mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Add tests
This commit is contained in:
@@ -97,7 +97,7 @@ class Runner:
|
||||
yield event
|
||||
|
||||
# Create a checkpoint before a run starts. Checkpoints are usually considered to be created at the
|
||||
# end of an iteration, we can think of this checkpoint as being created at the end of a "superstep 0"
|
||||
# end of an iteration, we can think of this checkpoint as being created at the end of "superstep 0"
|
||||
# which captures the states after which the start executor has run. Note that we execute the start
|
||||
# executor outside of the main iteration loop.
|
||||
if await self._ctx.has_messages() and not self._resumed_from_checkpoint:
|
||||
|
||||
@@ -788,8 +788,7 @@ class Workflow(DictConvertible):
|
||||
# fully consumed, ``_run_core``'s finally clears the attribute. When the
|
||||
# caller drops the stream without iterating, garbage collection invalidates
|
||||
# the weakref, so a subsequent ``run`` is permitted.
|
||||
existing_stream = self._active_run() if self._active_run is not None else None
|
||||
if existing_stream is not None:
|
||||
if self._is_run_active():
|
||||
raise WorkflowException(
|
||||
"Workflow is already running; concurrent runs are not allowed on the same instance."
|
||||
)
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for `InProcRunnerContext`."""
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import (
|
||||
InProcRunnerContext,
|
||||
WorkflowEvent,
|
||||
WorkflowMessage,
|
||||
)
|
||||
|
||||
|
||||
def _make_request_info_event(request_id: str, source_executor_id: str = "executor") -> WorkflowEvent[str]:
|
||||
return WorkflowEvent.request_info(
|
||||
request_id=request_id,
|
||||
source_executor_id=source_executor_id,
|
||||
request_data="please respond",
|
||||
response_type=str,
|
||||
)
|
||||
|
||||
|
||||
class TestInProcRunnerContextResetForNewRun:
|
||||
"""Verify `reset_for_new_run` clears per-run state, including pending request_info events."""
|
||||
|
||||
async def test_reset_clears_pending_request_info_events(self) -> None:
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
await ctx.add_request_info_event(_make_request_info_event("req-1"))
|
||||
await ctx.add_request_info_event(_make_request_info_event("req-2"))
|
||||
|
||||
assert set((await ctx.get_pending_request_info_events()).keys()) == {"req-1", "req-2"}
|
||||
|
||||
ctx.reset_for_new_run()
|
||||
|
||||
assert await ctx.get_pending_request_info_events() == {}
|
||||
|
||||
async def test_reset_clears_pending_request_info_events_when_already_empty(self) -> None:
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
assert await ctx.get_pending_request_info_events() == {}
|
||||
|
||||
ctx.reset_for_new_run()
|
||||
|
||||
assert await ctx.get_pending_request_info_events() == {}
|
||||
|
||||
async def test_reset_after_pending_event_blocks_response_correlation(self) -> None:
|
||||
"""After `reset_for_new_run`, prior request ids must no longer correlate to a response."""
|
||||
ctx = InProcRunnerContext()
|
||||
await ctx.add_request_info_event(_make_request_info_event("req-1"))
|
||||
|
||||
ctx.reset_for_new_run()
|
||||
|
||||
with pytest.raises(ValueError, match="No pending request found for request_id: req-1"):
|
||||
await ctx.send_request_info_response("req-1", "answer")
|
||||
|
||||
async def test_reset_clears_messages_events_and_streaming_flag(self) -> None:
|
||||
"""Sanity-check the other state `reset_for_new_run` is documented to clear."""
|
||||
ctx = InProcRunnerContext()
|
||||
await ctx.send_message(WorkflowMessage(data="hello", source_id="executor"))
|
||||
await ctx.add_event(WorkflowEvent("status", data="running"))
|
||||
ctx.set_streaming(True)
|
||||
|
||||
assert await ctx.has_messages() is True
|
||||
assert await ctx.has_events() is True
|
||||
assert ctx.is_streaming() is True
|
||||
|
||||
ctx.reset_for_new_run()
|
||||
|
||||
assert await ctx.has_messages() is False
|
||||
assert await ctx.has_events() is False
|
||||
assert ctx.is_streaming() is False
|
||||
@@ -20,6 +20,7 @@ from agent_framework import (
|
||||
Content,
|
||||
Executor,
|
||||
FileCheckpointStorage,
|
||||
InMemoryCheckpointStorage,
|
||||
Message,
|
||||
ResponseStream,
|
||||
WorkflowBuilder,
|
||||
@@ -1353,3 +1354,112 @@ async def test_output_executors_filtering_with_run_responses_streaming() -> None
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Workflow.create_checkpoint
|
||||
|
||||
|
||||
class TestWorkflowCreateCheckpoint:
|
||||
"""Tests for :meth:`Workflow.create_checkpoint`."""
|
||||
|
||||
async def test_returns_checkpoint_id_with_runtime_storage(self, simple_executor: Executor) -> None:
|
||||
"""Calling `create_checkpoint` with a runtime storage persists a checkpoint and returns its id."""
|
||||
storage = InMemoryCheckpointStorage()
|
||||
workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build()
|
||||
|
||||
checkpoint_id = await workflow.create_checkpoint(storage)
|
||||
|
||||
assert checkpoint_id
|
||||
loaded = await storage.load(checkpoint_id)
|
||||
assert loaded is not None
|
||||
assert loaded.checkpoint_id == checkpoint_id
|
||||
assert loaded.workflow_name == workflow.name
|
||||
assert loaded.graph_signature_hash == workflow.graph_signature_hash
|
||||
|
||||
async def test_uses_buildtime_storage_when_none_provided(self, simple_executor: Executor) -> None:
|
||||
"""When called with `None`, the build-time storage is used."""
|
||||
storage = InMemoryCheckpointStorage()
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor=simple_executor, checkpoint_storage=storage)
|
||||
.add_edge(simple_executor, simple_executor)
|
||||
.build()
|
||||
)
|
||||
|
||||
checkpoint_id = await workflow.create_checkpoint(None)
|
||||
|
||||
loaded = await storage.load(checkpoint_id)
|
||||
assert loaded is not None
|
||||
assert loaded.checkpoint_id == checkpoint_id
|
||||
|
||||
async def test_raises_when_no_storage_available(self, simple_executor: Executor) -> None:
|
||||
"""Without build-time or runtime storage, `create_checkpoint(None)` raises."""
|
||||
workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build()
|
||||
|
||||
with pytest.raises(WorkflowCheckpointException, match="Checkpoint storage must be provided"):
|
||||
await workflow.create_checkpoint(None)
|
||||
|
||||
async def test_raises_while_run_active(self, simple_executor: Executor) -> None:
|
||||
"""`create_checkpoint` must reject while a workflow run is still active."""
|
||||
storage = InMemoryCheckpointStorage()
|
||||
workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build()
|
||||
|
||||
# Hold a live reference to a streaming run without iterating it so that
|
||||
# ``_is_run_active`` remains True (the active-run weakref still resolves).
|
||||
active_stream = workflow.run(WorkflowMessage(data="hi", source_id="test"), stream=True)
|
||||
try:
|
||||
with pytest.raises(WorkflowException, match="Cannot create checkpoint while a workflow run is active"):
|
||||
await workflow.create_checkpoint(storage)
|
||||
finally:
|
||||
# Drain the stream so the run completes cleanly and the active-run
|
||||
# weakref is cleared; otherwise pytest's asyncio teardown can leak
|
||||
# the unconsumed generator.
|
||||
async for _ in active_stream:
|
||||
pass
|
||||
|
||||
async def test_clears_runtime_storage_after_call(self, simple_executor: Executor) -> None:
|
||||
"""The runtime storage override must not leak past the call."""
|
||||
storage = InMemoryCheckpointStorage()
|
||||
workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build()
|
||||
|
||||
await workflow.create_checkpoint(storage)
|
||||
|
||||
assert workflow._runner.context.has_checkpointing() is False
|
||||
assert workflow._runner.context._runtime_checkpoint_storage is None # type: ignore[attr-defined]
|
||||
|
||||
async def test_clears_runtime_storage_after_failure(self, simple_executor: Executor) -> None:
|
||||
"""The runtime storage override must be cleared even if checkpoint creation fails."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
storage = InMemoryCheckpointStorage()
|
||||
workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build()
|
||||
|
||||
# The runner logs-and-swallows storage save errors, so a failed save
|
||||
# surfaces as the "Failed to create checkpoint." path when
|
||||
# ``previous_checkpoint_id`` remains ``None``. Either way, the
|
||||
# ``finally`` cleanup must still clear the runtime override.
|
||||
storage.save = AsyncMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign]
|
||||
|
||||
with pytest.raises(WorkflowCheckpointException, match="Failed to create checkpoint"):
|
||||
await workflow.create_checkpoint(storage)
|
||||
|
||||
assert workflow._runner.context._runtime_checkpoint_storage is None # type: ignore[attr-defined]
|
||||
|
||||
async def test_alters_lineage_for_next_checkpoint(self, simple_executor: Executor) -> None:
|
||||
"""A manually created checkpoint becomes the parent of the next checkpoint."""
|
||||
storage = InMemoryCheckpointStorage()
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor=simple_executor, checkpoint_storage=storage)
|
||||
.add_edge(simple_executor, simple_executor)
|
||||
.build()
|
||||
)
|
||||
|
||||
first_id = await workflow.create_checkpoint(None)
|
||||
second_id = await workflow.create_checkpoint(None)
|
||||
|
||||
assert first_id != second_id
|
||||
second = await storage.load(second_id)
|
||||
assert second is not None
|
||||
assert second.previous_checkpoint_id == first_id
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -3063,6 +3063,121 @@ class TestCheckpointContextPathValidation:
|
||||
assert new_turn_messages[0].text == "next turn"
|
||||
assert new_turn_call.kwargs["checkpoint_storage"].storage_path == (root / response_id).resolve()
|
||||
|
||||
async def test_handle_inner_workflow_restores_initial_checkpoint_when_no_context_id(self, tmp_path: Any) -> None:
|
||||
"""When neither previous_response_id nor conversation_id is supplied, the workflow
|
||||
must be restored from the initial checkpoint to avoid context bleed between requests.
|
||||
"""
|
||||
from agent_framework import WorkflowAgent
|
||||
from azure.ai.agentserver.responses import ResponseContext
|
||||
from azure.ai.agentserver.responses.models import CreateResponse, ItemMessage
|
||||
|
||||
response_id = "resp_current"
|
||||
root = tmp_path / "root"
|
||||
root.mkdir()
|
||||
|
||||
agent = MagicMock(spec=WorkflowAgent)
|
||||
agent.id = "wf-agent"
|
||||
agent.name = "wf"
|
||||
agent.description = ""
|
||||
agent.context_providers = []
|
||||
agent.workflow = MagicMock()
|
||||
agent.workflow.name = "wf"
|
||||
agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False)
|
||||
agent.workflow.create_checkpoint = AsyncMock(return_value="cp_initial")
|
||||
agent.run = AsyncMock(
|
||||
side_effect=[
|
||||
AgentResponse(messages=[]),
|
||||
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])]),
|
||||
]
|
||||
)
|
||||
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
|
||||
server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# No previous_response_id and no conversation_id.
|
||||
request = CreateResponse(model="m", input="hi")
|
||||
context = ResponseContext(response_id=response_id, mode_flags=MagicMock())
|
||||
input_item = ItemMessage({"type": "message", "role": "user", "content": "fresh turn"})
|
||||
|
||||
with patch.object(ResponseContext, "get_input_items", new=AsyncMock(return_value=[input_item])):
|
||||
async for _ in server._handle_inner_workflow(request, context): # pyright: ignore[reportPrivateUsage]
|
||||
pass
|
||||
|
||||
# The initial checkpoint must have been created exactly once, against the
|
||||
# initial checkpoint storage owned by the server.
|
||||
assert agent.workflow.create_checkpoint.await_count == 1
|
||||
(initial_storage_arg,) = agent.workflow.create_checkpoint.await_args.args
|
||||
assert initial_storage_arg is server._initial_checkpoint_storage # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# First run() call is the restoration: no positional input, restored from
|
||||
# the initial checkpoint id, using the initial checkpoint storage (NOT a
|
||||
# per-context directory).
|
||||
assert agent.run.call_count == 2
|
||||
restore_call = agent.run.call_args_list[0]
|
||||
assert restore_call.args == ()
|
||||
assert restore_call.kwargs["checkpoint_id"] == "cp_initial"
|
||||
assert restore_call.kwargs["checkpoint_storage"] is server._initial_checkpoint_storage # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Second run() call delivers the new input; checkpoints land under response_id
|
||||
# (the write-sink directory keyed by the current response id).
|
||||
new_turn_call = agent.run.call_args_list[1]
|
||||
new_turn_messages = new_turn_call.args[0]
|
||||
assert len(new_turn_messages) == 1
|
||||
assert new_turn_messages[0].text == "fresh turn"
|
||||
assert new_turn_call.kwargs["checkpoint_storage"].storage_path == (root / response_id).resolve()
|
||||
|
||||
async def test_handle_inner_workflow_creates_initial_checkpoint_once_across_requests(self, tmp_path: Any) -> None:
|
||||
"""The initial checkpoint must be created exactly once and reused across
|
||||
subsequent requests, regardless of whether the requests carry a context id.
|
||||
"""
|
||||
from agent_framework import WorkflowAgent
|
||||
from azure.ai.agentserver.responses import ResponseContext
|
||||
from azure.ai.agentserver.responses.models import CreateResponse, ItemMessage
|
||||
|
||||
root = tmp_path / "root"
|
||||
root.mkdir()
|
||||
|
||||
agent = MagicMock(spec=WorkflowAgent)
|
||||
agent.id = "wf-agent"
|
||||
agent.name = "wf"
|
||||
agent.description = ""
|
||||
agent.context_providers = []
|
||||
agent.workflow = MagicMock()
|
||||
agent.workflow.name = "wf"
|
||||
agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False)
|
||||
agent.workflow.create_checkpoint = AsyncMock(return_value="cp_initial")
|
||||
# Four run() calls total: restore + new turn for each of the two requests.
|
||||
agent.run = AsyncMock(return_value=AgentResponse(messages=[]))
|
||||
|
||||
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
|
||||
server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
request1 = CreateResponse(model="m", input="hi")
|
||||
context1 = ResponseContext(response_id="resp_first", mode_flags=MagicMock())
|
||||
request2 = CreateResponse(model="m", input="hi again")
|
||||
context2 = ResponseContext(response_id="resp_second", mode_flags=MagicMock())
|
||||
input_item = ItemMessage({"type": "message", "role": "user", "content": "turn"})
|
||||
|
||||
with patch.object(ResponseContext, "get_input_items", new=AsyncMock(return_value=[input_item])):
|
||||
async for _ in server._handle_inner_workflow(request1, context1): # pyright: ignore[reportPrivateUsage]
|
||||
pass
|
||||
async for _ in server._handle_inner_workflow(request2, context2): # pyright: ignore[reportPrivateUsage]
|
||||
pass
|
||||
|
||||
# Initial checkpoint creation must not be repeated on the second request.
|
||||
assert agent.workflow.create_checkpoint.await_count == 1
|
||||
|
||||
# Both requests' restoration calls must use the same initial checkpoint id
|
||||
# and the same initial checkpoint storage instance.
|
||||
restore_call_1 = agent.run.call_args_list[0]
|
||||
restore_call_2 = agent.run.call_args_list[2]
|
||||
assert restore_call_1.kwargs["checkpoint_id"] == "cp_initial"
|
||||
assert restore_call_2.kwargs["checkpoint_id"] == "cp_initial"
|
||||
assert (
|
||||
restore_call_1.kwargs["checkpoint_storage"]
|
||||
is restore_call_2.kwargs["checkpoint_storage"]
|
||||
is server._initial_checkpoint_storage # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bad_id",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user