Files
agent-framework/python/packages/lab/lightning/samples/train_tau2_agent.py
T
Eduard van Valkenburg 977c3adfb2 Python: replace pre-commit with prek, add PEP 723 script deps, clean up dev dependencies (#3748)
* python: replace pre-commit with prek, add PEP 723 script deps, clean up dev dependencies

- Replace pre-commit with prek (Rust-native, faster pre-commit alternative)
- Move supported hooks to repo: builtin for zero-clone speed
- Add new builtin hooks: trailing-whitespace, check-merge-conflict, detect-private-key, check-added-large-files
- Update all hook versions to latest (pre-commit-hooks v6, pyupgrade v3.21.2, bandit 1.9.3, uv-pre-commit 0.10.0)
- Add PEP 723 inline script metadata to 34 samples with external deps
- Remove autogen-agentchat/autogen-ext from dev deps (now declared per-sample)
- Remove unused dev deps: pytest-env, tomli-w
- Add agent-framework-core>=1.0.0b260130 lower bound to all 21 packages
- Update CI workflow to use j178/prek-action
- Update docs: DEV_SETUP.md, AGENTS.md, CODING_STANDARD.md, SAMPLE_GUIDELINES.md

* updated lock

* python: fix prek config paths for local execution and CI workflow

Remove global 'files: ^python/' filter and strip python/ prefix from all path patterns in .pre-commit-config.yaml so prek finds files when run from the python/ directory. Update CI workflow to use --cd python instead of --config path. Include trailing whitespace fixes and dev dependency cleanup.

* python: move helper scripts to scripts/ folder and exclude from checks

* python: exclude AGENTS.md from prek markdown code lint

* python: exclude AGENTS.md and azure_ai_search sample from markdown lint

* fix m365 sample

* python: ignore CPY rule for samples with PEP 723 headers

* fix in dev_setup

* python: replace aiofiles with regular open in samples

* python: suppress reportUnusedImport in markdown code block checker

* python: use samples pyright config for markdown code block checker

Write a temp pyrightconfig.json matching pyrightconfig.samples.json rules (typeCheckingMode=off, only reportMissingImports and reportAttributeAccessIssue). Filter output to only fail on these rules since syntax-level errors (top-level await, undefined vars) are expected in README documentation snippets.

* python: use markdown-code-lint with fixed globs instead of prek file list

The prek-markdown-code-lint task received all changed files including non-README markdown and files with pre-existing broken imports. Replace with the standard markdown-code-lint task which uses the correct glob patterns (README.md, packages/**/README.md, samples/**/*.md).

* python: exclude READMEs with pre-existing broken imports from markdown lint

* python: fix broken README code snippets instead of excluding them

- ag-ui: replace TextContent (removed) with content.type == 'text'
- durabletask: fix import path to durabletask.worker.TaskHubGrpcWorker
- orchestrations: use constructor params instead of .participants() method
- observability: mark deprecated code blocks as plain text, filter
  reportMissingImports to agent_framework modules only
- remove README excludes from markdown-code-lint task

* add revision to gaia download

* feat(python): parallelize checks across packages

Run (package × task) cross-product in parallel using ThreadPoolExecutor
and subprocesses. Key changes:

- Add scripts/task_runner.py with shared parallel execution engine
- Update run_tasks_in_packages_if_exists.py to accept multiple tasks
- Update run_tasks_in_changed_packages.py with --files flag and parallel support
- Add check-packages poe task (fmt+lint+pyright+mypy in parallel)
- Add prek-markdown-code-lint and prek-samples-check with change detection
- Split CI code quality workflow into parallel prek and mypy jobs
- Update DEV_SETUP.md to document new parallel behavior

Core package changes still trigger checks on all packages.

* feat(ci): split code quality into 4 parallel jobs

Split the single prek job into parallel jobs:
- pre-commit-hooks: lightweight hooks (SKIP=poe-check)
- package-checks: fmt/lint/pyright/mypy via check-packages
- samples-markdown: samples-lint, samples-syntax, markdown-code-lint
- mypy: change-detected mypy checks

All 4 jobs run concurrently (×2 Python versions = 8 runners).

* feat(ci): use only Python 3.10 for code quality checks

* refactor(python): add future annotations and remove quoted types

Add `from __future__ import annotations` to 93 package files that
used quoted string annotations, then run pyupgrade --py310-plus to
remove the now-unnecessary quotes.

Fixes https://github.com/microsoft/agent-framework/issues/3578
2026-02-09 17:51:01 +00:00

232 lines
8.6 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
"""Advanced example showing multi-agent RL training using Tau2 benchmark.
This demonstrates:
- LitAgent class-based approach (vs @rollout decorator)
- Multi-agent scenarios with agent filtering
- Resource management for complex setups
- Integration with external benchmarks
Builds on concepts from train_math_agent.py with additional complexity.
Requires one GPU of at least 80GB of memory.
"""
from __future__ import annotations
import argparse
import asyncio
import json
import os
import random
import time
import traceback
from pathlib import Path
from typing import TypedDict, cast
from agent_framework.lab.lightning import AgentFrameworkTracer
from agent_framework.lab.tau2 import ASSISTANT_AGENT_ID, patch_env_set_state # type: ignore
from agent_framework.lab.tau2 import TaskRunner as Tau2TaskRunner # type: ignore
from agent_framework.openai import OpenAIChatClient
from agentlightning import LLM, Dataset, LitAgent, NamedResources, Rollout, Trainer
from agentlightning.algorithm.verl import VERL
from tau2.data_model.tasks import Task as Tau2Task # type: ignore[import-untyped]
# Tau2 tasks are complex objects that need special handling during distributed training
class SerializedTask(TypedDict):
"""Tau2 task object type."""
id: str
data: str # JSON-serialized task data to prevent HuggingFace conversion issues
def _load_dataset() -> tuple[Dataset[SerializedTask], Dataset[SerializedTask]]:
"""Load and prepare Tau2 dataset with proper serialization.
It takes external data dependency (TAU2_DATA_DIR) and uses deterministic train/val split for reproducibility.
"""
data_dir = os.getenv("TAU2_DATA_DIR")
if data_dir is None:
raise ValueError("TAU2_DATA_DIR must be set")
tasks_path = Path(data_dir) / "tau2/domains/airline/tasks.json"
with tasks_path.open("r") as f:
dataset = json.load(f)
# Serialize complex task objects to prevent HuggingFace tokenizer issues
dataset = [{"id": task["id"], "data": json.dumps(task)} for task in dataset]
# Deterministic train/val split (25/25) for reproducible experiments
random_state = random.Random(42) # noqa: S311
indices = list(range(len(dataset)))
random_state.shuffle(indices)
train_indices = indices[: int(len(dataset) * 0.5)]
val_indices = indices[int(len(dataset) * 0.5) :]
print(f"Train indices: {train_indices}")
print(f"Val indices: {val_indices}")
train_dataset = [dataset[i] for i in train_indices]
val_dataset = [dataset[i] for i in val_indices]
return cast(Dataset[SerializedTask], train_dataset), cast(Dataset[SerializedTask], val_dataset)
# Alternative to @rollout: LitAgent class for advanced scenarios
# Use this approach when you need:
# - Agent filtering (training only specific agents in multi-agent setup)
# - Resource management (multiple LLMs, databases, etc.)
# - Complex initialization logic
class Tau2Agent(LitAgent):
"""Class-based agent with advanced resource management and agent filtering."""
async def rollout_async(self, task: SerializedTask, resources: NamedResources, rollout: Rollout) -> float:
"""The main rollout method. Similar to @rollout but with more control."""
llm = resources.get("main_llm")
if not isinstance(llm, LLM):
raise ValueError("main_llm must be an instance of LLM")
openai_base_url = os.getenv("OPENAI_BASE_URL")
openai_api_key = os.getenv("OPENAI_API_KEY")
if openai_base_url is None:
raise ValueError("OPENAI_BASE_URL must be set")
if openai_api_key is None:
raise ValueError("OPENAI_API_KEY must be set")
# Deserialize the complex task object
task_data = json.loads(task["data"])
task_obj = Tau2Task(**task_data)
# Multi-agent setup: assistant (trainable) + user simulator (fixed)
runner = Tau2TaskRunner(
max_steps=100,
assistant_window_size=4000,
assistant_sampling_temperature=llm.sampling_parameters.get("temperature", 0.0),
)
# Assistant agent: uses the model being trained
assistant_chat_client = OpenAIChatClient(
base_url=llm.endpoint, # vLLM endpoint for the model being trained
api_key=openai_api_key,
model_id=llm.model, # Model ID being trained
)
# User simulator: uses a fixed, capable model for consistent simulation
user_simulator_chat_client = OpenAIChatClient(
base_url=openai_base_url, # External API endpoint
api_key=openai_api_key,
model_id="gpt-4.1", # Fixed model for user simulator
)
try:
# Run the multi-agent conversation
conversation = await runner.run(task_obj, assistant_chat_client, user_simulator_chat_client)
except Exception:
# Handle failures gracefully - assign low reward to discourage problematic behavior
# Common issues: tool calling errors, timeout, invalid responses
traceback.print_exc()
return 0.0
# Use Tau2's built-in evaluation metrics
evaluation = runner.evaluate(task_obj, conversation, runner.termination_reason)
# Return the evaluation score
return evaluation # noqa: RET504
def main():
"""Main entrypoint."""
# RL config with higher resource requirements and W&B logging
rl_training_config = {
"algorithm": {"adv_estimator": "grpo"},
"data": {
"train_batch_size": 8,
"max_prompt_length": 8192,
"max_response_length": 2048,
},
"actor_rollout_ref": {
"rollout": {
"tensor_model_parallel_size": 1,
"n": 8, # Higher repetition for more data per task
"log_prob_micro_batch_size_per_gpu": 4,
"multi_turn": {"format": "hermes"},
"name": "vllm",
"gpu_memory_utilization": 0.8, # Higher utilization for 80GB GPU
},
"actor": {
"ppo_mini_batch_size": 8,
"ppo_micro_batch_size_per_gpu": 4,
"optim": {"lr": 1e-6},
"use_kl_loss": False,
"clip_ratio_low": 0.2,
"clip_ratio_high": 0.3,
"fsdp_config": {
"param_offload": True,
"optimizer_offload": True,
},
},
# Reference model config
"ref": {
"log_prob_micro_batch_size_per_gpu": 8,
"fsdp_config": {"param_offload": True},
},
# Common configs for the model
"model": {
"path": "Qwen/Qwen2.5-1.5B-Instruct",
"use_remove_padding": True,
"enable_gradient_checkpointing": True,
},
},
"trainer": {
"n_gpus_per_node": 1,
"val_before_train": True,
"logger": ["console", "wandb"], # Wandb for experiment tracking
"project_name": "agent-framework-lab-lightning",
"experiment_name": "tau2_agent",
"nnodes": 1,
"test_freq": 4,
"total_epochs": 8,
},
}
patch_env_set_state() # Tau2-specific environment setup
train_dataset, val_dataset = _load_dataset()
# Key difference with math_agent: trained_agents parameter specifies which agents to train
# Only the assistant agent is trained; user simulator remains fixed
tau2_agent = Tau2Agent(trained_agents=ASSISTANT_AGENT_ID)
tracer = AgentFrameworkTracer()
trainer = Trainer(algorithm=VERL(rl_training_config), tracer=tracer, n_workers=4)
trainer.fit(tau2_agent, train_dataset, val_dataset=val_dataset)
def debug():
"""Debug mode for testing multi-agent setup and Tau2 integration."""
train_dataset, _ = _load_dataset()
tau2_agent = Tau2Agent(trained_agents=ASSISTANT_AGENT_ID)
openai_base_url = os.getenv("OPENAI_BASE_URL")
if openai_base_url is None:
raise ValueError("OPENAI_BASE_URL must be set")
patch_env_set_state() # Required for Tau2 environment
# Test with resources dict (different from @rollout LLM parameter)
asyncio.run(
tau2_agent.rollout_async(
train_dataset[0],
resources={"main_llm": LLM(model="gpt-4.1", endpoint=openai_base_url)},
rollout=Rollout(rollout_id="dummy", input="dummy_input", start_time=time.time()),
)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
if args.debug:
debug()
else:
main()