diff --git a/python/packages/core/agent_framework/_workflows/_runner_context.py b/python/packages/core/agent_framework/_workflows/_runner_context.py index 23a748f0be..30f8295524 100644 --- a/python/packages/core/agent_framework/_workflows/_runner_context.py +++ b/python/packages/core/agent_framework/_workflows/_runner_context.py @@ -403,8 +403,9 @@ class InProcRunnerContext: def reset_for_new_run(self) -> None: """Reset the context for a new workflow run. - This clears messages, events, and resets streaming flag. - Runtime checkpoint storage is NOT cleared here as it's managed at the workflow level. + Clears messages, the pending event queue, the pending request_info + correlation map, and the streaming flag. Runtime checkpoint storage is + NOT cleared here as it's managed at the workflow level. """ self._messages.clear() # Clear any pending events (best-effort) by recreating the queue diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index 8127e6dc01..09ff6c28c1 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -503,11 +503,18 @@ class Workflow(DictConvertible): if checkpoint_storage is not None: self._runner.context.set_runtime_checkpoint_storage(checkpoint_storage) + # Capture the runner's checkpoint id before attempting to save. The runner + # log-and-swallows storage save errors and only updates + # ``previous_checkpoint_id`` on success, so a failed save would otherwise + # leave the prior id in place and we'd return it as if a fresh checkpoint + # had been created. + previous_id_before = self._runner.previous_checkpoint_id try: await self._runner.create_checkpoint_if_enabled() - if self._runner.previous_checkpoint_id is None: + new_id = self._runner.previous_checkpoint_id + if new_id is None or new_id == previous_id_before: raise WorkflowCheckpointException("Failed to create checkpoint.") - return self._runner.previous_checkpoint_id + return new_id finally: self._runner.context.clear_runtime_checkpoint_storage() diff --git a/python/packages/core/tests/workflow/test_workflow.py b/python/packages/core/tests/workflow/test_workflow.py index ce23cfda3d..9884088cdf 100644 --- a/python/packages/core/tests/workflow/test_workflow.py +++ b/python/packages/core/tests/workflow/test_workflow.py @@ -1461,5 +1461,36 @@ class TestWorkflowCreateCheckpoint: assert second is not None assert second.previous_checkpoint_id == first_id + async def test_raises_when_save_fails_after_prior_success(self, simple_executor: Executor) -> None: + """A failed save after an earlier successful checkpoint must not return the stale id. + + The runner log-and-swallows storage save errors and only updates + ``previous_checkpoint_id`` on success. Without an explicit transition check, + ``create_checkpoint`` would silently return the previously stored id as if a + new checkpoint had been created. + """ + from unittest.mock import AsyncMock + + storage = InMemoryCheckpointStorage() + workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build() + + # First call succeeds and seeds ``previous_checkpoint_id``. + first_id = await workflow.create_checkpoint(storage) + assert first_id + + # Second call fails to save, so the runner leaves ``previous_checkpoint_id`` + # pointing at ``first_id``. The method must detect that the id did not + # transition and raise instead of returning the stale value. + original_save = storage.save + storage.save = AsyncMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign] + try: + with pytest.raises(WorkflowCheckpointException, match="Failed to create checkpoint"): + await workflow.create_checkpoint(storage) + finally: + storage.save = original_save # type: ignore[method-assign] + + # The runner's bookkeeping is unchanged after the failed call. + assert workflow._runner.previous_checkpoint_id == first_id # type: ignore[attr-defined] + # endregion diff --git a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py index 1a5b0e32d0..1f64d24eff 100644 --- a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py +++ b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py @@ -616,10 +616,15 @@ class ResponsesHostServer(ResponsesAgentServerHost): latest_checkpoint_id: str = self._initial_checkpoint_id restore_storage: FileCheckpointStorage = self._initial_checkpoint_storage if context_id is not None: - restore_storage = _checkpoint_storage_for_context(self._checkpoint_storage_path, context_id) - latest_checkpoint = await restore_storage.get_latest(workflow_name=self._agent.workflow.name) + context_storage = _checkpoint_storage_for_context(self._checkpoint_storage_path, context_id) + latest_checkpoint = await context_storage.get_latest(workflow_name=self._agent.workflow.name) if latest_checkpoint is not None: + # Only switch the restore storage when a checkpoint was actually + # found under the per-context directory. Otherwise the initial + # checkpoint id would not resolve in `context_storage` and the + # restore call below would fail. latest_checkpoint_id = latest_checkpoint.checkpoint_id + restore_storage = context_storage # Restore the workflow to the latest checkpoint and run it with the # new input. Events (including request info events) will not be emitted diff --git a/python/packages/foundry_hosting/tests/test_responses.py b/python/packages/foundry_hosting/tests/test_responses.py index d9c0512c4c..13923cc9d2 100644 --- a/python/packages/foundry_hosting/tests/test_responses.py +++ b/python/packages/foundry_hosting/tests/test_responses.py @@ -3178,6 +3178,66 @@ class TestCheckpointContextPathValidation: is server._initial_checkpoint_storage # pyright: ignore[reportPrivateUsage] ) + async def test_handle_inner_workflow_falls_back_to_initial_storage_when_context_dir_is_empty( + self, tmp_path: Any + ) -> None: + """When ``previous_response_id`` is supplied but its checkpoint directory has no + checkpoints, the restoration must fall back to BOTH the initial checkpoint id + and the initial checkpoint storage. Otherwise the initial id would be looked up + inside the per-context storage where it does not exist, and the restore would + fail. + """ + from agent_framework import WorkflowAgent + from azure.ai.agentserver.responses import ResponseContext + from azure.ai.agentserver.responses.models import CreateResponse, ItemMessage + + previous_response_id = "resp_previous" + response_id = "resp_current" + root = tmp_path / "root" + root.mkdir() + # The per-context storage exists but contains no checkpoints. + (root / previous_response_id).mkdir() + + agent = MagicMock(spec=WorkflowAgent) + agent.id = "wf-agent" + agent.name = "wf" + agent.description = "" + agent.context_providers = [] + agent.workflow = MagicMock() + agent.workflow.name = "wf" + agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False) + agent.workflow.create_checkpoint = AsyncMock(return_value="cp_initial") + agent.run = AsyncMock( + side_effect=[ + AgentResponse(messages=[]), + AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])]), + ] + ) + server = ResponsesHostServer(agent, store=InMemoryResponseProvider()) + server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage] + + request = CreateResponse(model="m", input="hi", previous_response_id=previous_response_id) + context = ResponseContext( + response_id=response_id, previous_response_id=previous_response_id, mode_flags=MagicMock() + ) + input_item = ItemMessage({"type": "message", "role": "user", "content": "next turn"}) + + with patch.object(ResponseContext, "get_input_items", new=AsyncMock(return_value=[input_item])): + async for _ in server._handle_inner_workflow(request, context): # pyright: ignore[reportPrivateUsage] + pass + + # The restoration call must use the initial id AND the initial storage, + # not the empty per-context storage. Mismatching the two would attempt + # to load ``cp_initial`` from a directory that doesn't contain it. + assert agent.run.call_count == 2 + restore_call = agent.run.call_args_list[0] + assert restore_call.kwargs["checkpoint_id"] == "cp_initial" + assert restore_call.kwargs["checkpoint_storage"] is server._initial_checkpoint_storage # pyright: ignore[reportPrivateUsage] + + # The new turn still writes checkpoints under the current response id. + new_turn_call = agent.run.call_args_list[1] + assert new_turn_call.kwargs["checkpoint_storage"].storage_path == (root / response_id).resolve() + @pytest.mark.parametrize( "bad_id", [