mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
977c3adfb2
* python: replace pre-commit with prek, add PEP 723 script deps, clean up dev dependencies - Replace pre-commit with prek (Rust-native, faster pre-commit alternative) - Move supported hooks to repo: builtin for zero-clone speed - Add new builtin hooks: trailing-whitespace, check-merge-conflict, detect-private-key, check-added-large-files - Update all hook versions to latest (pre-commit-hooks v6, pyupgrade v3.21.2, bandit 1.9.3, uv-pre-commit 0.10.0) - Add PEP 723 inline script metadata to 34 samples with external deps - Remove autogen-agentchat/autogen-ext from dev deps (now declared per-sample) - Remove unused dev deps: pytest-env, tomli-w - Add agent-framework-core>=1.0.0b260130 lower bound to all 21 packages - Update CI workflow to use j178/prek-action - Update docs: DEV_SETUP.md, AGENTS.md, CODING_STANDARD.md, SAMPLE_GUIDELINES.md * updated lock * python: fix prek config paths for local execution and CI workflow Remove global 'files: ^python/' filter and strip python/ prefix from all path patterns in .pre-commit-config.yaml so prek finds files when run from the python/ directory. Update CI workflow to use --cd python instead of --config path. Include trailing whitespace fixes and dev dependency cleanup. * python: move helper scripts to scripts/ folder and exclude from checks * python: exclude AGENTS.md from prek markdown code lint * python: exclude AGENTS.md and azure_ai_search sample from markdown lint * fix m365 sample * python: ignore CPY rule for samples with PEP 723 headers * fix in dev_setup * python: replace aiofiles with regular open in samples * python: suppress reportUnusedImport in markdown code block checker * python: use samples pyright config for markdown code block checker Write a temp pyrightconfig.json matching pyrightconfig.samples.json rules (typeCheckingMode=off, only reportMissingImports and reportAttributeAccessIssue). Filter output to only fail on these rules since syntax-level errors (top-level await, undefined vars) are expected in README documentation snippets. * python: use markdown-code-lint with fixed globs instead of prek file list The prek-markdown-code-lint task received all changed files including non-README markdown and files with pre-existing broken imports. Replace with the standard markdown-code-lint task which uses the correct glob patterns (README.md, packages/**/README.md, samples/**/*.md). * python: exclude READMEs with pre-existing broken imports from markdown lint * python: fix broken README code snippets instead of excluding them - ag-ui: replace TextContent (removed) with content.type == 'text' - durabletask: fix import path to durabletask.worker.TaskHubGrpcWorker - orchestrations: use constructor params instead of .participants() method - observability: mark deprecated code blocks as plain text, filter reportMissingImports to agent_framework modules only - remove README excludes from markdown-code-lint task * add revision to gaia download * feat(python): parallelize checks across packages Run (package × task) cross-product in parallel using ThreadPoolExecutor and subprocesses. Key changes: - Add scripts/task_runner.py with shared parallel execution engine - Update run_tasks_in_packages_if_exists.py to accept multiple tasks - Update run_tasks_in_changed_packages.py with --files flag and parallel support - Add check-packages poe task (fmt+lint+pyright+mypy in parallel) - Add prek-markdown-code-lint and prek-samples-check with change detection - Split CI code quality workflow into parallel prek and mypy jobs - Update DEV_SETUP.md to document new parallel behavior Core package changes still trigger checks on all packages. * feat(ci): split code quality into 4 parallel jobs Split the single prek job into parallel jobs: - pre-commit-hooks: lightweight hooks (SKIP=poe-check) - package-checks: fmt/lint/pyright/mypy via check-packages - samples-markdown: samples-lint, samples-syntax, markdown-code-lint - mypy: change-detected mypy checks All 4 jobs run concurrently (×2 Python versions = 8 runners). * feat(ci): use only Python 3.10 for code quality checks * refactor(python): add future annotations and remove quoted types Add `from __future__ import annotations` to 93 package files that used quoted string annotations, then run pyupgrade --py310-plus to remove the now-unnecessary quotes. Fixes https://github.com/microsoft/agent-framework/issues/3578
232 lines
8.6 KiB
Python
232 lines
8.6 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
"""Advanced example showing multi-agent RL training using Tau2 benchmark.
|
|
|
|
This demonstrates:
|
|
- LitAgent class-based approach (vs @rollout decorator)
|
|
- Multi-agent scenarios with agent filtering
|
|
- Resource management for complex setups
|
|
- Integration with external benchmarks
|
|
|
|
Builds on concepts from train_math_agent.py with additional complexity.
|
|
Requires one GPU of at least 80GB of memory.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import random
|
|
import time
|
|
import traceback
|
|
from pathlib import Path
|
|
from typing import TypedDict, cast
|
|
|
|
from agent_framework.lab.lightning import AgentFrameworkTracer
|
|
from agent_framework.lab.tau2 import ASSISTANT_AGENT_ID, patch_env_set_state # type: ignore
|
|
from agent_framework.lab.tau2 import TaskRunner as Tau2TaskRunner # type: ignore
|
|
from agent_framework.openai import OpenAIChatClient
|
|
from agentlightning import LLM, Dataset, LitAgent, NamedResources, Rollout, Trainer
|
|
from agentlightning.algorithm.verl import VERL
|
|
from tau2.data_model.tasks import Task as Tau2Task # type: ignore[import-untyped]
|
|
|
|
|
|
# Tau2 tasks are complex objects that need special handling during distributed training
|
|
class SerializedTask(TypedDict):
|
|
"""Tau2 task object type."""
|
|
|
|
id: str
|
|
data: str # JSON-serialized task data to prevent HuggingFace conversion issues
|
|
|
|
|
|
def _load_dataset() -> tuple[Dataset[SerializedTask], Dataset[SerializedTask]]:
|
|
"""Load and prepare Tau2 dataset with proper serialization.
|
|
|
|
It takes external data dependency (TAU2_DATA_DIR) and uses deterministic train/val split for reproducibility.
|
|
"""
|
|
data_dir = os.getenv("TAU2_DATA_DIR")
|
|
if data_dir is None:
|
|
raise ValueError("TAU2_DATA_DIR must be set")
|
|
tasks_path = Path(data_dir) / "tau2/domains/airline/tasks.json"
|
|
with tasks_path.open("r") as f:
|
|
dataset = json.load(f)
|
|
|
|
# Serialize complex task objects to prevent HuggingFace tokenizer issues
|
|
dataset = [{"id": task["id"], "data": json.dumps(task)} for task in dataset]
|
|
|
|
# Deterministic train/val split (25/25) for reproducible experiments
|
|
random_state = random.Random(42) # noqa: S311
|
|
indices = list(range(len(dataset)))
|
|
random_state.shuffle(indices)
|
|
train_indices = indices[: int(len(dataset) * 0.5)]
|
|
val_indices = indices[int(len(dataset) * 0.5) :]
|
|
print(f"Train indices: {train_indices}")
|
|
print(f"Val indices: {val_indices}")
|
|
train_dataset = [dataset[i] for i in train_indices]
|
|
val_dataset = [dataset[i] for i in val_indices]
|
|
|
|
return cast(Dataset[SerializedTask], train_dataset), cast(Dataset[SerializedTask], val_dataset)
|
|
|
|
|
|
# Alternative to @rollout: LitAgent class for advanced scenarios
|
|
# Use this approach when you need:
|
|
# - Agent filtering (training only specific agents in multi-agent setup)
|
|
# - Resource management (multiple LLMs, databases, etc.)
|
|
# - Complex initialization logic
|
|
class Tau2Agent(LitAgent):
|
|
"""Class-based agent with advanced resource management and agent filtering."""
|
|
|
|
async def rollout_async(self, task: SerializedTask, resources: NamedResources, rollout: Rollout) -> float:
|
|
"""The main rollout method. Similar to @rollout but with more control."""
|
|
llm = resources.get("main_llm")
|
|
if not isinstance(llm, LLM):
|
|
raise ValueError("main_llm must be an instance of LLM")
|
|
|
|
openai_base_url = os.getenv("OPENAI_BASE_URL")
|
|
openai_api_key = os.getenv("OPENAI_API_KEY")
|
|
if openai_base_url is None:
|
|
raise ValueError("OPENAI_BASE_URL must be set")
|
|
if openai_api_key is None:
|
|
raise ValueError("OPENAI_API_KEY must be set")
|
|
|
|
# Deserialize the complex task object
|
|
task_data = json.loads(task["data"])
|
|
task_obj = Tau2Task(**task_data)
|
|
|
|
# Multi-agent setup: assistant (trainable) + user simulator (fixed)
|
|
runner = Tau2TaskRunner(
|
|
max_steps=100,
|
|
assistant_window_size=4000,
|
|
assistant_sampling_temperature=llm.sampling_parameters.get("temperature", 0.0),
|
|
)
|
|
|
|
# Assistant agent: uses the model being trained
|
|
assistant_chat_client = OpenAIChatClient(
|
|
base_url=llm.endpoint, # vLLM endpoint for the model being trained
|
|
api_key=openai_api_key,
|
|
model_id=llm.model, # Model ID being trained
|
|
)
|
|
|
|
# User simulator: uses a fixed, capable model for consistent simulation
|
|
user_simulator_chat_client = OpenAIChatClient(
|
|
base_url=openai_base_url, # External API endpoint
|
|
api_key=openai_api_key,
|
|
model_id="gpt-4.1", # Fixed model for user simulator
|
|
)
|
|
|
|
try:
|
|
# Run the multi-agent conversation
|
|
conversation = await runner.run(task_obj, assistant_chat_client, user_simulator_chat_client)
|
|
except Exception:
|
|
# Handle failures gracefully - assign low reward to discourage problematic behavior
|
|
# Common issues: tool calling errors, timeout, invalid responses
|
|
traceback.print_exc()
|
|
return 0.0
|
|
|
|
# Use Tau2's built-in evaluation metrics
|
|
evaluation = runner.evaluate(task_obj, conversation, runner.termination_reason)
|
|
|
|
# Return the evaluation score
|
|
return evaluation # noqa: RET504
|
|
|
|
|
|
def main():
|
|
"""Main entrypoint."""
|
|
# RL config with higher resource requirements and W&B logging
|
|
rl_training_config = {
|
|
"algorithm": {"adv_estimator": "grpo"},
|
|
"data": {
|
|
"train_batch_size": 8,
|
|
"max_prompt_length": 8192,
|
|
"max_response_length": 2048,
|
|
},
|
|
"actor_rollout_ref": {
|
|
"rollout": {
|
|
"tensor_model_parallel_size": 1,
|
|
"n": 8, # Higher repetition for more data per task
|
|
"log_prob_micro_batch_size_per_gpu": 4,
|
|
"multi_turn": {"format": "hermes"},
|
|
"name": "vllm",
|
|
"gpu_memory_utilization": 0.8, # Higher utilization for 80GB GPU
|
|
},
|
|
"actor": {
|
|
"ppo_mini_batch_size": 8,
|
|
"ppo_micro_batch_size_per_gpu": 4,
|
|
"optim": {"lr": 1e-6},
|
|
"use_kl_loss": False,
|
|
"clip_ratio_low": 0.2,
|
|
"clip_ratio_high": 0.3,
|
|
"fsdp_config": {
|
|
"param_offload": True,
|
|
"optimizer_offload": True,
|
|
},
|
|
},
|
|
# Reference model config
|
|
"ref": {
|
|
"log_prob_micro_batch_size_per_gpu": 8,
|
|
"fsdp_config": {"param_offload": True},
|
|
},
|
|
# Common configs for the model
|
|
"model": {
|
|
"path": "Qwen/Qwen2.5-1.5B-Instruct",
|
|
"use_remove_padding": True,
|
|
"enable_gradient_checkpointing": True,
|
|
},
|
|
},
|
|
"trainer": {
|
|
"n_gpus_per_node": 1,
|
|
"val_before_train": True,
|
|
"logger": ["console", "wandb"], # Wandb for experiment tracking
|
|
"project_name": "agent-framework-lab-lightning",
|
|
"experiment_name": "tau2_agent",
|
|
"nnodes": 1,
|
|
"test_freq": 4,
|
|
"total_epochs": 8,
|
|
},
|
|
}
|
|
|
|
patch_env_set_state() # Tau2-specific environment setup
|
|
|
|
train_dataset, val_dataset = _load_dataset()
|
|
|
|
# Key difference with math_agent: trained_agents parameter specifies which agents to train
|
|
# Only the assistant agent is trained; user simulator remains fixed
|
|
tau2_agent = Tau2Agent(trained_agents=ASSISTANT_AGENT_ID)
|
|
|
|
tracer = AgentFrameworkTracer()
|
|
trainer = Trainer(algorithm=VERL(rl_training_config), tracer=tracer, n_workers=4)
|
|
trainer.fit(tau2_agent, train_dataset, val_dataset=val_dataset)
|
|
|
|
|
|
def debug():
|
|
"""Debug mode for testing multi-agent setup and Tau2 integration."""
|
|
train_dataset, _ = _load_dataset()
|
|
tau2_agent = Tau2Agent(trained_agents=ASSISTANT_AGENT_ID)
|
|
|
|
openai_base_url = os.getenv("OPENAI_BASE_URL")
|
|
if openai_base_url is None:
|
|
raise ValueError("OPENAI_BASE_URL must be set")
|
|
|
|
patch_env_set_state() # Required for Tau2 environment
|
|
|
|
# Test with resources dict (different from @rollout LLM parameter)
|
|
asyncio.run(
|
|
tau2_agent.rollout_async(
|
|
train_dataset[0],
|
|
resources={"main_llm": LLM(model="gpt-4.1", endpoint=openai_base_url)},
|
|
rollout=Rollout(rollout_id="dummy", input="dummy_input", start_time=time.time()),
|
|
)
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--debug", action="store_true")
|
|
args = parser.parse_args()
|
|
if args.debug:
|
|
debug()
|
|
else:
|
|
main()
|