# Copyright (c) Microsoft. All rights reserved. """Advanced example showing multi-agent RL training using Tau2 benchmark. This demonstrates: - LitAgent class-based approach (vs @rollout decorator) - Multi-agent scenarios with agent filtering - Resource management for complex setups - Integration with external benchmarks Builds on concepts from train_math_agent.py with additional complexity. Requires one GPU of at least 80GB of memory. """ import argparse import asyncio import json import os import random import time import traceback from pathlib import Path from typing import TypedDict, cast from agent_framework.lab.lightning import AgentFrameworkTracer from agent_framework.lab.tau2 import ASSISTANT_AGENT_ID, patch_env_set_state # type: ignore from agent_framework.lab.tau2 import TaskRunner as Tau2TaskRunner # type: ignore from agent_framework.openai import OpenAIChatClient from agentlightning import LLM, Dataset, LitAgent, NamedResources, Rollout, Trainer from agentlightning.algorithm.verl import VERL from tau2.data_model.tasks import Task as Tau2Task # type: ignore[import-untyped] # Tau2 tasks are complex objects that need special handling during distributed training class SerializedTask(TypedDict): """Tau2 task object type.""" id: str data: str # JSON-serialized task data to prevent HuggingFace conversion issues def _load_dataset() -> tuple[Dataset[SerializedTask], Dataset[SerializedTask]]: """Load and prepare Tau2 dataset with proper serialization. It takes external data dependency (TAU2_DATA_DIR) and uses deterministic train/val split for reproducibility. """ data_dir = os.getenv("TAU2_DATA_DIR") if data_dir is None: raise ValueError("TAU2_DATA_DIR must be set") tasks_path = Path(data_dir) / "tau2/domains/airline/tasks.json" with tasks_path.open("r") as f: dataset = json.load(f) # Serialize complex task objects to prevent HuggingFace tokenizer issues dataset = [{"id": task["id"], "data": json.dumps(task)} for task in dataset] # Deterministic train/val split (25/25) for reproducible experiments random_state = random.Random(42) # noqa: S311 indices = list(range(len(dataset))) random_state.shuffle(indices) train_indices = indices[: int(len(dataset) * 0.5)] val_indices = indices[int(len(dataset) * 0.5) :] print(f"Train indices: {train_indices}") print(f"Val indices: {val_indices}") train_dataset = [dataset[i] for i in train_indices] val_dataset = [dataset[i] for i in val_indices] return cast(Dataset[SerializedTask], train_dataset), cast(Dataset[SerializedTask], val_dataset) # Alternative to @rollout: LitAgent class for advanced scenarios # Use this approach when you need: # - Agent filtering (training only specific agents in multi-agent setup) # - Resource management (multiple LLMs, databases, etc.) # - Complex initialization logic class Tau2Agent(LitAgent): """Class-based agent with advanced resource management and agent filtering.""" async def rollout_async(self, task: SerializedTask, resources: NamedResources, rollout: Rollout) -> float: """The main rollout method. Similar to @rollout but with more control.""" llm = resources.get("main_llm") if not isinstance(llm, LLM): raise ValueError("main_llm must be an instance of LLM") openai_base_url = os.getenv("OPENAI_BASE_URL") openai_api_key = os.getenv("OPENAI_API_KEY") if openai_base_url is None: raise ValueError("OPENAI_BASE_URL must be set") if openai_api_key is None: raise ValueError("OPENAI_API_KEY must be set") # Deserialize the complex task object task_data = json.loads(task["data"]) task_obj = Tau2Task(**task_data) # Multi-agent setup: assistant (trainable) + user simulator (fixed) runner = Tau2TaskRunner( max_steps=100, assistant_window_size=4000, assistant_sampling_temperature=llm.sampling_parameters.get("temperature", 0.0), ) # Assistant agent: uses the model being trained assistant_chat_client = OpenAIChatClient( base_url=llm.endpoint, # vLLM endpoint for the model being trained api_key=openai_api_key, model_id=llm.model, # Model ID being trained ) # User simulator: uses a fixed, capable model for consistent simulation user_simulator_chat_client = OpenAIChatClient( base_url=openai_base_url, # External API endpoint api_key=openai_api_key, model_id="gpt-4.1", # Fixed model for user simulator ) try: # Run the multi-agent conversation conversation = await runner.run(task_obj, assistant_chat_client, user_simulator_chat_client) except Exception: # Handle failures gracefully - assign low reward to discourage problematic behavior # Common issues: tool calling errors, timeout, invalid responses traceback.print_exc() return 0.0 # Use Tau2's built-in evaluation metrics evaluation = runner.evaluate(task_obj, conversation, runner.termination_reason) # Return the evaluation score return evaluation # noqa: RET504 def main(): """Main entrypoint.""" # RL config with higher resource requirements and W&B logging rl_training_config = { "algorithm": {"adv_estimator": "grpo"}, "data": { "train_batch_size": 8, "max_prompt_length": 8192, "max_response_length": 2048, }, "actor_rollout_ref": { "rollout": { "tensor_model_parallel_size": 1, "n": 8, # Higher repetition for more data per task "log_prob_micro_batch_size_per_gpu": 4, "multi_turn": {"format": "hermes"}, "name": "vllm", "gpu_memory_utilization": 0.8, # Higher utilization for 80GB GPU }, "actor": { "ppo_mini_batch_size": 8, "ppo_micro_batch_size_per_gpu": 4, "optim": {"lr": 1e-6}, "use_kl_loss": False, "clip_ratio_low": 0.2, "clip_ratio_high": 0.3, "fsdp_config": { "param_offload": True, "optimizer_offload": True, }, }, # Reference model config "ref": { "log_prob_micro_batch_size_per_gpu": 8, "fsdp_config": {"param_offload": True}, }, # Common configs for the model "model": { "path": "Qwen/Qwen2.5-1.5B-Instruct", "use_remove_padding": True, "enable_gradient_checkpointing": True, }, }, "trainer": { "n_gpus_per_node": 1, "val_before_train": True, "logger": ["console", "wandb"], # Wandb for experiment tracking "project_name": "agent-framework-lab-lightning", "experiment_name": "tau2_agent", "nnodes": 1, "test_freq": 4, "total_epochs": 8, }, } patch_env_set_state() # Tau2-specific environment setup train_dataset, val_dataset = _load_dataset() # Key difference with math_agent: trained_agents parameter specifies which agents to train # Only the assistant agent is trained; user simulator remains fixed tau2_agent = Tau2Agent(trained_agents=ASSISTANT_AGENT_ID) tracer = AgentFrameworkTracer() trainer = Trainer(algorithm=VERL(rl_training_config), tracer=tracer, n_workers=4) trainer.fit(tau2_agent, train_dataset, val_dataset=val_dataset) def debug(): """Debug mode for testing multi-agent setup and Tau2 integration.""" train_dataset, _ = _load_dataset() tau2_agent = Tau2Agent(trained_agents=ASSISTANT_AGENT_ID) openai_base_url = os.getenv("OPENAI_BASE_URL") if openai_base_url is None: raise ValueError("OPENAI_BASE_URL must be set") patch_env_set_state() # Required for Tau2 environment # Test with resources dict (different from @rollout LLM parameter) asyncio.run( tau2_agent.rollout_async( train_dataset[0], resources={"main_llm": LLM(model="gpt-4.1", endpoint=openai_base_url)}, rollout=Rollout(rollout_id="dummy", input="dummy_input", start_time=time.time()), ) ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--debug", action="store_true") args = parser.parse_args() if args.debug: debug() else: main()