Files
agent-framework/python/packages/lab/lightning/samples/train_math_agent.py
T
Eduard van Valkenburg 977c3adfb2 Python: replace pre-commit with prek, add PEP 723 script deps, clean up dev dependencies (#3748)
* python: replace pre-commit with prek, add PEP 723 script deps, clean up dev dependencies

- Replace pre-commit with prek (Rust-native, faster pre-commit alternative)
- Move supported hooks to repo: builtin for zero-clone speed
- Add new builtin hooks: trailing-whitespace, check-merge-conflict, detect-private-key, check-added-large-files
- Update all hook versions to latest (pre-commit-hooks v6, pyupgrade v3.21.2, bandit 1.9.3, uv-pre-commit 0.10.0)
- Add PEP 723 inline script metadata to 34 samples with external deps
- Remove autogen-agentchat/autogen-ext from dev deps (now declared per-sample)
- Remove unused dev deps: pytest-env, tomli-w
- Add agent-framework-core>=1.0.0b260130 lower bound to all 21 packages
- Update CI workflow to use j178/prek-action
- Update docs: DEV_SETUP.md, AGENTS.md, CODING_STANDARD.md, SAMPLE_GUIDELINES.md

* updated lock

* python: fix prek config paths for local execution and CI workflow

Remove global 'files: ^python/' filter and strip python/ prefix from all path patterns in .pre-commit-config.yaml so prek finds files when run from the python/ directory. Update CI workflow to use --cd python instead of --config path. Include trailing whitespace fixes and dev dependency cleanup.

* python: move helper scripts to scripts/ folder and exclude from checks

* python: exclude AGENTS.md from prek markdown code lint

* python: exclude AGENTS.md and azure_ai_search sample from markdown lint

* fix m365 sample

* python: ignore CPY rule for samples with PEP 723 headers

* fix in dev_setup

* python: replace aiofiles with regular open in samples

* python: suppress reportUnusedImport in markdown code block checker

* python: use samples pyright config for markdown code block checker

Write a temp pyrightconfig.json matching pyrightconfig.samples.json rules (typeCheckingMode=off, only reportMissingImports and reportAttributeAccessIssue). Filter output to only fail on these rules since syntax-level errors (top-level await, undefined vars) are expected in README documentation snippets.

* python: use markdown-code-lint with fixed globs instead of prek file list

The prek-markdown-code-lint task received all changed files including non-README markdown and files with pre-existing broken imports. Replace with the standard markdown-code-lint task which uses the correct glob patterns (README.md, packages/**/README.md, samples/**/*.md).

* python: exclude READMEs with pre-existing broken imports from markdown lint

* python: fix broken README code snippets instead of excluding them

- ag-ui: replace TextContent (removed) with content.type == 'text'
- durabletask: fix import path to durabletask.worker.TaskHubGrpcWorker
- orchestrations: use constructor params instead of .participants() method
- observability: mark deprecated code blocks as plain text, filter
  reportMissingImports to agent_framework modules only
- remove README excludes from markdown-code-lint task

* add revision to gaia download

* feat(python): parallelize checks across packages

Run (package × task) cross-product in parallel using ThreadPoolExecutor
and subprocesses. Key changes:

- Add scripts/task_runner.py with shared parallel execution engine
- Update run_tasks_in_packages_if_exists.py to accept multiple tasks
- Update run_tasks_in_changed_packages.py with --files flag and parallel support
- Add check-packages poe task (fmt+lint+pyright+mypy in parallel)
- Add prek-markdown-code-lint and prek-samples-check with change detection
- Split CI code quality workflow into parallel prek and mypy jobs
- Update DEV_SETUP.md to document new parallel behavior

Core package changes still trigger checks on all packages.

* feat(ci): split code quality into 4 parallel jobs

Split the single prek job into parallel jobs:
- pre-commit-hooks: lightweight hooks (SKIP=poe-check)
- package-checks: fmt/lint/pyright/mypy via check-packages
- samples-markdown: samples-lint, samples-syntax, markdown-code-lint
- mypy: change-detected mypy checks

All 4 jobs run concurrently (×2 Python versions = 8 runners).

* feat(ci): use only Python 3.10 for code quality checks

* refactor(python): add future annotations and remove quoted types

Add `from __future__ import annotations` to 93 package files that
used quoted string annotations, then run pyupgrade --py310-plus to
remove the now-unnecessary quotes.

Fixes https://github.com/microsoft/agent-framework/issues/3578
2026-02-09 17:51:01 +00:00

331 lines
12 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
"""This sample demonstrates the basic usage pattern of agent-framework-lab-lightning.
It trains a math agent using a dataset in `data/math/` to solve mathematical problems
using an MCP calculator tool.
One GPU with 40GB of memory is sufficient for this sample.
"""
from __future__ import annotations
import argparse
import asyncio
import json
import math
import os
import re
import string
from typing import TypedDict, cast
import sympy # type: ignore[import-untyped,reportMissingImports]
from agent_framework import AgentResponse, ChatAgent, MCPStdioTool
from agent_framework.lab.lightning import AgentFrameworkTracer
from agent_framework.openai import OpenAIChatClient
from agentlightning import LLM, Dataset, Trainer, rollout
from agentlightning.algorithm.verl import VERL
class MathProblem(TypedDict):
"""This TypedDict defines the structure of each training sample.
Your task structure should contain all the information needed for:
- The agent to process the task (e.g., 'question')
- Evaluation (e.g., 'result' for ground truth)
This type is optional. Not necessary to make the example work.
"""
# The fields come from the dataset
id: str
question: str # The math problem for the agent to solve
chain: str # Step-by-step solution (not used in training)
result: str # Ground truth answer for evaluation
source: str
def _load_jsonl(file_path: str) -> Dataset[MathProblem]:
"""Load your dataset as a list of task samples.
Each sample should match your task structure (MathProblem in this case).
"""
with open(file_path) as f:
raw_data = [MathProblem(**json.loads(line)) for line in f]
return cast(Dataset[MathProblem], raw_data)
# Evaluation logic
# These functions evaluate whether the agent's answer matches the ground truth.
# Robust evaluation is crucial for RL training - the reward signal guides learning.
def _normalize_option(option: str) -> str:
return re.sub(r"(\s+|\(|\))", "", option)
def _is_option_result(result: str) -> bool:
return _normalize_option(result) in list(string.ascii_letters)
def _float_eval(input_str: str) -> float:
if " = around " in input_str:
input_str = input_str.split(" = around ")[0]
expr = sympy.parse_expr(input_str, evaluate=True)
return float(expr.evalf())
def _scalar_are_results_same(pred_result: str, true_result: str, rel_tol: float) -> bool:
pred_result = str(pred_result) if pred_result is not None else ""
true_result = str(true_result) if true_result is not None else ""
if pred_result.strip() == true_result.strip():
return True
if _is_option_result(true_result):
# The task is to select correct option
true_result = _normalize_option(true_result)
pred_result = _normalize_option(pred_result)
return pred_result == true_result
# The task is to calculate the result as a number
try:
pred_float = _float_eval(pred_result)
true_float = _float_eval(true_result)
return math.isclose(pred_float, true_float, rel_tol=rel_tol)
except Exception: # noqa: S110
pass
return False
def _is_result_correct(prediction: str, ground_truth: str) -> float:
return float(_scalar_are_results_same(prediction, ground_truth, 1e-2))
def evaluate(result: AgentResponse, ground_truth: str) -> float:
"""Main evaluation function that extracts the agent's answer and compares with ground truth.
This function:
1. Extracts the final answer from the agent's response (after ###)
2. Compares it with the ground truth using mathematical equivalence
3. Returns a reward score (0.0 or 1.0) for RL training
The reward signal is critical - it directly influences what the model learns.
"""
# Check if agent provided any response
if len(result.messages) == 0:
print("No response from agent. Assuming incorrect.")
return 0.0
final_message = result.messages[-1].text
# Extract answer after ### marker (as specified in agent instructions)
answer = re.search(r"###\s*(.+?)(\s*###|$)", final_message)
if answer is None:
print("No answer can be extracted from agent's response. Assuming incorrect.")
return 0.0
answer = answer.group(1)
# Compare extracted answer with ground truth
reward = _is_result_correct(answer, ground_truth)
print(f"Reward: {reward}")
return reward
# Agent Logic
# Clear instructions are important for consistent agent behavior
# The ### format helps with reliable answer extraction during evaluation
AGENT_INSTRUCTION = """
Solve the following math problem. Use the calculator tool to help you calculate math expressions.
Output the answer when you are ready. The answer should be after three sharps (`###`), with no extra punctuations or texts. For example: ### 123
""".strip() # noqa: E501
# The @rollout decorator is the key integration point with agent-lightning.
# It tells the training system that this function defines a trainable agent.
@rollout
async def math_agent(task: MathProblem, llm: LLM) -> float:
"""This is your trainable agent function.
Key points:
1. Must be decorated with @rollout
2. Takes a task sample and LLM object as parameters
3. Returns a float reward score (0.0 to 1.0 typically)
4. The LLM object contains the model being trained and its configuration
During training:
- llm.model: The model checkpoint being trained
- llm.endpoint: vLLM server endpoint for inference
- llm.sampling_parameters: Temperature, etc.
"""
# Create the Agent Framework components
# MCPStdioTool provides calculator functionality via MCP protocol
async with (
MCPStdioTool(name="calculator", command="uvx", args=["mcp-server-calculator"]) as mcp_server,
ChatAgent(
chat_client=OpenAIChatClient(
model_id=llm.model, # This is the model being trained
api_key=os.getenv("OPENAI_API_KEY") or "dummy", # Can be dummy when connecting to training LLM
base_url=llm.endpoint, # vLLM server endpoint provided by agent-lightning
),
name="MathAgent",
instructions=AGENT_INSTRUCTION,
temperature=llm.sampling_parameters.get("temperature", 0.0),
) as agent,
):
print(f"Task: {task['question'][:10]}...")
# Run the agent on the task
result = await agent.run(task["question"], tools=mcp_server)
print(f"Agent responses: {result}")
# Evaluate and return reward - this is what drives RL training
return evaluate(result, task["result"])
def main():
"""Main entrypoint."""
# Configure RL training
# This configuration controls all aspects of the RL training process.
# Key sections: algorithm, data, rollout, actor, trainer
rl_training_config = {
"algorithm": {
# Advantage estimator type: "gae", "grpo", "reinforce_plus_plus", etc.
"adv_estimator": "grpo"
},
"data": {
# Uses this many tasks from the dataset to perform rollouts
"train_batch_size": 8,
# Used to filter out the over-long prompt-response pairs
"max_prompt_length": 4096,
"max_response_length": 1024,
},
"actor_rollout_ref": {
# Controls the rollout process
"rollout": {
# Set to 1 unless you want to use TP in multiple GPUs
"tensor_model_parallel_size": 1,
# Repeat each task N many times. Required by G(rouped)RPO
"n": 4,
# Controls the batch size per GPU when computing the log-prob
"log_prob_micro_batch_size_per_gpu": 2,
# Controls the multi-turn format (this is binded to the LLM used)
# See https://docs.vllm.ai/en/stable/features/tool_calling.html
"multi_turn": {"format": "hermes"},
# Only vllm is supported for now
"name": "vllm",
# Controls the GPU memory utilization of vLLM
# You might want to set this to under 0.8 to prevent OOM
"gpu_memory_utilization": 0.7,
},
"actor": {
# Split each sample into sub-batches of this size for PPO
"ppo_mini_batch_size": 8,
# Local per-GPU micro batch size
"ppo_micro_batch_size_per_gpu": 2,
# Optimizer configuration
"optim": {"lr": 1e-6},
# Whether to use KL loss during training
"use_kl_loss": False,
# PPO clipping ratios for policy updates
"clip_ratio_low": 0.2,
"clip_ratio_high": 0.3,
# FSDP (Fully Sharded Data Parallel) configuration for memory efficiency
# Useful when you don't have enough GPU memory
"fsdp_config": {
# Whether to offload parameters to CPU
"param_offload": True,
# Whether to offload optimizer state to CPU
"optimizer_offload": True,
},
},
# Reference model config
"ref": {
# Controls the batch size per GPU when computing log-prob for reference model
"log_prob_micro_batch_size_per_gpu": 2,
"fsdp_config": {"param_offload": True},
},
# Common configs for the model
"model": {
# Huggingface model path.
# If you want to train a different model, change the path here.
"path": "Qwen/Qwen2.5-1.5B-Instruct",
# Whether to remove padding tokens in inputs during training
"use_remove_padding": True,
# Enable gradient checkpointing for memory efficiency
"enable_gradient_checkpointing": True,
},
},
# Config for the trainer
"trainer": {
# Number of GPUs per node
"n_gpus_per_node": 1,
# Whether to run validation before training begins
"val_before_train": True,
# Logging backends to use: "console", "wandb", etc.
"logger": ["console"],
# Number of nodes used in the training
"nnodes": 1,
# Validation frequency (in training iterations)
"test_freq": 4,
# Number of epochs in training
"total_epochs": 2,
},
}
# Load your datasets
train_dataset = _load_jsonl("data/math/train.jsonl")
val_dataset = _load_jsonl("data/math/test.jsonl")
# Preview the data to ensure it's loaded correctly
print("First 5 rows of train dataset:")
for i in range(5):
print(train_dataset[i])
print("First 5 rows of val dataset:")
for i in range(5):
print(val_dataset[i])
# Create trainer with VERL algorithm and start training
# n_workers: Number of rollout workers (processes) for parallel data collection
trainer = Trainer(algorithm=VERL(rl_training_config), tracer=AgentFrameworkTracer(), n_workers=2)
# This starts the actual RL training loop:
# 1. Collect rollouts using current model
# 2. Compute advantages and train the model
# 3. Repeat for specified number of epochs
trainer.fit(math_agent, train_dataset, val_dataset=val_dataset)
def debug():
"""Debug mode allows you to test your agent function before training.
Always run debug mode first before starting expensive RL training!
"""
train_dataset = _load_jsonl("data/math/train.jsonl")
train_sample = train_dataset[0]
# Use a known good model for debugging (not the one being trained)
model = "gpt-4o-mini"
base_url = os.getenv("OPENAI_BASE_URL")
api_key = os.getenv("OPENAI_API_KEY")
if api_key is None:
raise ValueError("OPENAI_API_KEY must be set")
if base_url is None:
raise ValueError("OPENAI_BASE_URL must be set")
# Test your agent function with a sample task
asyncio.run(math_agent(train_sample, LLM(model=model, endpoint=base_url))) # type: ignore
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
if args.debug:
debug()
else:
main()