Files
agent-framework/python/packages/lab/lightning/tests/test_lightning.py
T
Yuge Zhang b8df0cd03f Python: Add Agent Framework Lab Lightning package with RL training examples (#937)
* add math agent

* .

* update

* update debug mode

* add tau2 training

* .

* .

* .

* .

* add tests

* .

* revert observability

* update readme

* fix task serialization issue

* fix exception

* add inline docs

* update readme

* update pyproject toml

* minor fix

* update and use git lfs

* update

* update ignore file to use lab specific

* fix type

* update depedency

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
2025-09-30 01:18:49 +00:00

164 lines
5.2 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
"""Tests for lightning module."""
from unittest.mock import AsyncMock, patch
import pytest
from agent_framework import (
AgentExecutor,
ChatAgent,
WorkflowBuilder,
)
from agent_framework.openai import OpenAIChatClient
from agent_framework_lab_lightning import init
from agentlightning.adapter import TraceTripletAdapter
from agentlightning.tracer import AgentOpsTracer
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
def test_import():
"""Test that the module can be imported."""
@pytest.fixture
def workflow_two_agents():
"""Test a workflow with two OpenAI chat agents where first agent's result passes to second agent."""
# Mock OpenAI responses
first_agent_response = ChatCompletion(
id="chatcmpl-123",
object="chat.completion",
created=1677652288,
model="gpt-4o",
choices=[
Choice(
index=0,
message=ChatCompletionMessage(role="assistant", content="Analyzed data shows trend upward"),
finish_reason="stop",
)
],
)
second_agent_response = ChatCompletion(
id="chatcmpl-456",
object="chat.completion",
created=1677652289,
model="gpt-4o",
choices=[
Choice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content="Based on the analysis 'Analyzed data shows trend upward', I recommend investing",
),
finish_reason="stop",
)
],
)
# Create mock OpenAI clients
with patch.dict(
"os.environ",
{
"OPENAI_API_KEY": "test-key",
"OPENAI_CHAT_MODEL_ID": "gpt-4o",
},
):
first_chat_client = OpenAIChatClient()
second_chat_client = OpenAIChatClient()
# Mock the OpenAI API calls
with (
patch.object(
first_chat_client.client.chat.completions,
"create",
new_callable=AsyncMock,
return_value=first_agent_response,
),
patch.object(
second_chat_client.client.chat.completions,
"create",
new_callable=AsyncMock,
return_value=second_agent_response,
),
):
# Create the two agents
analyzer_agent = ChatAgent(
chat_client=first_chat_client,
name="DataAnalyzer",
instructions="You are a data analyst. Analyze the given data and provide insights.",
)
advisor_agent = ChatAgent(
chat_client=second_chat_client,
name="InvestmentAdvisor",
instructions="You are an investment advisor. Based on analysis results, provide recommendations.",
)
analyzer_executor = AgentExecutor(id="analyzer", agent=analyzer_agent)
advisor_executor = AgentExecutor(id="advisor", agent=advisor_agent)
# Build workflow: analyzer -> advisor
workflow = (
WorkflowBuilder()
.set_start_executor(analyzer_executor)
.add_edge(analyzer_executor, advisor_executor)
.build()
)
yield workflow
@pytest.mark.asyncio
async def test_openai_workflow_two_agents(workflow_two_agents):
events = await workflow_two_agents.run("Please analyze the quarterly sales data")
assert "Based on the analysis 'Analyzed data shows trend upward', I recommend investing" in events.get_outputs()
@pytest.mark.asyncio
async def test_observability(workflow_two_agents):
r"""Expected trace tree:
[workflow.run]
/ \
[analyzer] [advisor]
/ \ / \
[DataAnalyzer] [send] [Investment] [send]
| |
[chat gpt-4o] [chat gpt-4o]
"""
init()
tracer = AgentOpsTracer()
try:
tracer.init()
tracer.init_worker(0)
with tracer.trace_context():
await workflow_two_agents.run("Please analyze the quarterly sales data")
triplets = TraceTripletAdapter(agent_match=None, llm_call_match="chat").adapt(tracer.get_last_trace())
assert len(triplets) == 2
triplets = TraceTripletAdapter(agent_match="analyzer", llm_call_match="chat").adapt(tracer.get_last_trace())
assert len(triplets) == 1
triplets = TraceTripletAdapter(agent_match="advisor", llm_call_match="chat").adapt(tracer.get_last_trace())
assert len(triplets) == 1
# Parent agent is not matched
triplets = TraceTripletAdapter(agent_match="DataAnalyzer", llm_call_match="chat").adapt(tracer.get_last_trace())
assert len(triplets) == 0
triplets = TraceTripletAdapter(agent_match="InvestmentAdvisor|advisor", llm_call_match="chat").adapt(
tracer.get_last_trace()
)
assert len(triplets) == 1
finally:
tracer.teardown_worker(0)
tracer.teardown()