Add tests

This commit is contained in:
Tao Chen
2026-06-11 11:43:41 -07:00
Unverified
parent 9da83347c8
commit ed27241543
5 changed files with 299 additions and 3 deletions
@@ -97,7 +97,7 @@ class Runner:
yield event
# Create a checkpoint before a run starts. Checkpoints are usually considered to be created at the
# end of an iteration, we can think of this checkpoint as being created at the end of a "superstep 0"
# end of an iteration, we can think of this checkpoint as being created at the end of "superstep 0"
# which captures the states after which the start executor has run. Note that we execute the start
# executor outside of the main iteration loop.
if await self._ctx.has_messages() and not self._resumed_from_checkpoint:
@@ -788,8 +788,7 @@ class Workflow(DictConvertible):
# fully consumed, ``_run_core``'s finally clears the attribute. When the
# caller drops the stream without iterating, garbage collection invalidates
# the weakref, so a subsequent ``run`` is permitted.
existing_stream = self._active_run() if self._active_run is not None else None
if existing_stream is not None:
if self._is_run_active():
raise WorkflowException(
"Workflow is already running; concurrent runs are not allowed on the same instance."
)
@@ -0,0 +1,72 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for `InProcRunnerContext`."""
import pytest
from agent_framework import (
InProcRunnerContext,
WorkflowEvent,
WorkflowMessage,
)
def _make_request_info_event(request_id: str, source_executor_id: str = "executor") -> WorkflowEvent[str]:
return WorkflowEvent.request_info(
request_id=request_id,
source_executor_id=source_executor_id,
request_data="please respond",
response_type=str,
)
class TestInProcRunnerContextResetForNewRun:
"""Verify `reset_for_new_run` clears per-run state, including pending request_info events."""
async def test_reset_clears_pending_request_info_events(self) -> None:
ctx = InProcRunnerContext()
await ctx.add_request_info_event(_make_request_info_event("req-1"))
await ctx.add_request_info_event(_make_request_info_event("req-2"))
assert set((await ctx.get_pending_request_info_events()).keys()) == {"req-1", "req-2"}
ctx.reset_for_new_run()
assert await ctx.get_pending_request_info_events() == {}
async def test_reset_clears_pending_request_info_events_when_already_empty(self) -> None:
ctx = InProcRunnerContext()
assert await ctx.get_pending_request_info_events() == {}
ctx.reset_for_new_run()
assert await ctx.get_pending_request_info_events() == {}
async def test_reset_after_pending_event_blocks_response_correlation(self) -> None:
"""After `reset_for_new_run`, prior request ids must no longer correlate to a response."""
ctx = InProcRunnerContext()
await ctx.add_request_info_event(_make_request_info_event("req-1"))
ctx.reset_for_new_run()
with pytest.raises(ValueError, match="No pending request found for request_id: req-1"):
await ctx.send_request_info_response("req-1", "answer")
async def test_reset_clears_messages_events_and_streaming_flag(self) -> None:
"""Sanity-check the other state `reset_for_new_run` is documented to clear."""
ctx = InProcRunnerContext()
await ctx.send_message(WorkflowMessage(data="hello", source_id="executor"))
await ctx.add_event(WorkflowEvent("status", data="running"))
ctx.set_streaming(True)
assert await ctx.has_messages() is True
assert await ctx.has_events() is True
assert ctx.is_streaming() is True
ctx.reset_for_new_run()
assert await ctx.has_messages() is False
assert await ctx.has_events() is False
assert ctx.is_streaming() is False
@@ -20,6 +20,7 @@ from agent_framework import (
Content,
Executor,
FileCheckpointStorage,
InMemoryCheckpointStorage,
Message,
ResponseStream,
WorkflowBuilder,
@@ -1353,3 +1354,112 @@ async def test_output_executors_filtering_with_run_responses_streaming() -> None
# endregion
# region Workflow.create_checkpoint
class TestWorkflowCreateCheckpoint:
"""Tests for :meth:`Workflow.create_checkpoint`."""
async def test_returns_checkpoint_id_with_runtime_storage(self, simple_executor: Executor) -> None:
"""Calling `create_checkpoint` with a runtime storage persists a checkpoint and returns its id."""
storage = InMemoryCheckpointStorage()
workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build()
checkpoint_id = await workflow.create_checkpoint(storage)
assert checkpoint_id
loaded = await storage.load(checkpoint_id)
assert loaded is not None
assert loaded.checkpoint_id == checkpoint_id
assert loaded.workflow_name == workflow.name
assert loaded.graph_signature_hash == workflow.graph_signature_hash
async def test_uses_buildtime_storage_when_none_provided(self, simple_executor: Executor) -> None:
"""When called with `None`, the build-time storage is used."""
storage = InMemoryCheckpointStorage()
workflow = (
WorkflowBuilder(start_executor=simple_executor, checkpoint_storage=storage)
.add_edge(simple_executor, simple_executor)
.build()
)
checkpoint_id = await workflow.create_checkpoint(None)
loaded = await storage.load(checkpoint_id)
assert loaded is not None
assert loaded.checkpoint_id == checkpoint_id
async def test_raises_when_no_storage_available(self, simple_executor: Executor) -> None:
"""Without build-time or runtime storage, `create_checkpoint(None)` raises."""
workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build()
with pytest.raises(WorkflowCheckpointException, match="Checkpoint storage must be provided"):
await workflow.create_checkpoint(None)
async def test_raises_while_run_active(self, simple_executor: Executor) -> None:
"""`create_checkpoint` must reject while a workflow run is still active."""
storage = InMemoryCheckpointStorage()
workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build()
# Hold a live reference to a streaming run without iterating it so that
# ``_is_run_active`` remains True (the active-run weakref still resolves).
active_stream = workflow.run(WorkflowMessage(data="hi", source_id="test"), stream=True)
try:
with pytest.raises(WorkflowException, match="Cannot create checkpoint while a workflow run is active"):
await workflow.create_checkpoint(storage)
finally:
# Drain the stream so the run completes cleanly and the active-run
# weakref is cleared; otherwise pytest's asyncio teardown can leak
# the unconsumed generator.
async for _ in active_stream:
pass
async def test_clears_runtime_storage_after_call(self, simple_executor: Executor) -> None:
"""The runtime storage override must not leak past the call."""
storage = InMemoryCheckpointStorage()
workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build()
await workflow.create_checkpoint(storage)
assert workflow._runner.context.has_checkpointing() is False
assert workflow._runner.context._runtime_checkpoint_storage is None # type: ignore[attr-defined]
async def test_clears_runtime_storage_after_failure(self, simple_executor: Executor) -> None:
"""The runtime storage override must be cleared even if checkpoint creation fails."""
from unittest.mock import AsyncMock
storage = InMemoryCheckpointStorage()
workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build()
# The runner logs-and-swallows storage save errors, so a failed save
# surfaces as the "Failed to create checkpoint." path when
# ``previous_checkpoint_id`` remains ``None``. Either way, the
# ``finally`` cleanup must still clear the runtime override.
storage.save = AsyncMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign]
with pytest.raises(WorkflowCheckpointException, match="Failed to create checkpoint"):
await workflow.create_checkpoint(storage)
assert workflow._runner.context._runtime_checkpoint_storage is None # type: ignore[attr-defined]
async def test_alters_lineage_for_next_checkpoint(self, simple_executor: Executor) -> None:
"""A manually created checkpoint becomes the parent of the next checkpoint."""
storage = InMemoryCheckpointStorage()
workflow = (
WorkflowBuilder(start_executor=simple_executor, checkpoint_storage=storage)
.add_edge(simple_executor, simple_executor)
.build()
)
first_id = await workflow.create_checkpoint(None)
second_id = await workflow.create_checkpoint(None)
assert first_id != second_id
second = await storage.load(second_id)
assert second is not None
assert second.previous_checkpoint_id == first_id
# endregion
@@ -3063,6 +3063,121 @@ class TestCheckpointContextPathValidation:
assert new_turn_messages[0].text == "next turn"
assert new_turn_call.kwargs["checkpoint_storage"].storage_path == (root / response_id).resolve()
async def test_handle_inner_workflow_restores_initial_checkpoint_when_no_context_id(self, tmp_path: Any) -> None:
"""When neither previous_response_id nor conversation_id is supplied, the workflow
must be restored from the initial checkpoint to avoid context bleed between requests.
"""
from agent_framework import WorkflowAgent
from azure.ai.agentserver.responses import ResponseContext
from azure.ai.agentserver.responses.models import CreateResponse, ItemMessage
response_id = "resp_current"
root = tmp_path / "root"
root.mkdir()
agent = MagicMock(spec=WorkflowAgent)
agent.id = "wf-agent"
agent.name = "wf"
agent.description = ""
agent.context_providers = []
agent.workflow = MagicMock()
agent.workflow.name = "wf"
agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False)
agent.workflow.create_checkpoint = AsyncMock(return_value="cp_initial")
agent.run = AsyncMock(
side_effect=[
AgentResponse(messages=[]),
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])]),
]
)
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage]
# No previous_response_id and no conversation_id.
request = CreateResponse(model="m", input="hi")
context = ResponseContext(response_id=response_id, mode_flags=MagicMock())
input_item = ItemMessage({"type": "message", "role": "user", "content": "fresh turn"})
with patch.object(ResponseContext, "get_input_items", new=AsyncMock(return_value=[input_item])):
async for _ in server._handle_inner_workflow(request, context): # pyright: ignore[reportPrivateUsage]
pass
# The initial checkpoint must have been created exactly once, against the
# initial checkpoint storage owned by the server.
assert agent.workflow.create_checkpoint.await_count == 1
(initial_storage_arg,) = agent.workflow.create_checkpoint.await_args.args
assert initial_storage_arg is server._initial_checkpoint_storage # pyright: ignore[reportPrivateUsage]
# First run() call is the restoration: no positional input, restored from
# the initial checkpoint id, using the initial checkpoint storage (NOT a
# per-context directory).
assert agent.run.call_count == 2
restore_call = agent.run.call_args_list[0]
assert restore_call.args == ()
assert restore_call.kwargs["checkpoint_id"] == "cp_initial"
assert restore_call.kwargs["checkpoint_storage"] is server._initial_checkpoint_storage # pyright: ignore[reportPrivateUsage]
# Second run() call delivers the new input; checkpoints land under response_id
# (the write-sink directory keyed by the current response id).
new_turn_call = agent.run.call_args_list[1]
new_turn_messages = new_turn_call.args[0]
assert len(new_turn_messages) == 1
assert new_turn_messages[0].text == "fresh turn"
assert new_turn_call.kwargs["checkpoint_storage"].storage_path == (root / response_id).resolve()
async def test_handle_inner_workflow_creates_initial_checkpoint_once_across_requests(self, tmp_path: Any) -> None:
"""The initial checkpoint must be created exactly once and reused across
subsequent requests, regardless of whether the requests carry a context id.
"""
from agent_framework import WorkflowAgent
from azure.ai.agentserver.responses import ResponseContext
from azure.ai.agentserver.responses.models import CreateResponse, ItemMessage
root = tmp_path / "root"
root.mkdir()
agent = MagicMock(spec=WorkflowAgent)
agent.id = "wf-agent"
agent.name = "wf"
agent.description = ""
agent.context_providers = []
agent.workflow = MagicMock()
agent.workflow.name = "wf"
agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False)
agent.workflow.create_checkpoint = AsyncMock(return_value="cp_initial")
# Four run() calls total: restore + new turn for each of the two requests.
agent.run = AsyncMock(return_value=AgentResponse(messages=[]))
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage]
request1 = CreateResponse(model="m", input="hi")
context1 = ResponseContext(response_id="resp_first", mode_flags=MagicMock())
request2 = CreateResponse(model="m", input="hi again")
context2 = ResponseContext(response_id="resp_second", mode_flags=MagicMock())
input_item = ItemMessage({"type": "message", "role": "user", "content": "turn"})
with patch.object(ResponseContext, "get_input_items", new=AsyncMock(return_value=[input_item])):
async for _ in server._handle_inner_workflow(request1, context1): # pyright: ignore[reportPrivateUsage]
pass
async for _ in server._handle_inner_workflow(request2, context2): # pyright: ignore[reportPrivateUsage]
pass
# Initial checkpoint creation must not be repeated on the second request.
assert agent.workflow.create_checkpoint.await_count == 1
# Both requests' restoration calls must use the same initial checkpoint id
# and the same initial checkpoint storage instance.
restore_call_1 = agent.run.call_args_list[0]
restore_call_2 = agent.run.call_args_list[2]
assert restore_call_1.kwargs["checkpoint_id"] == "cp_initial"
assert restore_call_2.kwargs["checkpoint_id"] == "cp_initial"
assert (
restore_call_1.kwargs["checkpoint_storage"]
is restore_call_2.kwargs["checkpoint_storage"]
is server._initial_checkpoint_storage # pyright: ignore[reportPrivateUsage]
)
@pytest.mark.parametrize(
"bad_id",
[