mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Address comments
This commit is contained in:
@@ -403,8 +403,9 @@ class InProcRunnerContext:
|
||||
def reset_for_new_run(self) -> None:
|
||||
"""Reset the context for a new workflow run.
|
||||
|
||||
This clears messages, events, and resets streaming flag.
|
||||
Runtime checkpoint storage is NOT cleared here as it's managed at the workflow level.
|
||||
Clears messages, the pending event queue, the pending request_info
|
||||
correlation map, and the streaming flag. Runtime checkpoint storage is
|
||||
NOT cleared here as it's managed at the workflow level.
|
||||
"""
|
||||
self._messages.clear()
|
||||
# Clear any pending events (best-effort) by recreating the queue
|
||||
|
||||
@@ -503,11 +503,18 @@ class Workflow(DictConvertible):
|
||||
if checkpoint_storage is not None:
|
||||
self._runner.context.set_runtime_checkpoint_storage(checkpoint_storage)
|
||||
|
||||
# Capture the runner's checkpoint id before attempting to save. The runner
|
||||
# log-and-swallows storage save errors and only updates
|
||||
# ``previous_checkpoint_id`` on success, so a failed save would otherwise
|
||||
# leave the prior id in place and we'd return it as if a fresh checkpoint
|
||||
# had been created.
|
||||
previous_id_before = self._runner.previous_checkpoint_id
|
||||
try:
|
||||
await self._runner.create_checkpoint_if_enabled()
|
||||
if self._runner.previous_checkpoint_id is None:
|
||||
new_id = self._runner.previous_checkpoint_id
|
||||
if new_id is None or new_id == previous_id_before:
|
||||
raise WorkflowCheckpointException("Failed to create checkpoint.")
|
||||
return self._runner.previous_checkpoint_id
|
||||
return new_id
|
||||
finally:
|
||||
self._runner.context.clear_runtime_checkpoint_storage()
|
||||
|
||||
|
||||
@@ -1461,5 +1461,36 @@ class TestWorkflowCreateCheckpoint:
|
||||
assert second is not None
|
||||
assert second.previous_checkpoint_id == first_id
|
||||
|
||||
async def test_raises_when_save_fails_after_prior_success(self, simple_executor: Executor) -> None:
|
||||
"""A failed save after an earlier successful checkpoint must not return the stale id.
|
||||
|
||||
The runner log-and-swallows storage save errors and only updates
|
||||
``previous_checkpoint_id`` on success. Without an explicit transition check,
|
||||
``create_checkpoint`` would silently return the previously stored id as if a
|
||||
new checkpoint had been created.
|
||||
"""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
storage = InMemoryCheckpointStorage()
|
||||
workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build()
|
||||
|
||||
# First call succeeds and seeds ``previous_checkpoint_id``.
|
||||
first_id = await workflow.create_checkpoint(storage)
|
||||
assert first_id
|
||||
|
||||
# Second call fails to save, so the runner leaves ``previous_checkpoint_id``
|
||||
# pointing at ``first_id``. The method must detect that the id did not
|
||||
# transition and raise instead of returning the stale value.
|
||||
original_save = storage.save
|
||||
storage.save = AsyncMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign]
|
||||
try:
|
||||
with pytest.raises(WorkflowCheckpointException, match="Failed to create checkpoint"):
|
||||
await workflow.create_checkpoint(storage)
|
||||
finally:
|
||||
storage.save = original_save # type: ignore[method-assign]
|
||||
|
||||
# The runner's bookkeeping is unchanged after the failed call.
|
||||
assert workflow._runner.previous_checkpoint_id == first_id # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -616,10 +616,15 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
latest_checkpoint_id: str = self._initial_checkpoint_id
|
||||
restore_storage: FileCheckpointStorage = self._initial_checkpoint_storage
|
||||
if context_id is not None:
|
||||
restore_storage = _checkpoint_storage_for_context(self._checkpoint_storage_path, context_id)
|
||||
latest_checkpoint = await restore_storage.get_latest(workflow_name=self._agent.workflow.name)
|
||||
context_storage = _checkpoint_storage_for_context(self._checkpoint_storage_path, context_id)
|
||||
latest_checkpoint = await context_storage.get_latest(workflow_name=self._agent.workflow.name)
|
||||
if latest_checkpoint is not None:
|
||||
# Only switch the restore storage when a checkpoint was actually
|
||||
# found under the per-context directory. Otherwise the initial
|
||||
# checkpoint id would not resolve in `context_storage` and the
|
||||
# restore call below would fail.
|
||||
latest_checkpoint_id = latest_checkpoint.checkpoint_id
|
||||
restore_storage = context_storage
|
||||
|
||||
# Restore the workflow to the latest checkpoint and run it with the
|
||||
# new input. Events (including request info events) will not be emitted
|
||||
|
||||
@@ -3178,6 +3178,66 @@ class TestCheckpointContextPathValidation:
|
||||
is server._initial_checkpoint_storage # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
async def test_handle_inner_workflow_falls_back_to_initial_storage_when_context_dir_is_empty(
|
||||
self, tmp_path: Any
|
||||
) -> None:
|
||||
"""When ``previous_response_id`` is supplied but its checkpoint directory has no
|
||||
checkpoints, the restoration must fall back to BOTH the initial checkpoint id
|
||||
and the initial checkpoint storage. Otherwise the initial id would be looked up
|
||||
inside the per-context storage where it does not exist, and the restore would
|
||||
fail.
|
||||
"""
|
||||
from agent_framework import WorkflowAgent
|
||||
from azure.ai.agentserver.responses import ResponseContext
|
||||
from azure.ai.agentserver.responses.models import CreateResponse, ItemMessage
|
||||
|
||||
previous_response_id = "resp_previous"
|
||||
response_id = "resp_current"
|
||||
root = tmp_path / "root"
|
||||
root.mkdir()
|
||||
# The per-context storage exists but contains no checkpoints.
|
||||
(root / previous_response_id).mkdir()
|
||||
|
||||
agent = MagicMock(spec=WorkflowAgent)
|
||||
agent.id = "wf-agent"
|
||||
agent.name = "wf"
|
||||
agent.description = ""
|
||||
agent.context_providers = []
|
||||
agent.workflow = MagicMock()
|
||||
agent.workflow.name = "wf"
|
||||
agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False)
|
||||
agent.workflow.create_checkpoint = AsyncMock(return_value="cp_initial")
|
||||
agent.run = AsyncMock(
|
||||
side_effect=[
|
||||
AgentResponse(messages=[]),
|
||||
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])]),
|
||||
]
|
||||
)
|
||||
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
|
||||
server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
request = CreateResponse(model="m", input="hi", previous_response_id=previous_response_id)
|
||||
context = ResponseContext(
|
||||
response_id=response_id, previous_response_id=previous_response_id, mode_flags=MagicMock()
|
||||
)
|
||||
input_item = ItemMessage({"type": "message", "role": "user", "content": "next turn"})
|
||||
|
||||
with patch.object(ResponseContext, "get_input_items", new=AsyncMock(return_value=[input_item])):
|
||||
async for _ in server._handle_inner_workflow(request, context): # pyright: ignore[reportPrivateUsage]
|
||||
pass
|
||||
|
||||
# The restoration call must use the initial id AND the initial storage,
|
||||
# not the empty per-context storage. Mismatching the two would attempt
|
||||
# to load ``cp_initial`` from a directory that doesn't contain it.
|
||||
assert agent.run.call_count == 2
|
||||
restore_call = agent.run.call_args_list[0]
|
||||
assert restore_call.kwargs["checkpoint_id"] == "cp_initial"
|
||||
assert restore_call.kwargs["checkpoint_storage"] is server._initial_checkpoint_storage # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# The new turn still writes checkpoints under the current response id.
|
||||
new_turn_call = agent.run.call_args_list[1]
|
||||
assert new_turn_call.kwargs["checkpoint_storage"].storage_path == (root / response_id).resolve()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bad_id",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user