mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add tau2 benchmark integration with comprehensive testing and documentation (#817)
* first commit to tau2-bench * tau2-bench agent * tau2 agent * add condition * checkpoint * bug fix * add tests * fix tests * add comments * add comments * minor fix * fix * batch test script * . * init.bak -> init.py * fix mypy * update readme * fix env * remove temp files * setup tests * fix gaia tasks * fix tau2 tests * fix coverage * fix default version * update cookiecutter template --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
52790b9f6a
commit
205cd700c8
+3
-2
@@ -36,6 +36,7 @@ requires = ["setuptools>=64", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools]
|
||||
package-dir = {"" = "src"}
|
||||
packages = ["agent_framework_lab_{{ cookiecutter.package_name }}", "agent_framework.lab.{{ cookiecutter.package_name }}"]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
@@ -66,7 +67,7 @@ warn_return_any = true
|
||||
warn_unreachable = true
|
||||
show_error_codes = true
|
||||
implicit_reexport = true
|
||||
packages = ["agent_framework_lab_{{ cookiecutter.package_name }}"]
|
||||
packages = ["src.agent_framework_lab_{{ cookiecutter.package_name }}"]
|
||||
|
||||
{% if cookiecutter.within_microsoft_agent_framework_repo == "y" %}
|
||||
[tool.poe]
|
||||
@@ -86,7 +87,7 @@ mypy = "mypy agent_framework_lab_{{ cookiecutter.package_name }}"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
pythonpath = ["."]
|
||||
pythonpath = ["src"]
|
||||
addopts = "--strict-markers --strict-config"
|
||||
markers = [
|
||||
"unit: marks tests as unit tests",
|
||||
|
||||
+7
-2
@@ -4,13 +4,18 @@
|
||||
{{ cookiecutter.package_description }}
|
||||
"""
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
# Import your main exports here
|
||||
# from .main_module import MainClass, main_function
|
||||
|
||||
try:
|
||||
__version__ = importlib.metadata.version(__name__)
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
__version__ = "0.0.0" # Fallback for development mode
|
||||
|
||||
__all__ = [
|
||||
# List your exports here
|
||||
# "MainClass",
|
||||
# "main_function",
|
||||
]
|
||||
|
||||
__version__ = "{{ cookiecutter.version }}"
|
||||
@@ -4,9 +4,16 @@
|
||||
GAIA benchmark module for Agent Framework.
|
||||
"""
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
from ._types import Evaluation, Evaluator, Prediction, Task, TaskResult, TaskRunner
|
||||
from .gaia import GAIA, GAIATelemetryConfig, gaia_scorer, viewer_main
|
||||
|
||||
try:
|
||||
__version__ = importlib.metadata.version(__name__)
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
__version__ = "0.0.0" # Fallback for development mode
|
||||
|
||||
__all__ = [
|
||||
"GAIA",
|
||||
"GAIATelemetryConfig",
|
||||
@@ -19,5 +26,3 @@ __all__ = [
|
||||
"TaskRunner",
|
||||
"Evaluator",
|
||||
]
|
||||
|
||||
__version__ = "0.1.0b1"
|
||||
|
||||
@@ -72,7 +72,9 @@ packages = ["agent_framework_lab_gaia"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
test = "pytest --cov=agent_framework_lab_gaia --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
|
||||
@@ -4,6 +4,13 @@
|
||||
RL Module for Microsoft Agent Framework
|
||||
"""
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
try:
|
||||
__version__ = importlib.metadata.version(__name__)
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
__version__ = "0.0.0" # Fallback for development mode
|
||||
|
||||
# Import your main exports here
|
||||
# from .main_module import MainClass, main_function
|
||||
|
||||
@@ -12,5 +19,3 @@ __all__ = [
|
||||
# "MainClass",
|
||||
# "main_function",
|
||||
]
|
||||
|
||||
__version__ = "0.1.0b1"
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1,190 @@
|
||||
# Agent Framework Lab - τ²-bench
|
||||
|
||||
τ²-bench implements a simulation framework for evaluating customer service agents across various domains.
|
||||
|
||||
The framework orchestrates conversations between two AI agents:
|
||||
- **Customer Service Agent**: Follows domain-specific policies and has access to tools (e.g., booking systems, databases)
|
||||
- **User Simulator**: Simulates realistic customer behavior with specific goals and scenarios
|
||||
|
||||
Each evaluation runs a multi-turn conversation where the user simulator presents a customer service scenario, and the agent must resolve it following the domain policy while using available tools appropriately. The results are evaluated using τ²'s comprehensive evaluation system.
|
||||
|
||||
## Supported Domains
|
||||
|
||||
| Domain | Status | Description |
|
||||
|--------|--------|-------------|
|
||||
| **airline** | ✅ Supported | Customer service for airline booking, changes, and support |
|
||||
| **retail** | 🚧 In Development | E-commerce customer support scenarios |
|
||||
| **telecom** | 🚧 In Development | Telecommunications service support |
|
||||
|
||||
*Note: Currently only the airline domain is fully supported.*
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install agent-framework-lab-tau2
|
||||
```
|
||||
|
||||
Download data from [Tau2-Bench](https://github.com/sierra-research/tau2-bench):
|
||||
|
||||
```bash
|
||||
git clone https://github.com/sierra-research/tau2-bench.git
|
||||
mv tau2-bench/data/ .
|
||||
rm -rf tau2-bench
|
||||
```
|
||||
|
||||
Export the data directory to `TAU2_DATA_DIR` environment variable:
|
||||
|
||||
```bash
|
||||
export TAU2_DATA_DIR="data"
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Running a Single Task
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from agent_framework_lab_tau2 import TaskRunner
|
||||
from tau2.domains.airline.environment import get_tasks
|
||||
|
||||
async def run_single_task():
|
||||
# Initialize the task runner
|
||||
runner = TaskRunner(max_steps=50)
|
||||
|
||||
# Set up your LLM clients
|
||||
assistant_client = OpenAIChatClient(
|
||||
base_url="https://api.openai.com/v1",
|
||||
api_key="your-api-key",
|
||||
ai_model_id="gpt-4o"
|
||||
)
|
||||
user_client = OpenAIChatClient(
|
||||
base_url="https://api.openai.com/v1",
|
||||
api_key="your-api-key",
|
||||
ai_model_id="gpt-4o-mini"
|
||||
)
|
||||
|
||||
# Get a task and run it
|
||||
tasks = get_tasks()
|
||||
task = tasks[0] # Run the first task
|
||||
|
||||
conversation = await runner.run(task, assistant_client, user_client)
|
||||
reward = runner.evaluate(task, conversation, runner.termination_reason)
|
||||
|
||||
print(f"Task completed with reward: {reward}")
|
||||
|
||||
# Run the example
|
||||
asyncio.run(run_single_task())
|
||||
```
|
||||
|
||||
### Running the Full Benchmark
|
||||
|
||||
Use the provided script to run the complete benchmark:
|
||||
|
||||
```bash
|
||||
# Run with default models (gpt-4.1 for both agent and user)
|
||||
python samples/run_benchmark.py
|
||||
|
||||
# Use custom models
|
||||
python samples/run_benchmark.py --assistant gpt-4o --user gpt-4o-mini
|
||||
|
||||
# Debug a specific task
|
||||
python samples/run_benchmark.py --debug-task-id task_001 --assistant gpt-4o
|
||||
|
||||
# Limit conversation length
|
||||
python samples/run_benchmark.py --max-steps 20
|
||||
```
|
||||
|
||||
## Results (on Airline Domain)
|
||||
|
||||
The following results are reproduced from our implementation of τ²-bench with `samples/run_benchmark.py`. It shows the average success rate over the dataset of 50 tasks.
|
||||
|
||||
| Agent Model | User Model | Success Rate |
|
||||
|-------------|------------|----------|
|
||||
| gpt-5 | gpt-4.1 | 62.0% |
|
||||
| gpt-5-mini | gpt-4.1 | 52.0% |
|
||||
| gpt-4.1 | gpt-4.1 | 60.0% |
|
||||
| gpt-4.1-mini | gpt-4.1 | 50.0% |
|
||||
| gpt-4.1 | gpt-4o-mini | 42.0% |
|
||||
| gpt-4o | gpt-4.1 | 42.0% |
|
||||
| gpt-4o-mini | gpt-4.1 | 26.0% |
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
Set required environment variables:
|
||||
|
||||
```bash
|
||||
export OPENAI_BASE_URL="https://api.openai.com/v1"
|
||||
export OPENAI_API_KEY="your-api-key"
|
||||
|
||||
# Optional: for custom endpoints
|
||||
export OPENAI_BASE_URL="https://your-custom-endpoint.com/v1"
|
||||
```
|
||||
|
||||
### Custom Agent Implementation
|
||||
|
||||
```python
|
||||
from agent_framework_lab_tau2 import TaskRunner
|
||||
from agent_framework import ChatAgent
|
||||
|
||||
class CustomTaskRunner(TaskRunner):
|
||||
def assistant_agent(self, assistant_chat_client):
|
||||
# Override to customize the assistant agent
|
||||
return ChatAgent(
|
||||
chat_client=assistant_chat_client,
|
||||
instructions="Your custom system prompt here",
|
||||
# Add custom tools, temperature, etc.
|
||||
)
|
||||
|
||||
def user_simulator(self, user_chat_client, task):
|
||||
# Override to customize the user simulator
|
||||
return ChatAgent(
|
||||
chat_client=user_chat_client,
|
||||
instructions="Custom user simulator prompt",
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Workflow Integration
|
||||
|
||||
```python
|
||||
from agent_framework._workflow import WorkflowBuilder, AgentExecutor
|
||||
from agent_framework_lab_tau2 import TaskRunner
|
||||
|
||||
class WorkflowTaskRunner(TaskRunner):
|
||||
def build_conversation_workflow(self, assistant_agent, user_simulator_agent):
|
||||
# Build a custom workflow
|
||||
builder = WorkflowBuilder()
|
||||
|
||||
# Create agent executors
|
||||
assistant_executor = AgentExecutor(assistant_agent, id="assistant_agent")
|
||||
user_executor = AgentExecutor(user_simulator_agent, id="user_simulator")
|
||||
|
||||
# Add workflow edges and conditions
|
||||
builder.set_start_executor(assistant_executor)
|
||||
builder.add_edge(assistant_executor, user_executor)
|
||||
builder.add_edge(user_executor, assistant_executor, condition=self.should_not_stop)
|
||||
|
||||
return builder.build()
|
||||
```
|
||||
|
||||
### Utility Functions
|
||||
|
||||
```python
|
||||
from agent_framework_lab_tau2 import patch_env_set_state, unpatch_env_set_state
|
||||
|
||||
# Enable compatibility patches for τ²-bench integration
|
||||
patch_env_set_state()
|
||||
|
||||
# Disable patches when done
|
||||
unpatch_env_set_state()
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
This package is part of the Microsoft Agent Framework Lab. Please see the main repository for contribution guidelines.
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the MIT License - see the LICENSE file for details.
|
||||
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""
|
||||
Tau2 Benchmark for Agent Framework.
|
||||
"""
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
from ._tau2_utils import patch_env_set_state, unpatch_env_set_state
|
||||
from .runner import TaskRunner
|
||||
|
||||
try:
|
||||
__version__ = importlib.metadata.version(__name__)
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
__version__ = "0.0.0" # Fallback for development mode
|
||||
|
||||
__all__ = [
|
||||
"TaskRunner",
|
||||
"patch_env_set_state",
|
||||
"unpatch_env_set_state",
|
||||
]
|
||||
@@ -0,0 +1,112 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from agent_framework._types import ChatMessage, Contents, Role
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def flip_messages(messages: list[ChatMessage]) -> list[ChatMessage]:
|
||||
"""Flip message roles between assistant and user for role-playing scenarios.
|
||||
|
||||
Used in agent simulations where the assistant's messages become user inputs
|
||||
and vice versa. Function calls are filtered out when flipping assistant
|
||||
messages to user messages (since users typically don't make function calls).
|
||||
"""
|
||||
|
||||
def filter_out_function_calls(messages: list[Contents]) -> list[Contents]:
|
||||
"""Remove function call content from message contents."""
|
||||
return [content for content in messages if content.type != "function_call"]
|
||||
|
||||
flipped_messages = []
|
||||
for msg in messages:
|
||||
if msg.role == Role.ASSISTANT:
|
||||
# Flip assistant to user
|
||||
contents = filter_out_function_calls(msg.contents)
|
||||
if contents:
|
||||
flipped_msg = ChatMessage(
|
||||
role=Role.USER,
|
||||
# The function calls will cause 400 when role is user
|
||||
contents=contents,
|
||||
author_name=msg.author_name,
|
||||
message_id=msg.message_id,
|
||||
)
|
||||
flipped_messages.append(flipped_msg)
|
||||
elif msg.role == Role.USER:
|
||||
# Flip user to assistant
|
||||
flipped_msg = ChatMessage(
|
||||
role=Role.ASSISTANT, contents=msg.contents, author_name=msg.author_name, message_id=msg.message_id
|
||||
)
|
||||
flipped_messages.append(flipped_msg)
|
||||
elif msg.role == Role.TOOL:
|
||||
# Skip tool messages
|
||||
pass
|
||||
else:
|
||||
# Keep other roles as-is (system, tool, etc.)
|
||||
flipped_messages.append(msg)
|
||||
return flipped_messages
|
||||
|
||||
|
||||
def log_messages(messages: list[ChatMessage]) -> None:
|
||||
"""Log messages with colored output based on role and content type.
|
||||
|
||||
Provides visual debugging by color-coding different message roles and
|
||||
content types. Escapes HTML-like characters to prevent log formatting issues.
|
||||
"""
|
||||
_logger = logger.opt(colors=True)
|
||||
for msg in messages:
|
||||
# Handle different content types
|
||||
if hasattr(msg, "contents") and msg.contents:
|
||||
for content in msg.contents:
|
||||
if hasattr(content, "type"):
|
||||
if content.type == "text":
|
||||
escape_text = content.text.replace("<", r"\<")
|
||||
if msg.role == Role.SYSTEM:
|
||||
_logger.info(f"<cyan>[SYSTEM]</cyan> {escape_text}")
|
||||
elif msg.role == Role.USER:
|
||||
_logger.info(f"<green>[USER]</green> {escape_text}")
|
||||
elif msg.role == Role.ASSISTANT:
|
||||
_logger.info(f"<blue>[ASSISTANT]</blue> {escape_text}")
|
||||
elif msg.role == Role.TOOL:
|
||||
_logger.info(f"<yellow>[TOOL]</yellow> {escape_text}")
|
||||
else:
|
||||
_logger.info(f"<magenta>[{msg.role.value.upper()}]</magenta> {escape_text}")
|
||||
elif content.type == "function_call":
|
||||
function_call_text = f"{content.name}({content.arguments})"
|
||||
function_call_text = function_call_text.replace("<", r"\<")
|
||||
_logger.info(f"<yellow>[TOOL_CALL]</yellow> 🔧 {function_call_text}")
|
||||
elif content.type == "function_result":
|
||||
function_result_text = f"ID:{content.call_id} -> {content.result}"
|
||||
function_result_text = function_result_text.replace("<", r"\<")
|
||||
_logger.info(f"<yellow>[TOOL_RESULT]</yellow> 🔨 {function_result_text}")
|
||||
else:
|
||||
content_text = str(content).replace("<", r"\<")
|
||||
_logger.info(f"<magenta>[{msg.role.value.upper()}] ({content.type})</magenta> {content_text}")
|
||||
else:
|
||||
# Fallback for content without type
|
||||
text_content = str(content).replace("<", r"\<")
|
||||
if msg.role == Role.SYSTEM:
|
||||
_logger.info(f"<cyan>[SYSTEM]</cyan> {text_content}")
|
||||
elif msg.role == Role.USER:
|
||||
_logger.info(f"<green>[USER]</green> {text_content}")
|
||||
elif msg.role == Role.ASSISTANT:
|
||||
_logger.info(f"<blue>[ASSISTANT]</blue> {text_content}")
|
||||
elif msg.role == Role.TOOL:
|
||||
_logger.info(f"<yellow>[TOOL]</yellow> {text_content}")
|
||||
else:
|
||||
_logger.info(f"<magenta>[{msg.role.value.upper()}]</magenta> {text_content}")
|
||||
elif hasattr(msg, "text") and msg.text:
|
||||
# Handle simple text messages
|
||||
text_content = msg.text.replace("<", r"\<")
|
||||
if msg.role == Role.SYSTEM:
|
||||
_logger.info(f"<cyan>[SYSTEM]</cyan> {text_content}")
|
||||
elif msg.role == Role.USER:
|
||||
_logger.info(f"<green>[USER]</green> {text_content}")
|
||||
elif msg.role == Role.ASSISTANT:
|
||||
_logger.info(f"<blue>[ASSISTANT]</blue> {text_content}")
|
||||
elif msg.role == Role.TOOL:
|
||||
_logger.info(f"<yellow>[TOOL]</yellow> {text_content}")
|
||||
else:
|
||||
_logger.info(f"<magenta>[{msg.role.value.upper()}]</magenta> {text_content}")
|
||||
else:
|
||||
# Fallback for other message formats
|
||||
text_content = str(msg).replace("<", r"\<")
|
||||
_logger.info(f"<magenta>[{msg.role.value.upper()}]</magenta> {text_content}")
|
||||
@@ -0,0 +1,131 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
from agent_framework._threads import ChatMessageList
|
||||
from agent_framework._types import ChatMessage, Role
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class SlidingWindowChatMessageList(ChatMessageList):
|
||||
"""A token-aware sliding window implementation of ChatMessageList.
|
||||
|
||||
Maintains two message lists: complete history and truncated window.
|
||||
Automatically removes oldest messages when token limit is exceeded.
|
||||
Also removes leading tool messages to ensure valid conversation flow.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: Sequence[ChatMessage] | None = None,
|
||||
max_tokens: int = 3800,
|
||||
system_message: str | None = None,
|
||||
tool_definitions: Any | None = None,
|
||||
):
|
||||
super().__init__(messages)
|
||||
self._truncated_messages = self._messages.copy() # Separate truncated view
|
||||
self.max_tokens = max_tokens
|
||||
self.system_message = system_message # Included in token count
|
||||
self.tool_definitions = tool_definitions
|
||||
# An estimation based on a commonly used vocab table
|
||||
self.encoding = tiktoken.get_encoding("o200k_base")
|
||||
|
||||
async def add_messages(self, messages: Sequence[ChatMessage]) -> None:
|
||||
await super().add_messages(messages)
|
||||
|
||||
self._truncated_messages = self._messages.copy()
|
||||
self.truncate_messages()
|
||||
|
||||
async def list_messages(self) -> list[ChatMessage]:
|
||||
"""Get the current list of messages, which may be truncated."""
|
||||
return self._truncated_messages
|
||||
|
||||
async def list_all_messages(self) -> list[ChatMessage]:
|
||||
"""Get all messages from the store including the truncated ones."""
|
||||
return self._messages
|
||||
|
||||
def truncate_messages(self) -> None:
|
||||
while len(self._truncated_messages) > 0 and self.get_token_count() > self.max_tokens:
|
||||
logger.warning("Messages exceed max tokens. Truncating oldest message.")
|
||||
self._truncated_messages.pop(0)
|
||||
# Remove leading tool messages
|
||||
while len(self._truncated_messages) > 0 and self._truncated_messages[0].role == Role.TOOL:
|
||||
logger.warning("Removing leading tool message because tool result cannot be the first message.")
|
||||
self._truncated_messages.pop(0)
|
||||
|
||||
def get_token_count(self) -> int:
|
||||
"""Estimate token count for a list of messages using tiktoken.
|
||||
Args:
|
||||
messages: List of ChatMessage objects
|
||||
system_message: Optional system message to include in count
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
total_tokens = 0
|
||||
|
||||
# Add system message tokens if provided
|
||||
if self.system_message:
|
||||
total_tokens += len(self.encoding.encode(self.system_message))
|
||||
total_tokens += 4 # Extra tokens for system message formatting
|
||||
|
||||
for msg in self._truncated_messages:
|
||||
# Add 4 tokens per message for role, formatting, etc.
|
||||
total_tokens += 4
|
||||
|
||||
# Handle different content types
|
||||
if hasattr(msg, "contents") and msg.contents:
|
||||
for content in msg.contents:
|
||||
if hasattr(content, "type"):
|
||||
if content.type == "text":
|
||||
total_tokens += len(self.encoding.encode(content.text))
|
||||
elif content.type == "function_call":
|
||||
total_tokens += 4
|
||||
# Serialize function call and count tokens
|
||||
func_call_data = {
|
||||
"name": content.name,
|
||||
"arguments": content.arguments,
|
||||
}
|
||||
total_tokens += self.estimate_any_object_token_count(func_call_data)
|
||||
elif content.type == "function_result":
|
||||
total_tokens += 4
|
||||
# Serialize function result and count tokens
|
||||
func_result_data = {
|
||||
"call_id": content.call_id,
|
||||
"result": content.result,
|
||||
}
|
||||
total_tokens += self.estimate_any_object_token_count(func_result_data)
|
||||
else:
|
||||
# For other content types, serialize the whole content
|
||||
total_tokens += self.estimate_any_object_token_count(content)
|
||||
else:
|
||||
# Content without type, treat as text
|
||||
total_tokens += self.estimate_any_object_token_count(content)
|
||||
elif hasattr(msg, "text") and msg.text:
|
||||
# Simple text message
|
||||
total_tokens += self.estimate_any_object_token_count(msg.text)
|
||||
else:
|
||||
# Skip it
|
||||
pass
|
||||
|
||||
if total_tokens > self.max_tokens / 2:
|
||||
logger.opt(colors=True).warning(
|
||||
f"<yellow>Total tokens {total_tokens} is "
|
||||
f"{total_tokens / self.max_tokens * 100:.0f}% "
|
||||
f"of max tokens {self.max_tokens}</yellow>"
|
||||
)
|
||||
elif total_tokens > self.max_tokens:
|
||||
logger.opt(colors=True).warning(
|
||||
f"<red>Total tokens {total_tokens} is over max tokens {self.max_tokens}. Will truncate messages.</red>"
|
||||
)
|
||||
|
||||
return total_tokens
|
||||
|
||||
def estimate_any_object_token_count(self, obj: Any) -> int:
|
||||
try:
|
||||
serialized = json.dumps(obj)
|
||||
except Exception:
|
||||
serialized = str(obj)
|
||||
return len(self.encoding.encode(serialized))
|
||||
@@ -0,0 +1,244 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from agent_framework._tools import AIFunction
|
||||
from agent_framework._types import ChatMessage
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from tau2.data_model.message import ( # type: ignore[import-untyped]
|
||||
AssistantMessage,
|
||||
Message,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from tau2.data_model.tasks import EnvFunctionCall, InitializationData # type: ignore[import-untyped]
|
||||
from tau2.environment.environment import Environment # type: ignore[import-untyped]
|
||||
from tau2.environment.tool import Tool # type: ignore[import-untyped]
|
||||
|
||||
_original_set_state = Environment.set_state
|
||||
|
||||
|
||||
def convert_tau2_tool_to_ai_function(tau2_tool: Tool) -> AIFunction[Any, Any]:
|
||||
"""Convert a tau2 Tool to an AIFunction for agent framework compatibility.
|
||||
|
||||
Creates a wrapper that preserves the tool's interface while ensuring
|
||||
results are deep-copied to prevent unintended mutations.
|
||||
"""
|
||||
|
||||
def wrapped_func(**kwargs: Any) -> Any:
|
||||
result = tau2_tool(**kwargs)
|
||||
# Deep copy to prevent mutations of returned data
|
||||
if isinstance(result, BaseModel):
|
||||
result = result.model_copy(deep=True)
|
||||
else:
|
||||
result = deepcopy(result)
|
||||
return result
|
||||
|
||||
return AIFunction(
|
||||
name=tau2_tool.name,
|
||||
description=tau2_tool._get_description(),
|
||||
func=wrapped_func,
|
||||
input_model=tau2_tool.params,
|
||||
)
|
||||
|
||||
|
||||
def convert_agent_framework_messages_to_tau2_messages(messages: list[ChatMessage]) -> list[Message]:
|
||||
"""Convert agent framework ChatMessages to tau2 Message objects.
|
||||
|
||||
Handles role mapping, text extraction, function calls, and function results.
|
||||
Function results are converted to separate ToolMessage instances.
|
||||
"""
|
||||
|
||||
tau2_messages = []
|
||||
|
||||
for msg in messages:
|
||||
role_str = str(msg.role)
|
||||
|
||||
# Extract text content from all text-type contents
|
||||
text_content = None
|
||||
text_contents = [c for c in msg.contents if hasattr(c, "text") and hasattr(c, "type") and c.type == "text"]
|
||||
if text_contents:
|
||||
text_content = " ".join(c.text for c in text_contents)
|
||||
|
||||
# Extract function calls and convert to ToolCall objects
|
||||
function_calls = [c for c in msg.contents if hasattr(c, "type") and c.type == "function_call"]
|
||||
tool_calls = None
|
||||
if function_calls:
|
||||
tool_calls = []
|
||||
for fc in function_calls:
|
||||
arguments = fc.parse_arguments() or {}
|
||||
tool_call = ToolCall(
|
||||
id=fc.call_id,
|
||||
name=fc.name,
|
||||
arguments=arguments,
|
||||
requestor="assistant" if role_str == "assistant" else "user",
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
# Extract function results for separate ToolMessage creation
|
||||
function_results = [c for c in msg.contents if hasattr(c, "type") and c.type == "function_result"]
|
||||
|
||||
# Create main message based on role
|
||||
if role_str == "system":
|
||||
tau2_messages.append(SystemMessage(role="system", content=text_content))
|
||||
elif role_str == "user":
|
||||
tau2_messages.append(UserMessage(role="user", content=text_content, tool_calls=tool_calls))
|
||||
elif role_str == "assistant":
|
||||
tau2_messages.append(AssistantMessage(role="assistant", content=text_content, tool_calls=tool_calls))
|
||||
elif role_str == "tool":
|
||||
# Tool messages are handled as function results below
|
||||
pass
|
||||
|
||||
# Convert function results to separate ToolMessage instances
|
||||
for fr in function_results:
|
||||
dumpable_content = _dump_function_result(fr.result)
|
||||
content = dumpable_content if isinstance(dumpable_content, str) else json.dumps(dumpable_content)
|
||||
tool_msg = ToolMessage(
|
||||
id=fr.call_id,
|
||||
role="tool",
|
||||
content=content,
|
||||
requestor="assistant", # Most tool calls originate from assistant
|
||||
error=fr.exception is not None,
|
||||
)
|
||||
tau2_messages.append(tool_msg)
|
||||
|
||||
return tau2_messages
|
||||
|
||||
|
||||
def patch_env_set_state() -> None:
|
||||
"""Patch Environment.set_state to allow inconsistent tool call results.
|
||||
|
||||
Modifies the original method to log warnings instead of raising errors
|
||||
when actual tool results differ from expected results, enabling more
|
||||
flexible testing and development workflows.
|
||||
"""
|
||||
|
||||
def set_state(
|
||||
self: Any,
|
||||
initialization_data: InitializationData | None,
|
||||
initialization_actions: list[EnvFunctionCall] | None,
|
||||
message_history: list[Message],
|
||||
) -> None:
|
||||
if self.solo_mode:
|
||||
if any(isinstance(message, UserMessage) for message in message_history):
|
||||
raise ValueError("User messages are not allowed in solo mode")
|
||||
|
||||
def get_actions_from_messages(
|
||||
messages: list[Message],
|
||||
) -> list[tuple[ToolCall, ToolMessage]]:
|
||||
"""
|
||||
Get the actions from the messages.
|
||||
"""
|
||||
messages = deepcopy(messages)[::-1]
|
||||
actions = []
|
||||
while messages:
|
||||
message = messages.pop()
|
||||
if isinstance(message, ToolMessage):
|
||||
raise ValueError("Tool message not expected. Tool messages should always follow a tool call.")
|
||||
if isinstance(message, (AssistantMessage, UserMessage)) and message.is_tool_call():
|
||||
tool_calls = message.tool_calls
|
||||
if tool_calls is None:
|
||||
raise ValueError("Tool message expected. Got None.")
|
||||
for tc in tool_calls:
|
||||
if len(messages) == 0:
|
||||
raise ValueError("Tool message expected. Got None.")
|
||||
tm = messages.pop()
|
||||
if not isinstance(tm, ToolMessage):
|
||||
raise ValueError(f"Tool message expected. Got {type(tm)}")
|
||||
if tc.id != tm.id:
|
||||
raise ValueError(f"Tool call id mismatch. Got {tc.id} and {tm.id}")
|
||||
actions.append((tc, tm))
|
||||
|
||||
return actions
|
||||
|
||||
if initialization_data is not None:
|
||||
if initialization_data.agent_data is not None:
|
||||
self.tools.update_db(initialization_data.agent_data)
|
||||
if initialization_data.user_data is not None:
|
||||
self.user_tools.update_db(initialization_data.user_data)
|
||||
|
||||
if initialization_actions is not None:
|
||||
for action in initialization_actions:
|
||||
self.run_env_function_call(action)
|
||||
|
||||
action_responses = get_actions_from_messages(message_history)
|
||||
for tool_call, expected_response in action_responses:
|
||||
response = self.get_response(tool_call)
|
||||
content = _recursive_json_deserialize(response.content)
|
||||
expected_content = _recursive_json_deserialize(expected_response.content)
|
||||
if content != expected_content:
|
||||
diff = f"Tool call:\n{tool_call}\n\nReturned:\n{response}\n\nExpected:\n{expected_response}"
|
||||
if isinstance(content, str) and content.startswith("Error:"):
|
||||
# If the tool call resulted in an error, the difference can be ignored
|
||||
logger.warning(f"Tool call resulted in an error. Ignoring the difference.\n{diff}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Tool call:\n{tool_call}\n\nReturned:\n{response}\n\nExpected:\n{expected_response}"
|
||||
)
|
||||
self.sync_tools()
|
||||
|
||||
Environment.set_state = set_state
|
||||
|
||||
|
||||
def unpatch_env_set_state() -> None:
|
||||
Environment.set_state = _original_set_state
|
||||
|
||||
|
||||
def _dump_function_result(result: Any) -> Any:
|
||||
if isinstance(result, BaseModel):
|
||||
return result.model_dump_json()
|
||||
elif isinstance(result, list):
|
||||
return [_dump_function_result(item) for item in result]
|
||||
elif isinstance(result, dict):
|
||||
return {k: _dump_function_result(v) for k, v in result.items()}
|
||||
elif result is None:
|
||||
return None
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
def _to_native(obj: Any) -> Any:
|
||||
"""Convert data retrieved from Panquet to data usable in AGL server."""
|
||||
# 1) Arrays -> list (then recurse)
|
||||
if isinstance(obj, np.ndarray):
|
||||
return _to_native(obj.tolist())
|
||||
|
||||
# 2) NumPy scalar types -> Python scalars
|
||||
if isinstance(obj, np.generic):
|
||||
return _to_native(obj.item())
|
||||
|
||||
# 3) Dict-like -> dict
|
||||
if isinstance(obj, Mapping):
|
||||
return {_to_native(k): _to_native(v) for k, v in obj.items()}
|
||||
|
||||
# 4) Lists/Tuples/Sets -> list
|
||||
if isinstance(obj, (list, tuple, set)):
|
||||
return [_to_native(x) for x in obj]
|
||||
|
||||
# 5) Anything else: leave as-is
|
||||
return obj
|
||||
|
||||
|
||||
def _recursive_json_deserialize(obj: Any) -> Any:
|
||||
"""
|
||||
Recursively deserialize a JSON object.
|
||||
"""
|
||||
if isinstance(obj, str):
|
||||
try:
|
||||
deserialized = json.loads(obj)
|
||||
return _recursive_json_deserialize(deserialized)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return obj
|
||||
elif isinstance(obj, list):
|
||||
return [_recursive_json_deserialize(item) for item in obj]
|
||||
elif isinstance(obj, dict):
|
||||
return {k: _recursive_json_deserialize(v) for k, v in obj.items()}
|
||||
else:
|
||||
return obj
|
||||
@@ -0,0 +1,424 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
from agent_framework._agents import ChatAgent
|
||||
from agent_framework._types import AgentRunResponse, ChatMessage, Role
|
||||
from agent_framework._workflow import (
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest,
|
||||
AgentExecutorResponse,
|
||||
FunctionExecutor,
|
||||
Workflow,
|
||||
WorkflowBuilder,
|
||||
WorkflowContext,
|
||||
)
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from loguru import logger
|
||||
from tau2.data_model.simulation import SimulationRun, TerminationReason # type: ignore[import-untyped]
|
||||
from tau2.data_model.tasks import Task # type: ignore[import-untyped]
|
||||
from tau2.domains.airline.environment import get_environment # type: ignore[import-untyped]
|
||||
from tau2.evaluator.evaluator import EvaluationType, RewardInfo, evaluate_simulation # type: ignore[import-untyped]
|
||||
from tau2.user.user_simulator import ( # type: ignore[import-untyped]
|
||||
OUT_OF_SCOPE,
|
||||
STOP,
|
||||
TRANSFER,
|
||||
get_global_user_sim_guidelines,
|
||||
)
|
||||
from tau2.utils.utils import get_now # type: ignore[import-untyped]
|
||||
|
||||
from ._message_utils import flip_messages, log_messages
|
||||
from ._sliding_window import SlidingWindowChatMessageList
|
||||
from ._tau2_utils import convert_agent_framework_messages_to_tau2_messages, convert_tau2_tool_to_ai_function
|
||||
|
||||
# Agent instructions matching tau2's LLMAgent
|
||||
ASSISTANT_AGENT_INSTRUCTION = """
|
||||
You are a customer service agent that helps the user according to the <policy> provided below.
|
||||
In each turn you can either:
|
||||
- Send a message to the user.
|
||||
- Make a tool call.
|
||||
You cannot do both at the same time.
|
||||
Try to be helpful and always follow the policy. Always make sure you generate valid JSON only.
|
||||
""".strip()
|
||||
|
||||
# Default first message from agent (matching tau2)
|
||||
DEFAULT_FIRST_AGENT_MESSAGE = "Hi! How can I help you today?"
|
||||
|
||||
# Constants of Agent executor IDs
|
||||
ASSISTANT_AGENT_ID = "assistant_agent"
|
||||
USER_SIMULATOR_ID = "user_simulator"
|
||||
ORCHESTRATOR_ID = "orchestrator"
|
||||
|
||||
|
||||
class TaskRunner:
|
||||
"""Orchestrates task execution using agent framework workflows for tau2 benchmarks.
|
||||
|
||||
Manages conversation flow between assistant agents and user simulators,
|
||||
handles termination conditions, and evaluates performance using tau2 metrics.
|
||||
|
||||
Only "airline" domain is supported for now.
|
||||
"""
|
||||
|
||||
# State tracking
|
||||
step_count: int
|
||||
full_conversation: list[ChatMessage]
|
||||
termination_reason: TerminationReason | None
|
||||
full_reward_info: RewardInfo | None
|
||||
_final_user_message: list[ChatMessage] | None
|
||||
_assistant_executor: AgentExecutor | None
|
||||
_user_executor: AgentExecutor | None
|
||||
|
||||
# Configuration
|
||||
max_steps: int
|
||||
assistant_sampling_temperature: float
|
||||
assistant_window_size: int
|
||||
|
||||
def __init__(self, max_steps: int, assistant_sampling_temperature: float = 0.0, assistant_window_size: int = 32768):
|
||||
"""Initialize the TaskRunner.
|
||||
|
||||
Args:
|
||||
max_steps: The maximum number of steps to run.
|
||||
assistant_sampling_temperature: The sampling temperature for the assistant agent.
|
||||
assistant_window_size: The window size for the assistant agent.
|
||||
"""
|
||||
self.assistant_sampling_temperature = assistant_sampling_temperature
|
||||
self.assistant_window_size = assistant_window_size
|
||||
self.max_steps = max_steps
|
||||
self.reinit()
|
||||
|
||||
def reinit(self) -> "TaskRunner":
|
||||
"""Reset all state for a new task run."""
|
||||
self.step_count = 0
|
||||
self.full_conversation = []
|
||||
self.termination_reason = None
|
||||
self.full_reward_info = None
|
||||
self._final_user_message = None
|
||||
self._assistant_executor = None
|
||||
self._user_executor = None
|
||||
logger.info("TaskRunner has been re-initialized.")
|
||||
return self
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"TaskRunner(max_steps={self.max_steps}, step_count={self.step_count}, "
|
||||
f"full_conversation_length={len(self.full_conversation)}, "
|
||||
f"termination_reason={self.termination_reason}, full_reward_info={self.full_reward_info})"
|
||||
)
|
||||
|
||||
def should_not_stop(self, response: AgentExecutorResponse) -> bool:
|
||||
"""Based on the response, check whether we should or not stop the conversation."""
|
||||
|
||||
# Determine who sent this based on executor_id
|
||||
is_from_agent = response.executor_id == ASSISTANT_AGENT_ID
|
||||
is_from_user = response.executor_id == USER_SIMULATOR_ID
|
||||
|
||||
self.step_count += 1
|
||||
|
||||
logger.opt(colors=True).info(
|
||||
f"<bold>[Step {self.step_count}] Received the following response from "
|
||||
f"{'<blue>assistant</blue>' if is_from_agent else '<green>user</green>'}</bold>, "
|
||||
f"routing to {'<green>user</green>' if is_from_agent else '<blue>assistant</blue>'}:"
|
||||
)
|
||||
log_messages(response.agent_run_response.messages)
|
||||
|
||||
if self.step_count >= self.max_steps:
|
||||
logger.info(f"Max steps ({self.max_steps}) reached - terminating conversation")
|
||||
self.termination_reason = TerminationReason.MAX_STEPS
|
||||
# Terminate the workflow
|
||||
return False
|
||||
|
||||
response_text = response.agent_run_response.text
|
||||
if is_from_agent and self._is_agent_stop(response_text):
|
||||
logger.info("Agent requested stop - terminating conversation")
|
||||
self.termination_reason = TerminationReason.AGENT_STOP
|
||||
return False
|
||||
|
||||
if is_from_user and self._is_user_stop(response_text):
|
||||
logger.info(f"User requested stop with message: '{response_text}' - terminating conversation")
|
||||
self.termination_reason = TerminationReason.USER_STOP
|
||||
# The final user message won't appear in the assistant's message store,
|
||||
# because it will never arrive there.
|
||||
# We need to store it because it's needed for evaluation.
|
||||
self._final_user_message = flip_messages(response.agent_run_response.messages)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _is_agent_stop(self, _: str) -> bool:
|
||||
"""Check if agent wants to stop the conversation."""
|
||||
# Could check for specific stop tokens if agent uses them
|
||||
return False # Agent doesn't have explicit stop in this setup
|
||||
|
||||
def _is_user_stop(self, text: str) -> bool:
|
||||
"""Check if user wants to stop the conversation."""
|
||||
return STOP in text or TRANSFER in text or OUT_OF_SCOPE in text
|
||||
|
||||
def assistant_agent(self, assistant_chat_client: OpenAIChatClient) -> ChatAgent:
|
||||
"""Create an assistant agent.
|
||||
|
||||
Users can override this method to provide a custom assistant agent.
|
||||
|
||||
Args:
|
||||
assistant_chat_client: The chat client for the assistant agent.
|
||||
|
||||
Returns:
|
||||
The assistant agent.
|
||||
"""
|
||||
|
||||
# Initialize tau2 environment and extract tools/policy
|
||||
# This provides the domain-specific context (airline customer service in this case)
|
||||
env = get_environment()
|
||||
tools = env.get_tools() # Available actions the assistant can take
|
||||
policy = env.get_policy() # Guidelines the assistant must follow
|
||||
|
||||
logger.info(
|
||||
f"Environment has {len(env.get_tools())} tools: {', '.join([tool.name for tool in env.get_tools()])}"
|
||||
)
|
||||
|
||||
# Convert tau2 tools to agent framework AIFunction format
|
||||
# This bridges the gap between tau2's tool system and agent framework's expectations
|
||||
ai_functions = [convert_tau2_tool_to_ai_function(tool) for tool in tools]
|
||||
|
||||
# Combines general customer service behavior with specific policy guidelines
|
||||
assistant_system_prompt = f"""<instructions>
|
||||
{ASSISTANT_AGENT_INSTRUCTION}
|
||||
</instructions>
|
||||
<policy>
|
||||
{policy}
|
||||
</policy>"""
|
||||
|
||||
# Assistant agent has:
|
||||
# - Access to all domain tools (booking, cancellation, etc.)
|
||||
# - Sliding window memory to handle long conversations within token limits
|
||||
# - Temperature-controlled response generation
|
||||
return ChatAgent(
|
||||
chat_client=assistant_chat_client,
|
||||
instructions=assistant_system_prompt,
|
||||
tools=ai_functions, # type: ignore
|
||||
temperature=self.assistant_sampling_temperature,
|
||||
chat_message_store_factory=lambda: SlidingWindowChatMessageList(
|
||||
system_message=assistant_system_prompt,
|
||||
tool_definitions=[tool.openai_schema for tool in tools],
|
||||
max_tokens=self.assistant_window_size,
|
||||
),
|
||||
)
|
||||
|
||||
def user_simulator(self, user_simuator_chat_client: OpenAIChatClient, task: Task) -> ChatAgent:
|
||||
"""Create a user simulator agent.
|
||||
|
||||
Users can override this method to provide a custom user simulator agent.
|
||||
|
||||
Args:
|
||||
user_simuator_chat_client: The chat client for the user simulator agent.
|
||||
task: The task to be executed.
|
||||
|
||||
Returns:
|
||||
The user simulator agent.
|
||||
"""
|
||||
|
||||
# User simulator follows tau2's guidelines for realistic customer behavior
|
||||
# No tools available - users typically don't have direct system access
|
||||
user_sim_guidelines = get_global_user_sim_guidelines(use_tools=False)
|
||||
|
||||
# User simulator prompt combines general guidelines with task-specific scenario
|
||||
user_sim_system_prompt = f"""{user_sim_guidelines}
|
||||
<scenario>
|
||||
{task.user_scenario.instructions}
|
||||
</scenario>"""
|
||||
|
||||
return ChatAgent(
|
||||
chat_client=user_simuator_chat_client,
|
||||
instructions=user_sim_system_prompt,
|
||||
temperature=0.0,
|
||||
# No sliding window for user simulator to maintain full conversation context
|
||||
# TODO(yuge): Consider adding user tools in future for more realistic scenarios
|
||||
)
|
||||
|
||||
async def conversation_orchestrator(
|
||||
self, response: AgentExecutorResponse, ctx: WorkflowContext[AgentExecutorRequest]
|
||||
) -> None:
|
||||
"""Orchestrate conversation flow between assistant and user simulator.
|
||||
|
||||
This is the central routing hub that:
|
||||
|
||||
1. Receives responses from either the assistant agent or user simulator
|
||||
2. Flips message roles to create proper conversation flow (assistant->user, user->assistant)
|
||||
3. Routes the flipped messages to the appropriate target agent
|
||||
4. Maintains the conversation loop until termination conditions are met
|
||||
|
||||
Args:
|
||||
response: The response from either assistant or user simulator agent
|
||||
ctx: Workflow context for sending messages to other executors
|
||||
"""
|
||||
# Flip message roles for proper conversation flow
|
||||
# Assistant messages become user messages and vice versa
|
||||
flipped = flip_messages(response.agent_run_response.messages)
|
||||
|
||||
# Determine source to route to correct target
|
||||
is_from_agent = response.executor_id == ASSISTANT_AGENT_ID
|
||||
|
||||
# Send flipped messages to the opposite agent
|
||||
# Critical: Target ID must be specified to prevent broadcasting to both agents
|
||||
await ctx.send_message(
|
||||
AgentExecutorRequest(messages=flipped, should_respond=True),
|
||||
target_id=USER_SIMULATOR_ID if is_from_agent else ASSISTANT_AGENT_ID,
|
||||
)
|
||||
|
||||
def build_conversation_workflow(self, assistant_agent: ChatAgent, user_simulator_agent: ChatAgent) -> Workflow:
|
||||
"""Build the conversation workflow.
|
||||
|
||||
Users can override this method to provide a custom conversation workflow.
|
||||
|
||||
Args:
|
||||
assistant_agent: The assistant agent.
|
||||
user_simulator_agent: The user simulator agent.
|
||||
|
||||
Returns:
|
||||
The conversation workflow.
|
||||
"""
|
||||
|
||||
# STEP 1: Create workflow executors
|
||||
# Each executor wraps an agent or function for workflow orchestration
|
||||
self._assistant_executor = AgentExecutor(assistant_agent, id=ASSISTANT_AGENT_ID)
|
||||
self._user_executor = AgentExecutor(user_simulator_agent, id=USER_SIMULATOR_ID)
|
||||
orchestrator = FunctionExecutor(func=self.conversation_orchestrator, id=ORCHESTRATOR_ID)
|
||||
|
||||
# STEP 2: Build the conversation workflow
|
||||
# Creates a cyclic workflow: Orchestrator -> Assistant -> Orchestrator -> User -> Orchestrator...
|
||||
# The orchestrator acts as a message router that flips roles and routes to appropriate agent
|
||||
workflow = (
|
||||
WorkflowBuilder(max_iterations=10000) # Unlimited - we control termination via should_not_stop
|
||||
.set_start_executor(orchestrator) # Orchestrator manages the conversation flow
|
||||
.add_edge(orchestrator, self._assistant_executor) # Route messages to assistant
|
||||
.add_edge(
|
||||
self._assistant_executor, orchestrator, condition=self.should_not_stop
|
||||
) # Check termination after assistant
|
||||
.add_edge(orchestrator, self._user_executor) # Route messages to user simulator
|
||||
.add_edge(self._user_executor, orchestrator, condition=self.should_not_stop) # Check termination after user
|
||||
.build()
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
async def run(
|
||||
self,
|
||||
task: Task,
|
||||
assistant_chat_client: OpenAIChatClient,
|
||||
user_simuator_chat_client: OpenAIChatClient,
|
||||
) -> list[ChatMessage]:
|
||||
"""Run a tau2 task using workflow-based agent orchestration.
|
||||
|
||||
This method orchestrates a complex multi-agent simulation:
|
||||
|
||||
1. Sets up tau2 environment and converts tools for agent framework compatibility
|
||||
2. Creates two agents: assistant (with tools) and user simulator (without tools)
|
||||
3. Builds a workflow with orchestrated message routing between agents
|
||||
4. Manages conversation flow until termination conditions are met
|
||||
5. Returns complete conversation history for evaluation
|
||||
|
||||
Args:
|
||||
task: Tau2 task containing scenario, policy, and evaluation criteria
|
||||
assistant_chat_client: LLM client for the assistant agent
|
||||
user_simuator_chat_client: LLM client for the user simulator
|
||||
|
||||
Returns:
|
||||
Complete conversation history as ChatMessage list for evaluation
|
||||
"""
|
||||
|
||||
logger.info(f"Starting workflow agent for task {task.id}: {task.description.purpose}") # type: ignore[unused-ignore]
|
||||
logger.info(f"Assistant chat client: {assistant_chat_client}")
|
||||
logger.info(f"User simulator chat client: {user_simuator_chat_client}")
|
||||
|
||||
# STEP 1: Create agents
|
||||
assistant_agent = self.assistant_agent(assistant_chat_client)
|
||||
user_simulator_agent = self.user_simulator(user_simuator_chat_client, task)
|
||||
|
||||
# STEP 2: Create the conversation workflow
|
||||
workflow = self.build_conversation_workflow(assistant_agent, user_simulator_agent)
|
||||
|
||||
# STEP 3: Initialize conversation with standard greeting
|
||||
# Matches tau2's expected conversation start pattern
|
||||
logger.info(f"Starting workflow with hardcoded greeting: '{DEFAULT_FIRST_AGENT_MESSAGE}'")
|
||||
|
||||
first_message = ChatMessage(Role.ASSISTANT, text=DEFAULT_FIRST_AGENT_MESSAGE)
|
||||
initial_greeting = AgentExecutorResponse(
|
||||
executor_id=ASSISTANT_AGENT_ID,
|
||||
agent_run_response=AgentRunResponse(messages=[first_message]),
|
||||
full_conversation=[ChatMessage(Role.ASSISTANT, text=DEFAULT_FIRST_AGENT_MESSAGE)],
|
||||
)
|
||||
|
||||
# STEP 4: Execute the workflow and collect results
|
||||
# The workflow runs until termination conditions are met (max steps, stop signals, etc.)
|
||||
await workflow.run(initial_greeting)
|
||||
|
||||
# STEP 5: Ensemble the conversation history needed for evaluation.
|
||||
# It's coming from three parts:
|
||||
# 1. The initial greeting
|
||||
# 2. The assistant's message store (not just the truncated window)
|
||||
# 3. The final user message (if any)
|
||||
assistant_executor = cast(AgentExecutor, self._assistant_executor)
|
||||
message_store = cast(SlidingWindowChatMessageList, assistant_executor._agent_thread.message_store)
|
||||
full_conversation = [first_message] + await message_store.list_all_messages()
|
||||
if self._final_user_message is not None:
|
||||
full_conversation.extend(self._final_user_message)
|
||||
|
||||
logger.opt(colors=True).info(
|
||||
f"<green>WORKFLOW COMPLETED WITH {len(full_conversation)} MESSAGES. "
|
||||
f"Termination reason: {self.termination_reason}.</green>"
|
||||
)
|
||||
log_messages(full_conversation)
|
||||
|
||||
return full_conversation
|
||||
|
||||
def evaluate(
|
||||
self, task_input: Task, conversation: list[ChatMessage], termination_reason: TerminationReason | None
|
||||
) -> float:
|
||||
"""Evaluate agent performance using tau2's comprehensive evaluation system.
|
||||
|
||||
Bridges agent framework conversation results with tau2's evaluation pipeline.
|
||||
Considers task completion, policy adherence, conversation quality, and tool usage.
|
||||
|
||||
Args:
|
||||
task_input: Original tau2 task containing evaluation criteria
|
||||
conversation: Complete conversation history from workflow execution
|
||||
termination_reason: How/why the conversation ended (affects scoring)
|
||||
|
||||
Returns:
|
||||
Numeric reward score (0.0-1.0) representing overall performance
|
||||
|
||||
Side Effects:
|
||||
Stores detailed evaluation results in self.full_reward_info
|
||||
"""
|
||||
|
||||
# Handle missing termination reason (can happen with unexpected workflow endings)
|
||||
if termination_reason is None:
|
||||
termination_reason = TerminationReason.TOO_MANY_ERRORS
|
||||
|
||||
# Convert agent framework ChatMessages to tau2 Message format for evaluation
|
||||
tau2_messages = convert_agent_framework_messages_to_tau2_messages(conversation)
|
||||
|
||||
# Package conversation and metadata for tau2's evaluation system
|
||||
simulation = SimulationRun(
|
||||
id=str(uuid.uuid4()), # Unique identifier for this evaluation run
|
||||
task_id=task_input.id, # Links evaluation back to original task
|
||||
start_time=get_now(), # Timestamp for evaluation records
|
||||
end_time=get_now(), # Duration is 0 since this is post-hoc evaluation
|
||||
duration=0.0,
|
||||
termination_reason=termination_reason, # Context for how conversation ended
|
||||
messages=tau2_messages, # The actual conversation to evaluate
|
||||
)
|
||||
|
||||
# Run comprehensive multi-dimensional evaluation
|
||||
# EvaluationType.ALL: evaluates task completion, policy adherence, conversation quality, ...
|
||||
# solo_mode=False: indicates multi-agent conversation (assistant + user simulator)
|
||||
self.full_reward_info = evaluate_simulation(
|
||||
simulation=simulation,
|
||||
task=task_input,
|
||||
evaluation_type=EvaluationType.ALL,
|
||||
solo_mode=False,
|
||||
domain="airline",
|
||||
)
|
||||
|
||||
logger.info(f"Evaluation completed - Reward: {self.full_reward_info.reward}, Info: {self.full_reward_info}")
|
||||
return self.full_reward_info.reward # type: ignore[no-any-return]
|
||||
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
# This makes agent_framework a namespace package
|
||||
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
|
||||
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
# This makes agent_framework.lab a namespace package
|
||||
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
|
||||
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
# Import and re-export from the actual implementation
|
||||
from agent_framework_lab_tau2 import * # noqa: F403, F401
|
||||
@@ -0,0 +1,99 @@
|
||||
[project]
|
||||
name = "agent-framework-lab-tau2"
|
||||
description = "Tau2 Benchmark for Agent Framework."
|
||||
authors = [{ name = "Microsoft", email = "SK-Support@microsoft.com"}]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
version = "0.1.0b1"
|
||||
license-files = ["LICENSE"]
|
||||
urls.homepage = "https://learn.microsoft.com/en-us/semantic-kernel/overview/"
|
||||
urls.source = "https://github.com/microsoft/agent-framework/tree/main/python"
|
||||
urls.release_notes = "https://github.com/microsoft/agent-framework/releases?q=tag%3Apython-1&expanded=true"
|
||||
urls.issues = "https://github.com/microsoft/agent-framework/issues"
|
||||
classifiers = [
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Development Status :: 2 - Pre-Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
]
|
||||
dependencies = [
|
||||
"agent-framework",
|
||||
"pydantic>=2.0.0",
|
||||
"tiktoken>=0.11.0",
|
||||
"loguru>=0.7.3",
|
||||
"tau2@git+https://github.com/sierra-research/tau2-bench@5ba9e3e56db57c5e4114bf7f901291f09b2c5619",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=64", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["agent_framework_lab_tau2", "agent_framework.lab.tau2"]
|
||||
|
||||
[tool.setuptools.package-dir]
|
||||
"agent_framework.lab.tau2" = "namespace/agent_framework/lab/tau2"
|
||||
"agent_framework_lab_tau2" = "agent_framework_lab_tau2"
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
agent_framework_lab_tau2 = ["py.typed"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
target-version = "py310"
|
||||
extend-exclude = ["tests", "__pycache__"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "W", "UP", "C4", "N"]
|
||||
ignore = ["N803", "N806", "N999", "UP007"]
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "double"
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
strict = true
|
||||
check_untyped_defs = true
|
||||
disallow_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_untyped_decorators = true
|
||||
warn_redundant_casts = true
|
||||
warn_unused_ignores = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
show_error_codes = true
|
||||
implicit_reexport = true
|
||||
packages = ["agent_framework_lab_tau2"]
|
||||
exclude = [
|
||||
"data",
|
||||
]
|
||||
|
||||
[tool.pyright]
|
||||
exclude = ["**/data"]
|
||||
|
||||
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
test = "pytest --cov=agent_framework_lab_tau2 --cov-report=term-missing:skip-covered tests"
|
||||
mypy = "mypy agent_framework_lab_tau2"
|
||||
setup-data = "python tests/setup_data.py"
|
||||
purge-data = "python tests/purge_data.py"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
pythonpath = ["."]
|
||||
addopts = "--strict-markers --strict-config"
|
||||
markers = [
|
||||
"unit: marks tests as unit tests",
|
||||
"integration: marks tests as integration tests",
|
||||
]
|
||||
env = [
|
||||
"TAU2_DATA_DIR=data",
|
||||
]
|
||||
@@ -0,0 +1,4 @@
|
||||
TAU2_DATA_DIR=/path/to/your/data
|
||||
|
||||
OPENAI_API_KEY=dummy
|
||||
OPENAI_BASE_URL=http://127.0.0.1:12345/
|
||||
@@ -0,0 +1,240 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from loguru import logger
|
||||
from tau2.domains.airline.environment import get_tasks
|
||||
|
||||
from agent_framework_lab_tau2 import TaskRunner, patch_env_set_state
|
||||
|
||||
|
||||
def to_dumpable(result: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Convert benchmark result to JSONL-serializable format.
|
||||
|
||||
Handles both successful runs and error cases, ensuring consistent output
|
||||
format for downstream analysis. Converts Pydantic models to dictionaries
|
||||
and extracts enum values for JSON compatibility.
|
||||
"""
|
||||
if "error" in result:
|
||||
# Error case: minimal structure with zero reward
|
||||
return {
|
||||
"id": result["task"].id,
|
||||
"error": result["error"],
|
||||
"evaluation": {
|
||||
"reward": 0.0, # Standard zero reward for failed runs
|
||||
},
|
||||
"config": result["config"],
|
||||
"task": result["task"].model_dump(),
|
||||
}
|
||||
else:
|
||||
# Success case: full result structure
|
||||
return {
|
||||
"id": result["task"].id,
|
||||
"evaluation": result["evaluation"].model_dump(), # Detailed evaluation metrics
|
||||
"config": result["config"], # Model configuration used
|
||||
"termination_reason": result["termination_reason"].value, # Enum to string
|
||||
"messages": [m.model_dump() for m in result["messages"]], # Full conversation
|
||||
"task": result["task"].model_dump(), # Task specification
|
||||
}
|
||||
|
||||
|
||||
async def run_benchmark(assistant_model: str, user_model: str, debug_task_id: str | None, max_steps: int):
|
||||
"""Run comprehensive tau2 benchmark evaluation using agent framework.
|
||||
|
||||
This is the main function that:
|
||||
|
||||
1. Sets up output file handling (full benchmark vs debug mode)
|
||||
2. Loads tau2 task dataset and configures LLM clients
|
||||
3. Runs each task through the agent framework workflow
|
||||
4. Evaluates performance using tau2's multi-dimensional metrics
|
||||
5. Aggregates results and calculates overall benchmark scores
|
||||
|
||||
Args:
|
||||
assistant_model: Model ID for the customer service agent (e.g., "gpt-4o")
|
||||
user_model: Model ID for the user simulator (e.g., "gpt-4o")
|
||||
debug_task_id: Optional specific task ID to run (disables batch processing)
|
||||
max_steps: Maximum conversation steps before forced termination
|
||||
|
||||
Output:
|
||||
Creates timestamped JSONL file with detailed results for analysis
|
||||
Prints summary statistics to console with colored logging
|
||||
"""
|
||||
|
||||
# STEP 1: Configure output handling based on execution mode
|
||||
result_fp = None
|
||||
if debug_task_id is None:
|
||||
# Full benchmark mode: create timestamped results file
|
||||
timestamp = datetime.now().strftime("%m%d%H%M") # Format: MMDDHHMM
|
||||
result_filename = f"results/{assistant_model}_user-{user_model}_{timestamp}.jsonl"
|
||||
os.makedirs("results", exist_ok=True)
|
||||
result_fp = open(result_filename, "a") # Append mode for resumability
|
||||
logger.info(f"Results will be saved to: {result_filename}")
|
||||
else:
|
||||
# Debug mode: single task, no file output, verbose logging
|
||||
logger.info(f"Debug mode: targeting task ID {debug_task_id}")
|
||||
|
||||
# STEP 2: Load tau2 dataset and validate environment
|
||||
tasks = get_tasks() # Loads all tau2 airline customer service tasks
|
||||
logger.info(f"Found {len(tasks)} tasks in the dataset")
|
||||
|
||||
_logger = logger.opt(colors=True) # Enable colored console output
|
||||
|
||||
# Validate required OpenAI configuration
|
||||
# Both models use the same endpoint but can be different model types
|
||||
openai_base_url = os.getenv("OPENAI_BASE_URL")
|
||||
if openai_base_url is None:
|
||||
raise ValueError("OPENAI_BASE_URL must be set")
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
if openai_api_key is None:
|
||||
raise ValueError("OPENAI_API_KEY must be set")
|
||||
|
||||
# STEP 3: Initialize LLM clients for both agent roles
|
||||
# Assistant: handles customer service with access to tools and policies
|
||||
assistant_chat_client = OpenAIChatClient(
|
||||
base_url=openai_base_url,
|
||||
api_key=openai_api_key,
|
||||
ai_model_id=assistant_model,
|
||||
)
|
||||
|
||||
# User simulator: simulates realistic customer behavior and requests
|
||||
user_simulator_chat_client = OpenAIChatClient(
|
||||
base_url=openai_base_url,
|
||||
api_key=openai_api_key,
|
||||
ai_model_id=user_model,
|
||||
)
|
||||
|
||||
# STEP 4: Filter task set for debug mode
|
||||
if debug_task_id is not None:
|
||||
tasks = [task for task in tasks if task.id == debug_task_id]
|
||||
if not tasks:
|
||||
logger.error(f"Task ID {debug_task_id} not found in dataset")
|
||||
return
|
||||
|
||||
# STEP 5: Initialize evaluation tracking
|
||||
all_rewards: list[float] = [] # Stores reward scores for final statistics
|
||||
task_runner = TaskRunner(max_steps=max_steps) # Reusable workflow orchestrator
|
||||
|
||||
# STEP 6: Execute benchmark across all tasks
|
||||
for task in tasks:
|
||||
_logger.info(f"<red>Testing task #{task.id}</red>")
|
||||
_logger.info(f"<cyan>Purpose:</cyan> {task.description.purpose}") # type: ignore
|
||||
|
||||
# Initialize result structure for this task
|
||||
result: dict[str, Any] = {
|
||||
"config": {
|
||||
"assistant": assistant_chat_client.ai_model_id,
|
||||
"user": user_simulator_chat_client.ai_model_id,
|
||||
},
|
||||
"task": task,
|
||||
}
|
||||
|
||||
# Log user scenario context for transparency
|
||||
if task.user_scenario and task.user_scenario.instructions:
|
||||
_logger.info(f"<cyan>User scenario:</cyan> {task.user_scenario.instructions.reason_for_call}") # type: ignore
|
||||
|
||||
try:
|
||||
# Execute the workflow: agent + user simulator conversation
|
||||
conversation = await task_runner.run(task, assistant_chat_client, user_simulator_chat_client)
|
||||
|
||||
# Evaluate performance using tau2's comprehensive metrics
|
||||
reward_value = task_runner.evaluate(task, conversation, task_runner.termination_reason)
|
||||
|
||||
# Store detailed results for analysis
|
||||
result["evaluation"] = task_runner.full_reward_info # Full evaluation breakdown
|
||||
result["messages"] = conversation # Complete conversation history
|
||||
result["termination_reason"] = task_runner.termination_reason # How conversation ended
|
||||
|
||||
# Log evaluation results (escape HTML for colored output)
|
||||
reward_str = str(task_runner.full_reward_info).replace("<", r"\<")
|
||||
_logger.info(f"<cyan>Final evaluation:</cyan> {reward_str}")
|
||||
|
||||
except Exception as e:
|
||||
# Robust error handling: capture all failures for analysis
|
||||
_logger.error(f"<red>Error testing task #{task.id}:</red> {e}")
|
||||
result["error"] = traceback.format_exc() # Full stack trace for debugging
|
||||
|
||||
traceback.print_exc() # Console output for immediate debugging
|
||||
reward_value = 0.0 # Zero score for failed runs
|
||||
|
||||
# STEP 7: Persist results incrementally (enables partial analysis)
|
||||
if result_fp is not None:
|
||||
result_fp.write(json.dumps(to_dumpable(result), default=str) + "\n")
|
||||
|
||||
all_rewards.append(reward_value) # Track for final statistics
|
||||
|
||||
# Reset runner state for next task
|
||||
task_runner.reinit()
|
||||
|
||||
# STEP 8: Finalize and report aggregate results
|
||||
if result_fp is not None:
|
||||
result_fp.close()
|
||||
|
||||
# Calculate overall benchmark performance
|
||||
all_accuracy = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0
|
||||
|
||||
# Report final statistics with colored formatting
|
||||
_logger.info("<green>Final Results:</green>")
|
||||
_logger.info(f"<cyan>All tasks accuracy:</cyan> {all_accuracy:.2f} ({int(sum(all_rewards))}/{len(tasks)})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Command-line interface for tau2 benchmark execution.
|
||||
|
||||
Provides flexible execution modes:
|
||||
|
||||
- Full benchmark: Runs all tasks and generates timestamped results file
|
||||
- Debug mode: Single task execution with verbose logging for development
|
||||
- Environment patching: Optional compatibility layer for tau2-bench integration
|
||||
|
||||
Usage Examples:
|
||||
# Full benchmark with default models
|
||||
python run_benchmark.py
|
||||
|
||||
# Custom models
|
||||
python run_benchmark.py --assistant gpt-4o --user gpt-4o-mini
|
||||
|
||||
# Debug specific task
|
||||
python run_benchmark.py --debug-task-id task_123
|
||||
|
||||
# Disable environment patching for testing
|
||||
python run_benchmark.py --disable-env-patch
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run tau2-agent-framework model test")
|
||||
|
||||
# Model configuration arguments
|
||||
parser.add_argument("--assistant", type=str, default="gpt-4.1", help="Assistant model id, e.g., gpt-4.1-mini")
|
||||
parser.add_argument("--user", type=str, default="gpt-4.1", help="User model id")
|
||||
|
||||
# Execution mode arguments
|
||||
parser.add_argument(
|
||||
"--debug-task-id", type=str, default=None, help="Debug a specific task ID (disables result file creation)"
|
||||
)
|
||||
parser.add_argument("--max-steps", type=int, default=100, help="Maximum number of steps to run")
|
||||
|
||||
# Environment configuration arguments
|
||||
parser.add_argument("--disable-env-patch", action="store_true", help="Disable patching tau2-bench environment")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Apply environment patch for tau2-bench compatibility
|
||||
# This modifies tau2's environment to be more flexible with tool call validation
|
||||
if not args.disable_env_patch:
|
||||
patch_env_set_state()
|
||||
|
||||
# Execute benchmark with configured parameters
|
||||
asyncio.run(
|
||||
run_benchmark(
|
||||
assistant_model=args.assistant,
|
||||
user_model=args.user,
|
||||
debug_task_id=args.debug_task_id,
|
||||
max_steps=args.max_steps,
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def purge_tau2_data():
|
||||
"""Purge tau2 data directory if it exists."""
|
||||
|
||||
data_dir = Path.cwd() / "data"
|
||||
|
||||
if data_dir.exists():
|
||||
shutil.rmtree(data_dir)
|
||||
print(f"Data directory at {data_dir} has been purged.")
|
||||
else:
|
||||
print("Data directory not found. Skipping purge.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
purge_tau2_data()
|
||||
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def setup_tau2_data():
|
||||
"""Set up tau2 data directory by cloning repository if needed."""
|
||||
|
||||
# Get project directory (parent of tests directory)
|
||||
data_dir = Path.cwd() / "data"
|
||||
print(data_dir)
|
||||
|
||||
print("Setting up tau2 data directory...")
|
||||
|
||||
# Check if data directory already exists
|
||||
if data_dir.exists():
|
||||
print(f"Data directory already exists at {data_dir}")
|
||||
else:
|
||||
print("Data directory not found. Cloning tau2-bench repository...")
|
||||
|
||||
try:
|
||||
# Clone the repository
|
||||
print("Cloning https://github.com/sierra-research/tau2-bench.git...")
|
||||
subprocess.run(
|
||||
["git", "clone", "https://github.com/sierra-research/tau2-bench.git"],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
# Move data directory
|
||||
print("Moving data directory...")
|
||||
tau2_bench_dir = Path.cwd() / "tau2-bench"
|
||||
tau2_data_dir = tau2_bench_dir / "data"
|
||||
|
||||
if tau2_data_dir.exists():
|
||||
shutil.move(str(tau2_data_dir), str(data_dir))
|
||||
else:
|
||||
raise FileNotFoundError(f"Data directory not found in cloned repository: {tau2_data_dir}")
|
||||
|
||||
# Clean up cloned repository
|
||||
print("Cleaning up cloned repository...")
|
||||
shutil.rmtree(tau2_bench_dir)
|
||||
|
||||
print("Data directory setup completed successfully!")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"ERROR: Failed to clone repository: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to set up data directory: {e}")
|
||||
raise
|
||||
|
||||
print(f"TAU2_DATA_DIR should be set to: {data_dir}")
|
||||
|
||||
return str(data_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup_tau2_data()
|
||||
@@ -0,0 +1,265 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent_framework._types import ChatMessage, Role, TextContent, FunctionCallContent, FunctionResultContent
|
||||
from agent_framework_lab_tau2._message_utils import flip_messages, log_messages
|
||||
|
||||
|
||||
def test_flip_messages_user_to_assistant():
|
||||
"""Test flipping user message to assistant."""
|
||||
messages = [
|
||||
ChatMessage(
|
||||
role=Role.USER, contents=[TextContent(text="Hello assistant")], author_name="User1", message_id="msg_001"
|
||||
)
|
||||
]
|
||||
|
||||
flipped = flip_messages(messages)
|
||||
|
||||
assert len(flipped) == 1
|
||||
assert flipped[0].role == Role.ASSISTANT
|
||||
assert flipped[0].text == "Hello assistant"
|
||||
assert flipped[0].author_name == "User1"
|
||||
assert flipped[0].message_id == "msg_001"
|
||||
|
||||
|
||||
def test_flip_messages_assistant_to_user():
|
||||
"""Test flipping assistant message to user."""
|
||||
messages = [
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[TextContent(text="Hello user")],
|
||||
author_name="Assistant1",
|
||||
message_id="msg_002",
|
||||
)
|
||||
]
|
||||
|
||||
flipped = flip_messages(messages)
|
||||
|
||||
assert len(flipped) == 1
|
||||
assert flipped[0].role == Role.USER
|
||||
assert flipped[0].text == "Hello user"
|
||||
assert flipped[0].author_name == "Assistant1"
|
||||
assert flipped[0].message_id == "msg_002"
|
||||
|
||||
|
||||
def test_flip_messages_assistant_with_function_calls_filtered():
|
||||
"""Test that function calls are filtered out when flipping assistant to user."""
|
||||
function_call = FunctionCallContent(call_id="call_123", name="test_function", arguments={"param": "value"})
|
||||
|
||||
messages = [
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[TextContent(text="I'll call a function"), function_call, TextContent(text="After the call")],
|
||||
message_id="msg_003",
|
||||
)
|
||||
]
|
||||
|
||||
flipped = flip_messages(messages)
|
||||
|
||||
assert len(flipped) == 1
|
||||
assert flipped[0].role == Role.USER
|
||||
# Function call should be filtered out
|
||||
assert len(flipped[0].contents) == 2
|
||||
assert all(content.type == "text" for content in flipped[0].contents)
|
||||
assert "I'll call a function" in flipped[0].text
|
||||
assert "After the call" in flipped[0].text
|
||||
|
||||
|
||||
def test_flip_messages_assistant_with_only_function_calls_skipped():
|
||||
"""Test that assistant messages with only function calls are skipped."""
|
||||
function_call = FunctionCallContent(call_id="call_456", name="another_function", arguments={"key": "value"})
|
||||
|
||||
messages = [
|
||||
ChatMessage(role=Role.ASSISTANT, contents=[function_call], message_id="msg_004") # Only function call, no text
|
||||
]
|
||||
|
||||
flipped = flip_messages(messages)
|
||||
|
||||
# Should be empty since the message had no text content after filtering
|
||||
assert len(flipped) == 0
|
||||
|
||||
|
||||
def test_flip_messages_tool_messages_skipped():
|
||||
"""Test that tool messages are skipped."""
|
||||
function_result = FunctionResultContent(call_id="call_789", result={"success": True})
|
||||
|
||||
messages = [ChatMessage(role=Role.TOOL, contents=[function_result])]
|
||||
|
||||
flipped = flip_messages(messages)
|
||||
|
||||
# Tool messages should be skipped
|
||||
assert len(flipped) == 0
|
||||
|
||||
|
||||
def test_flip_messages_system_messages_preserved():
|
||||
"""Test that system messages are preserved as-is."""
|
||||
messages = [ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="System instruction")], message_id="sys_001")]
|
||||
|
||||
flipped = flip_messages(messages)
|
||||
|
||||
assert len(flipped) == 1
|
||||
assert flipped[0].role == Role.SYSTEM
|
||||
assert flipped[0].text == "System instruction"
|
||||
assert flipped[0].message_id == "sys_001"
|
||||
|
||||
|
||||
def test_flip_messages_mixed_conversation():
|
||||
"""Test flipping a mixed conversation."""
|
||||
function_call = FunctionCallContent(call_id="call_mixed", name="mixed_function", arguments={})
|
||||
|
||||
function_result = FunctionResultContent(call_id="call_mixed", result="function result")
|
||||
|
||||
messages = [
|
||||
ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="System prompt")]),
|
||||
ChatMessage(role=Role.USER, contents=[TextContent(text="User question")]),
|
||||
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Assistant response"), function_call]),
|
||||
ChatMessage(role=Role.TOOL, contents=[function_result]),
|
||||
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Final response")]),
|
||||
]
|
||||
|
||||
flipped = flip_messages(messages)
|
||||
|
||||
# Should have: system (unchanged), assistant (from user), user (from assistant, filtered), assistant (from final assistant)
|
||||
assert len(flipped) == 4
|
||||
|
||||
# Check each flipped message
|
||||
assert flipped[0].role == Role.SYSTEM
|
||||
assert flipped[0].text == "System prompt"
|
||||
|
||||
assert flipped[1].role == Role.ASSISTANT
|
||||
assert flipped[1].text == "User question"
|
||||
|
||||
assert flipped[2].role == Role.USER
|
||||
assert flipped[2].text == "Assistant response" # Function call filtered out
|
||||
|
||||
# Tool message skipped
|
||||
|
||||
assert flipped[3].role == Role.USER
|
||||
assert flipped[3].text == "Final response"
|
||||
|
||||
|
||||
def test_flip_messages_empty_list():
|
||||
"""Test flipping empty message list."""
|
||||
messages = []
|
||||
flipped = flip_messages(messages)
|
||||
assert len(flipped) == 0
|
||||
|
||||
|
||||
def test_flip_messages_preserves_metadata():
|
||||
"""Test that message metadata is preserved during flipping."""
|
||||
messages = [
|
||||
ChatMessage(
|
||||
role=Role.USER, contents=[TextContent(text="Test message")], author_name="TestUser", message_id="test_123"
|
||||
)
|
||||
]
|
||||
|
||||
flipped = flip_messages(messages)
|
||||
|
||||
assert len(flipped) == 1
|
||||
assert flipped[0].author_name == "TestUser"
|
||||
assert flipped[0].message_id == "test_123"
|
||||
|
||||
|
||||
@patch("agent_framework_lab_tau2._message_utils.logger")
|
||||
def test_log_messages_text_content(mock_logger):
|
||||
"""Test logging messages with text content."""
|
||||
messages = [
|
||||
ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")]),
|
||||
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Hi there!")]),
|
||||
]
|
||||
|
||||
log_messages(messages)
|
||||
|
||||
# Should have called logger.info for each message
|
||||
assert mock_logger.opt.return_value.info.call_count == 2
|
||||
|
||||
|
||||
@patch("agent_framework_lab_tau2._message_utils.logger")
|
||||
def test_log_messages_function_call(mock_logger):
|
||||
"""Test logging messages with function calls."""
|
||||
function_call = FunctionCallContent(call_id="call_log", name="log_function", arguments={"param": "value"})
|
||||
|
||||
messages = [ChatMessage(role=Role.ASSISTANT, contents=[function_call])]
|
||||
|
||||
log_messages(messages)
|
||||
|
||||
# Should log the function call
|
||||
mock_logger.opt.return_value.info.assert_called()
|
||||
call_args = mock_logger.opt.return_value.info.call_args[0][0]
|
||||
assert "TOOL_CALL" in call_args
|
||||
assert "log_function" in call_args
|
||||
|
||||
|
||||
@patch("agent_framework_lab_tau2._message_utils.logger")
|
||||
def test_log_messages_function_result(mock_logger):
|
||||
"""Test logging messages with function results."""
|
||||
function_result = FunctionResultContent(call_id="call_result", result="success")
|
||||
|
||||
messages = [ChatMessage(role=Role.TOOL, contents=[function_result])]
|
||||
|
||||
log_messages(messages)
|
||||
|
||||
# Should log the function result
|
||||
mock_logger.opt.return_value.info.assert_called()
|
||||
call_args = mock_logger.opt.return_value.info.call_args[0][0]
|
||||
assert "TOOL_RESULT" in call_args
|
||||
|
||||
|
||||
@patch("agent_framework_lab_tau2._message_utils.logger")
|
||||
def test_log_messages_different_roles(mock_logger):
|
||||
"""Test logging messages with different roles get different colors."""
|
||||
messages = [
|
||||
ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="System")]),
|
||||
ChatMessage(role=Role.USER, contents=[TextContent(text="User")]),
|
||||
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Assistant")]),
|
||||
ChatMessage(role=Role.TOOL, contents=[TextContent(text="Tool")]),
|
||||
]
|
||||
|
||||
log_messages(messages)
|
||||
|
||||
# Should have called logger for each message
|
||||
assert mock_logger.opt.return_value.info.call_count == 4
|
||||
|
||||
# Check that different color tags are used
|
||||
calls = mock_logger.opt.return_value.info.call_args_list
|
||||
system_call = calls[0][0][0]
|
||||
user_call = calls[1][0][0]
|
||||
assistant_call = calls[2][0][0]
|
||||
tool_call = calls[3][0][0]
|
||||
|
||||
assert "cyan" in system_call or "SYSTEM" in system_call
|
||||
assert "green" in user_call or "USER" in user_call
|
||||
assert "blue" in assistant_call or "ASSISTANT" in assistant_call
|
||||
assert "yellow" in tool_call or "TOOL" in tool_call
|
||||
|
||||
|
||||
@patch("agent_framework_lab_tau2._message_utils.logger")
|
||||
def test_log_messages_escapes_html(mock_logger):
|
||||
"""Test that HTML-like characters are properly escaped in log output."""
|
||||
messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Message with <tag> content")])]
|
||||
|
||||
log_messages(messages)
|
||||
|
||||
mock_logger.opt.return_value.info.assert_called()
|
||||
call_args = mock_logger.opt.return_value.info.call_args[0][0]
|
||||
# Should escape < characters
|
||||
assert "\\<tag>" in call_args or "<tag>" in call_args
|
||||
|
||||
|
||||
@patch("agent_framework_lab_tau2._message_utils.logger")
|
||||
def test_log_messages_mixed_content_types(mock_logger):
|
||||
"""Test logging messages with mixed content types."""
|
||||
function_call = FunctionCallContent(call_id="mixed_call", name="mixed_function", arguments={"key": "value"})
|
||||
|
||||
messages = [
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[TextContent(text="I'll call a function"), function_call, TextContent(text="Done!")],
|
||||
)
|
||||
]
|
||||
|
||||
log_messages(messages)
|
||||
|
||||
# Should log multiple times for different content types
|
||||
assert mock_logger.opt.return_value.info.call_count == 3
|
||||
@@ -0,0 +1,267 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for sliding window message list."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent_framework._types import ChatMessage, Role, TextContent, FunctionCallContent, FunctionResultContent
|
||||
from agent_framework_lab_tau2._sliding_window import SlidingWindowChatMessageList
|
||||
|
||||
|
||||
def test_initialization_empty():
|
||||
"""Test initializing with no messages."""
|
||||
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
|
||||
|
||||
assert sliding_window.max_tokens == 1000
|
||||
assert sliding_window.system_message is None
|
||||
assert sliding_window.tool_definitions is None
|
||||
assert len(sliding_window._messages) == 0
|
||||
assert len(sliding_window._truncated_messages) == 0
|
||||
|
||||
|
||||
def test_initialization_with_parameters():
|
||||
"""Test initializing with system message and tool definitions."""
|
||||
system_msg = "You are a helpful assistant"
|
||||
tool_defs = [{"name": "test_tool", "description": "A test tool"}]
|
||||
|
||||
sliding_window = SlidingWindowChatMessageList(
|
||||
max_tokens=2000, system_message=system_msg, tool_definitions=tool_defs
|
||||
)
|
||||
|
||||
assert sliding_window.max_tokens == 2000
|
||||
assert sliding_window.system_message == system_msg
|
||||
assert sliding_window.tool_definitions == tool_defs
|
||||
|
||||
|
||||
def test_initialization_with_messages():
|
||||
"""Test initializing with existing messages."""
|
||||
messages = [
|
||||
ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")]),
|
||||
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Hi there!")]),
|
||||
]
|
||||
|
||||
sliding_window = SlidingWindowChatMessageList(messages=messages, max_tokens=1000)
|
||||
|
||||
assert len(sliding_window._messages) == 2
|
||||
assert len(sliding_window._truncated_messages) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_messages_simple():
|
||||
"""Test adding messages without truncation."""
|
||||
sliding_window = SlidingWindowChatMessageList(max_tokens=10000) # Large limit
|
||||
|
||||
new_messages = [
|
||||
ChatMessage(role=Role.USER, contents=[TextContent(text="What's the weather?")]),
|
||||
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="I can help with that.")]),
|
||||
]
|
||||
|
||||
await sliding_window.add_messages(new_messages)
|
||||
|
||||
messages = await sliding_window.list_messages()
|
||||
assert len(messages) == 2
|
||||
assert messages[0].text == "What's the weather?"
|
||||
assert messages[1].text == "I can help with that."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_all_messages_vs_list_messages():
|
||||
"""Test difference between list_all_messages and list_messages."""
|
||||
sliding_window = SlidingWindowChatMessageList(max_tokens=50) # Small limit to force truncation
|
||||
|
||||
# Add many messages to trigger truncation
|
||||
messages = [
|
||||
ChatMessage(role=Role.USER, contents=[TextContent(text=f"Message {i} with some content")]) for i in range(10)
|
||||
]
|
||||
|
||||
await sliding_window.add_messages(messages)
|
||||
|
||||
truncated_messages = await sliding_window.list_messages()
|
||||
all_messages = await sliding_window.list_all_messages()
|
||||
|
||||
# All messages should contain everything
|
||||
assert len(all_messages) == 10
|
||||
|
||||
# Truncated messages should be fewer due to token limit
|
||||
assert len(truncated_messages) < len(all_messages)
|
||||
|
||||
|
||||
def test_get_token_count_basic():
|
||||
"""Test basic token counting."""
|
||||
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
|
||||
sliding_window._truncated_messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])]
|
||||
|
||||
token_count = sliding_window.get_token_count()
|
||||
|
||||
# Should be more than 0 (exact count depends on encoding)
|
||||
assert token_count > 0
|
||||
|
||||
|
||||
def test_get_token_count_with_system_message():
|
||||
"""Test token counting includes system message."""
|
||||
system_msg = "You are a helpful assistant"
|
||||
sliding_window = SlidingWindowChatMessageList(max_tokens=1000, system_message=system_msg)
|
||||
|
||||
# Without messages
|
||||
token_count_empty = sliding_window.get_token_count()
|
||||
|
||||
# Add a message
|
||||
sliding_window._truncated_messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])]
|
||||
token_count_with_message = sliding_window.get_token_count()
|
||||
|
||||
# With message should be more tokens
|
||||
assert token_count_with_message > token_count_empty
|
||||
assert token_count_empty > 0 # System message contributes tokens
|
||||
|
||||
|
||||
def test_get_token_count_function_call():
|
||||
"""Test token counting with function calls."""
|
||||
function_call = FunctionCallContent(call_id="call_123", name="test_function", arguments={"param": "value"})
|
||||
|
||||
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
|
||||
sliding_window._truncated_messages = [ChatMessage(role=Role.ASSISTANT, contents=[function_call])]
|
||||
|
||||
token_count = sliding_window.get_token_count()
|
||||
assert token_count > 0
|
||||
|
||||
|
||||
def test_get_token_count_function_result():
|
||||
"""Test token counting with function results."""
|
||||
function_result = FunctionResultContent(call_id="call_123", result={"success": True, "data": "result"})
|
||||
|
||||
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
|
||||
sliding_window._truncated_messages = [ChatMessage(role=Role.TOOL, contents=[function_result])]
|
||||
|
||||
token_count = sliding_window.get_token_count()
|
||||
assert token_count > 0
|
||||
|
||||
|
||||
@patch("agent_framework_lab_tau2._sliding_window.logger")
|
||||
def test_truncate_messages_removes_old_messages(mock_logger):
|
||||
"""Test that truncation removes old messages when token limit exceeded."""
|
||||
sliding_window = SlidingWindowChatMessageList(max_tokens=20) # Very small limit
|
||||
|
||||
# Create messages that will exceed the limit
|
||||
messages = [
|
||||
ChatMessage(
|
||||
role=Role.USER,
|
||||
contents=[TextContent(text="This is a very long message that should exceed the token limit")],
|
||||
),
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[TextContent(text="This is another very long message that should also exceed the token limit")],
|
||||
),
|
||||
ChatMessage(role=Role.USER, contents=[TextContent(text="Short msg")]),
|
||||
]
|
||||
|
||||
sliding_window._truncated_messages = messages.copy()
|
||||
sliding_window.truncate_messages()
|
||||
|
||||
# Should have fewer messages after truncation
|
||||
assert len(sliding_window._truncated_messages) < len(messages)
|
||||
|
||||
# Should have logged warnings
|
||||
assert mock_logger.warning.called
|
||||
|
||||
|
||||
@patch("agent_framework_lab_tau2._sliding_window.logger")
|
||||
def test_truncate_messages_removes_leading_tool_messages(mock_logger):
|
||||
"""Test that truncation removes leading tool messages."""
|
||||
sliding_window = SlidingWindowChatMessageList(max_tokens=10000) # Large limit
|
||||
|
||||
# Create messages starting with tool message
|
||||
tool_message = ChatMessage(role=Role.TOOL, contents=[FunctionResultContent(call_id="call_123", result="result")])
|
||||
user_message = ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])
|
||||
|
||||
sliding_window._truncated_messages = [tool_message, user_message]
|
||||
sliding_window.truncate_messages()
|
||||
|
||||
# Tool message should be removed from the beginning
|
||||
assert len(sliding_window._truncated_messages) == 1
|
||||
assert sliding_window._truncated_messages[0].role == Role.USER
|
||||
|
||||
# Should have logged warning about removing tool message
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
|
||||
def test_estimate_any_object_token_count_dict():
|
||||
"""Test token counting for dictionary objects."""
|
||||
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
|
||||
|
||||
test_dict = {"key": "value", "number": 42}
|
||||
token_count = sliding_window.estimate_any_object_token_count(test_dict)
|
||||
|
||||
assert token_count > 0
|
||||
|
||||
|
||||
def test_estimate_any_object_token_count_string():
|
||||
"""Test token counting for string objects."""
|
||||
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
|
||||
|
||||
test_string = "This is a test string"
|
||||
token_count = sliding_window.estimate_any_object_token_count(test_string)
|
||||
|
||||
assert token_count > 0
|
||||
|
||||
|
||||
def test_estimate_any_object_token_count_non_serializable():
|
||||
"""Test token counting for non-JSON-serializable objects."""
|
||||
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
|
||||
|
||||
# Create an object that can't be JSON serialized
|
||||
class CustomObject:
|
||||
def __str__(self):
|
||||
return "CustomObject instance"
|
||||
|
||||
custom_obj = CustomObject()
|
||||
token_count = sliding_window.estimate_any_object_token_count(custom_obj)
|
||||
|
||||
# Should fall back to string representation
|
||||
assert token_count > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_world_scenario():
|
||||
"""Test a realistic conversation scenario."""
|
||||
sliding_window = SlidingWindowChatMessageList(
|
||||
max_tokens=30, system_message="You are a helpful assistant" # Moderate limit
|
||||
)
|
||||
|
||||
# Simulate a conversation
|
||||
conversation = [
|
||||
ChatMessage(role=Role.USER, contents=[TextContent(text="Hello, how are you?")]),
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT, contents=[TextContent(text="I'm doing well, thank you! How can I help you today?")]
|
||||
),
|
||||
ChatMessage(role=Role.USER, contents=[TextContent(text="Can you tell me about the weather?")]),
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
TextContent(
|
||||
text="I'd be happy to help with weather information, but I don't have access to current weather data."
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatMessage(role=Role.USER, contents=[TextContent(text="What about telling me a joke instead?")]),
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[TextContent(text="Sure! Why don't scientists trust atoms? Because they make up everything!")],
|
||||
),
|
||||
]
|
||||
|
||||
await sliding_window.add_messages(conversation)
|
||||
|
||||
current_messages = await sliding_window.list_messages()
|
||||
all_messages = await sliding_window.list_all_messages()
|
||||
|
||||
# All messages should be preserved
|
||||
assert len(all_messages) == 6
|
||||
|
||||
# Current messages might be truncated
|
||||
assert len(current_messages) <= 6
|
||||
|
||||
# Token count should be within or close to limit
|
||||
token_count = sliding_window.get_token_count()
|
||||
# Allow some margin since truncation happens when exceeded
|
||||
assert token_count <= sliding_window.max_tokens * 1.1
|
||||
@@ -0,0 +1,189 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for tau2 utils module."""
|
||||
|
||||
from typing import Any, cast
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agent_framework._tools import AIFunction
|
||||
from agent_framework._types import ChatMessage, Role, TextContent, FunctionCallContent, FunctionResultContent
|
||||
from agent_framework_lab_tau2._tau2_utils import (
|
||||
convert_tau2_tool_to_ai_function,
|
||||
convert_agent_framework_messages_to_tau2_messages,
|
||||
)
|
||||
from tau2.data_model.message import SystemMessage, UserMessage, AssistantMessage, ToolMessage, ToolCall
|
||||
from tau2.domains.airline.environment import get_environment
|
||||
|
||||
|
||||
def test_convert_tau2_tool_to_ai_function_basic():
|
||||
"""Test basic conversion from tau2 tool to AIFunction."""
|
||||
# Get real tools from tau2 environment
|
||||
env = get_environment()
|
||||
tools = env.get_tools()
|
||||
|
||||
# Use the first available tool for testing
|
||||
assert len(tools) > 0, "No tools available in environment"
|
||||
tau2_tool = tools[0]
|
||||
|
||||
# Convert the tool
|
||||
ai_function = convert_tau2_tool_to_ai_function(tau2_tool)
|
||||
|
||||
# Verify the conversion
|
||||
assert isinstance(ai_function, AIFunction)
|
||||
assert ai_function.name == tau2_tool.name
|
||||
assert ai_function.description == tau2_tool._get_description()
|
||||
assert ai_function.input_model == tau2_tool.params
|
||||
|
||||
# Test that the function is callable (we won't call it with real params to avoid side effects)
|
||||
assert callable(ai_function.func)
|
||||
|
||||
|
||||
def test_convert_tau2_tool_to_ai_function_multiple_tools():
|
||||
"""Test conversion with multiple tau2 tools."""
|
||||
# Get real tools from tau2 environment
|
||||
env = get_environment()
|
||||
tools = env.get_tools()
|
||||
|
||||
# Convert multiple tools
|
||||
ai_functions = [convert_tau2_tool_to_ai_function(tool) for tool in tools[:3]] # Test first 3 tools
|
||||
|
||||
# Verify all conversions
|
||||
for ai_function, tau2_tool in zip(ai_functions, tools[:3]):
|
||||
assert isinstance(ai_function, AIFunction)
|
||||
assert ai_function.name == tau2_tool.name
|
||||
assert ai_function.description == tau2_tool._get_description()
|
||||
assert ai_function.input_model == tau2_tool.params
|
||||
assert callable(ai_function.func)
|
||||
|
||||
|
||||
def test_convert_agent_framework_messages_to_tau2_messages_system():
|
||||
"""Test converting system message."""
|
||||
messages = [ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="System instruction")])]
|
||||
|
||||
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
|
||||
|
||||
assert len(tau2_messages) == 1
|
||||
assert isinstance(tau2_messages[0], SystemMessage)
|
||||
assert tau2_messages[0].role == "system"
|
||||
assert tau2_messages[0].content == "System instruction"
|
||||
|
||||
|
||||
def test_convert_agent_framework_messages_to_tau2_messages_user():
|
||||
"""Test converting user message."""
|
||||
messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello assistant")])]
|
||||
|
||||
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
|
||||
|
||||
assert len(tau2_messages) == 1
|
||||
assert isinstance(tau2_messages[0], UserMessage)
|
||||
assert tau2_messages[0].role == "user"
|
||||
assert tau2_messages[0].content == "Hello assistant"
|
||||
assert tau2_messages[0].tool_calls is None
|
||||
|
||||
|
||||
def test_convert_agent_framework_messages_to_tau2_messages_assistant():
|
||||
"""Test converting assistant message."""
|
||||
messages = [ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Hello user")])]
|
||||
|
||||
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
|
||||
|
||||
assert len(tau2_messages) == 1
|
||||
assert isinstance(tau2_messages[0], AssistantMessage)
|
||||
assert tau2_messages[0].role == "assistant"
|
||||
assert tau2_messages[0].content == "Hello user"
|
||||
assert tau2_messages[0].tool_calls is None
|
||||
|
||||
|
||||
def test_convert_agent_framework_messages_to_tau2_messages_with_function_call():
|
||||
"""Test converting message with function call."""
|
||||
function_call = FunctionCallContent(call_id="call_123", name="test_function", arguments={"param": "value"})
|
||||
|
||||
messages = [ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="I'll call a function"), function_call])]
|
||||
|
||||
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
|
||||
|
||||
assert len(tau2_messages) == 1
|
||||
assert isinstance(tau2_messages[0], AssistantMessage)
|
||||
assert tau2_messages[0].content == "I'll call a function"
|
||||
assert tau2_messages[0].tool_calls is not None
|
||||
assert len(tau2_messages[0].tool_calls) == 1
|
||||
|
||||
tool_call = tau2_messages[0].tool_calls[0]
|
||||
assert isinstance(tool_call, ToolCall)
|
||||
assert tool_call.id == "call_123"
|
||||
assert tool_call.name == "test_function"
|
||||
assert tool_call.arguments == {"param": "value"}
|
||||
assert tool_call.requestor == "assistant"
|
||||
|
||||
|
||||
def test_convert_agent_framework_messages_to_tau2_messages_with_function_result():
|
||||
"""Test converting message with function result."""
|
||||
function_result = FunctionResultContent(call_id="call_123", result={"success": True, "data": "result data"})
|
||||
|
||||
messages = [ChatMessage(role=Role.TOOL, contents=[function_result])]
|
||||
|
||||
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
|
||||
|
||||
assert len(tau2_messages) == 1
|
||||
assert isinstance(tau2_messages[0], ToolMessage)
|
||||
assert tau2_messages[0].id == "call_123"
|
||||
assert tau2_messages[0].role == "tool"
|
||||
assert tau2_messages[0].content is not None
|
||||
assert '{"success": true, "data": "result data"}' in tau2_messages[0].content
|
||||
assert tau2_messages[0].requestor == "assistant"
|
||||
assert tau2_messages[0].error is False
|
||||
|
||||
|
||||
def test_convert_agent_framework_messages_to_tau2_messages_with_error():
|
||||
"""Test converting function result with error."""
|
||||
function_result = FunctionResultContent(
|
||||
call_id="call_456", result="Error occurred", exception=Exception("Test error")
|
||||
)
|
||||
|
||||
messages = [ChatMessage(role=Role.TOOL, contents=[function_result])]
|
||||
|
||||
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
|
||||
|
||||
assert len(tau2_messages) == 1
|
||||
assert isinstance(tau2_messages[0], ToolMessage)
|
||||
assert tau2_messages[0].error is True
|
||||
|
||||
|
||||
def test_convert_agent_framework_messages_to_tau2_messages_multiple_text_contents():
|
||||
"""Test converting message with multiple text contents."""
|
||||
messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="First part"), TextContent(text="Second part")])]
|
||||
|
||||
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
|
||||
|
||||
assert len(tau2_messages) == 1
|
||||
assert isinstance(tau2_messages[0], UserMessage)
|
||||
assert tau2_messages[0].content == "First part Second part"
|
||||
|
||||
|
||||
def test_convert_agent_framework_messages_to_tau2_messages_complex_scenario():
|
||||
"""Test converting complex scenario with multiple message types."""
|
||||
function_call = FunctionCallContent(call_id="call_789", name="complex_tool", arguments='{"key": "value"}')
|
||||
|
||||
function_result = FunctionResultContent(call_id="call_789", result={"output": "tool result"})
|
||||
|
||||
messages = [
|
||||
ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="System prompt")]),
|
||||
ChatMessage(role=Role.USER, contents=[TextContent(text="User request")]),
|
||||
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="I'll help you"), function_call]),
|
||||
ChatMessage(role=Role.TOOL, contents=[function_result]),
|
||||
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Based on the result...")]),
|
||||
]
|
||||
|
||||
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
|
||||
|
||||
assert len(tau2_messages) == 5
|
||||
assert isinstance(tau2_messages[0], SystemMessage)
|
||||
assert isinstance(tau2_messages[1], UserMessage)
|
||||
assert isinstance(tau2_messages[2], AssistantMessage)
|
||||
assert isinstance(tau2_messages[3], ToolMessage)
|
||||
assert isinstance(tau2_messages[4], AssistantMessage)
|
||||
|
||||
# Check the assistant message with tool call
|
||||
assert tau2_messages[2].tool_calls is not None
|
||||
assert len(tau2_messages[2].tool_calls) == 1
|
||||
assert tau2_messages[2].tool_calls[0].name == "complex_tool"
|
||||
@@ -20,6 +20,7 @@ dev = [
|
||||
"pytest>=8.4.1",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
"pytest-cov>=6.2.1",
|
||||
"pytest-env>=1.1.5",
|
||||
"pytest-xdist[psutil]>=3.8.0",
|
||||
"pytest-timeout>=2.3.1",
|
||||
"mypy>=1.16.1",
|
||||
@@ -164,7 +165,7 @@ exclude_dirs = ["tests", "./run_tasks_in_packages_if_exists.py", "./check_md_cod
|
||||
executor.type = "uv"
|
||||
|
||||
[tool.poe.tasks]
|
||||
markdown-code-lint = """uv run python check_md_code_blocks.py README.md ./packages/**/README.md ./samples/**/*.md --exclude cookiecutter-agent-framework-lab"""
|
||||
markdown-code-lint = """uv run python check_md_code_blocks.py README.md ./packages/**/README.md ./samples/**/*.md --exclude cookiecutter-agent-framework-lab --exclude tau2"""
|
||||
samples-code-check = """pyright ./samples"""
|
||||
docs-install = "uv sync --all-packages --all-extras --dev -U --prerelease=if-necessary-or-explicit --group=docs"
|
||||
docs-clean = "rm -rf docs/build"
|
||||
|
||||
Generated
+1311
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user