Address comments

This commit is contained in:
Tao Chen
2026-06-11 15:27:36 -07:00
Unverified
parent ed27241543
commit b0d0224ed4
5 changed files with 110 additions and 6 deletions
@@ -403,8 +403,9 @@ class InProcRunnerContext:
def reset_for_new_run(self) -> None:
"""Reset the context for a new workflow run.
This clears messages, events, and resets streaming flag.
Runtime checkpoint storage is NOT cleared here as it's managed at the workflow level.
Clears messages, the pending event queue, the pending request_info
correlation map, and the streaming flag. Runtime checkpoint storage is
NOT cleared here as it's managed at the workflow level.
"""
self._messages.clear()
# Clear any pending events (best-effort) by recreating the queue
@@ -503,11 +503,18 @@ class Workflow(DictConvertible):
if checkpoint_storage is not None:
self._runner.context.set_runtime_checkpoint_storage(checkpoint_storage)
# Capture the runner's checkpoint id before attempting to save. The runner
# log-and-swallows storage save errors and only updates
# ``previous_checkpoint_id`` on success, so a failed save would otherwise
# leave the prior id in place and we'd return it as if a fresh checkpoint
# had been created.
previous_id_before = self._runner.previous_checkpoint_id
try:
await self._runner.create_checkpoint_if_enabled()
if self._runner.previous_checkpoint_id is None:
new_id = self._runner.previous_checkpoint_id
if new_id is None or new_id == previous_id_before:
raise WorkflowCheckpointException("Failed to create checkpoint.")
return self._runner.previous_checkpoint_id
return new_id
finally:
self._runner.context.clear_runtime_checkpoint_storage()
@@ -1461,5 +1461,36 @@ class TestWorkflowCreateCheckpoint:
assert second is not None
assert second.previous_checkpoint_id == first_id
async def test_raises_when_save_fails_after_prior_success(self, simple_executor: Executor) -> None:
"""A failed save after an earlier successful checkpoint must not return the stale id.
The runner log-and-swallows storage save errors and only updates
``previous_checkpoint_id`` on success. Without an explicit transition check,
``create_checkpoint`` would silently return the previously stored id as if a
new checkpoint had been created.
"""
from unittest.mock import AsyncMock
storage = InMemoryCheckpointStorage()
workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build()
# First call succeeds and seeds ``previous_checkpoint_id``.
first_id = await workflow.create_checkpoint(storage)
assert first_id
# Second call fails to save, so the runner leaves ``previous_checkpoint_id``
# pointing at ``first_id``. The method must detect that the id did not
# transition and raise instead of returning the stale value.
original_save = storage.save
storage.save = AsyncMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign]
try:
with pytest.raises(WorkflowCheckpointException, match="Failed to create checkpoint"):
await workflow.create_checkpoint(storage)
finally:
storage.save = original_save # type: ignore[method-assign]
# The runner's bookkeeping is unchanged after the failed call.
assert workflow._runner.previous_checkpoint_id == first_id # type: ignore[attr-defined]
# endregion
@@ -616,10 +616,15 @@ class ResponsesHostServer(ResponsesAgentServerHost):
latest_checkpoint_id: str = self._initial_checkpoint_id
restore_storage: FileCheckpointStorage = self._initial_checkpoint_storage
if context_id is not None:
restore_storage = _checkpoint_storage_for_context(self._checkpoint_storage_path, context_id)
latest_checkpoint = await restore_storage.get_latest(workflow_name=self._agent.workflow.name)
context_storage = _checkpoint_storage_for_context(self._checkpoint_storage_path, context_id)
latest_checkpoint = await context_storage.get_latest(workflow_name=self._agent.workflow.name)
if latest_checkpoint is not None:
# Only switch the restore storage when a checkpoint was actually
# found under the per-context directory. Otherwise the initial
# checkpoint id would not resolve in `context_storage` and the
# restore call below would fail.
latest_checkpoint_id = latest_checkpoint.checkpoint_id
restore_storage = context_storage
# Restore the workflow to the latest checkpoint and run it with the
# new input. Events (including request info events) will not be emitted
@@ -3178,6 +3178,66 @@ class TestCheckpointContextPathValidation:
is server._initial_checkpoint_storage # pyright: ignore[reportPrivateUsage]
)
async def test_handle_inner_workflow_falls_back_to_initial_storage_when_context_dir_is_empty(
self, tmp_path: Any
) -> None:
"""When ``previous_response_id`` is supplied but its checkpoint directory has no
checkpoints, the restoration must fall back to BOTH the initial checkpoint id
and the initial checkpoint storage. Otherwise the initial id would be looked up
inside the per-context storage where it does not exist, and the restore would
fail.
"""
from agent_framework import WorkflowAgent
from azure.ai.agentserver.responses import ResponseContext
from azure.ai.agentserver.responses.models import CreateResponse, ItemMessage
previous_response_id = "resp_previous"
response_id = "resp_current"
root = tmp_path / "root"
root.mkdir()
# The per-context storage exists but contains no checkpoints.
(root / previous_response_id).mkdir()
agent = MagicMock(spec=WorkflowAgent)
agent.id = "wf-agent"
agent.name = "wf"
agent.description = ""
agent.context_providers = []
agent.workflow = MagicMock()
agent.workflow.name = "wf"
agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False)
agent.workflow.create_checkpoint = AsyncMock(return_value="cp_initial")
agent.run = AsyncMock(
side_effect=[
AgentResponse(messages=[]),
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])]),
]
)
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage]
request = CreateResponse(model="m", input="hi", previous_response_id=previous_response_id)
context = ResponseContext(
response_id=response_id, previous_response_id=previous_response_id, mode_flags=MagicMock()
)
input_item = ItemMessage({"type": "message", "role": "user", "content": "next turn"})
with patch.object(ResponseContext, "get_input_items", new=AsyncMock(return_value=[input_item])):
async for _ in server._handle_inner_workflow(request, context): # pyright: ignore[reportPrivateUsage]
pass
# The restoration call must use the initial id AND the initial storage,
# not the empty per-context storage. Mismatching the two would attempt
# to load ``cp_initial`` from a directory that doesn't contain it.
assert agent.run.call_count == 2
restore_call = agent.run.call_args_list[0]
assert restore_call.kwargs["checkpoint_id"] == "cp_initial"
assert restore_call.kwargs["checkpoint_storage"] is server._initial_checkpoint_storage # pyright: ignore[reportPrivateUsage]
# The new turn still writes checkpoints under the current response id.
new_turn_call = agent.run.call_args_list[1]
assert new_turn_call.kwargs["checkpoint_storage"].storage_path == (root / response_id).resolve()
@pytest.mark.parametrize(
"bad_id",
[