From 9d8e5ca4f52a753c7fa3a49873ba466f1744c995 Mon Sep 17 00:00:00 2001 From: Baidar Date: Thu, 28 May 2026 11:13:30 +0200 Subject: [PATCH] Python: Allow hosted checkpoints to restore MessageRole (#6049) * Python: Allow hosted checkpoints to restore MessageRole Allow Responses hosting checkpoint storage to deserialize the Azure Responses MessageRole enum that hosted workflows can persist inside Agent Framework Message objects. Add regression coverage for both direct load() and the hosted get_latest() restore path, including the plain-storage failure mode where list_checkpoints logs the blocked type and get_latest() returns None. Ruff also normalizes a duplicate contextlib import in the touched hosting module. * Address MessageRole checkpoint review comments * Cover hosted MessageRole checkpoint restore path --- .../_responses.py | 10 +- .../foundry_hosting/tests/test_responses.py | 141 +++++++++++++++++- 2 files changed, 149 insertions(+), 2 deletions(-) diff --git a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py index 459705ca97..8940365930 100644 --- a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py +++ b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py @@ -72,6 +72,7 @@ from azure.ai.agentserver.responses.models import ( MessageContentOutputTextContent, MessageContentReasoningTextContent, MessageContentRefusalContent, + MessageRole, OAuthConsentRequestOutputItem, OutputItem, OutputItemApplyPatchToolCall, @@ -116,6 +117,8 @@ from typing_extensions import Any logger = logging.getLogger(__name__) +_AZURE_RESPONSES_MESSAGE_ROLE_TYPE = f"{MessageRole.__module__}:{MessageRole.__qualname__}" + # region Approval Storage class ApprovalStorage(Protocol): @@ -249,7 +252,12 @@ def _checkpoint_storage_for_context(root: str, context_id: str) -> FileCheckpoin storage_path = (root_path / context_id).resolve() if not storage_path.is_relative_to(root_path): raise RuntimeError(f"Invalid checkpoint context id: {context_id!r}") - return FileCheckpointStorage(storage_path) + return FileCheckpointStorage( + storage_path, + # Keep this provider-specific allowlist narrow. Hosted workflow + # checkpoints can persist Azure's role enum inside Message objects. + allowed_checkpoint_types=[_AZURE_RESPONSES_MESSAGE_ROLE_TYPE], + ) # endregion Approval Storage diff --git a/python/packages/foundry_hosting/tests/test_responses.py b/python/packages/foundry_hosting/tests/test_responses.py index 1503b3c9ba..9358549a86 100644 --- a/python/packages/foundry_hosting/tests/test_responses.py +++ b/python/packages/foundry_hosting/tests/test_responses.py @@ -13,7 +13,7 @@ from __future__ import annotations import json from collections.abc import AsyncIterator, Callable from dataclasses import dataclass -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest @@ -26,6 +26,9 @@ from agent_framework import ( Message, RawAgent, ResponseStream, + WorkflowCheckpoint, + WorkflowCheckpointException, + WorkflowMessage, ) from azure.ai.agentserver.responses import InMemoryResponseProvider from mcp import McpError @@ -34,6 +37,7 @@ from typing_extensions import Any from agent_framework_foundry_hosting import ResponsesHostServer from agent_framework_foundry_hosting._responses import ( + _AZURE_RESPONSES_MESSAGE_ROLE_TYPE, # pyright: ignore[reportPrivateUsage] CONSENT_ERROR_CODE, FileBasedFunctionApprovalStorage, # pyright: ignore[reportPrivateUsage] InMemoryFunctionApprovalStorage, # pyright: ignore[reportPrivateUsage] @@ -2712,6 +2716,23 @@ class TestCheckpointContextPathValidation: return _checkpoint_storage_for_context + @staticmethod + def _checkpoint_with_azure_message_role() -> WorkflowCheckpoint: + from azure.ai.agentserver.responses.models import MessageRole + + return WorkflowCheckpoint( + workflow_name="wf", + graph_signature_hash="hash", + messages={ + "executor": [ + WorkflowMessage( + data=Message(role=MessageRole.USER, contents=[Content.from_text("hello")]), + source_id="source", + ) + ] + }, + ) + def test_valid_segment_creates_storage_under_root(self, tmp_path: Any) -> None: helper = self._helper() root = tmp_path / "root" @@ -2720,6 +2741,124 @@ class TestCheckpointContextPathValidation: assert storage.storage_path.is_dir() assert storage.storage_path.parent == root.resolve() + def test_azure_message_role_allowlist_type_matches_generated_sdk_path(self) -> None: + assert ( + _AZURE_RESPONSES_MESSAGE_ROLE_TYPE + == "azure.ai.agentserver.responses.models._generated.sdk.models.models._enums:MessageRole" + ) + + async def test_storage_allows_azure_message_role_checkpoint_restore(self, tmp_path: Any) -> None: + from azure.ai.agentserver.responses.models import MessageRole + + helper = self._helper() + root = tmp_path / "root" + root.mkdir() + storage = helper(str(root), "resp_abc123") + checkpoint = self._checkpoint_with_azure_message_role() + + await storage.save(checkpoint) + loaded = await storage.load(checkpoint.checkpoint_id) + + loaded_message = loaded.messages["executor"][0].data + assert isinstance(loaded_message, Message) + assert type(loaded_message.role) is MessageRole + assert loaded_message.role == MessageRole.USER + assert loaded_message.text == "hello" + + async def test_plain_storage_blocks_azure_message_role_checkpoint_restore(self, tmp_path: Any) -> None: + storage = FileCheckpointStorage(tmp_path / "plain") + checkpoint = self._checkpoint_with_azure_message_role() + + await storage.save(checkpoint) + with pytest.raises(WorkflowCheckpointException, match="MessageRole"): + await storage.load(checkpoint.checkpoint_id) + + async def test_get_latest_restores_azure_message_role(self, tmp_path: Any) -> None: + from azure.ai.agentserver.responses.models import MessageRole + + helper = self._helper() + root = tmp_path / "root" + root.mkdir() + storage = helper(str(root), "resp_abc123") + checkpoint = self._checkpoint_with_azure_message_role() + + await storage.save(checkpoint) + latest = await storage.get_latest(workflow_name="wf") + + assert latest is not None + assert latest.checkpoint_id == checkpoint.checkpoint_id + latest_message = latest.messages["executor"][0].data + assert isinstance(latest_message, Message) + assert type(latest_message.role) is MessageRole + + async def test_get_latest_silently_skips_without_allowlist( + self, tmp_path: Any, caplog: pytest.LogCaptureFixture + ) -> None: + import logging + + storage = FileCheckpointStorage(tmp_path / "plain") + checkpoint = self._checkpoint_with_azure_message_role() + + await storage.save(checkpoint) + with caplog.at_level(logging.WARNING, logger="agent_framework"): + latest = await storage.get_latest(workflow_name="wf") + + assert latest is None + assert any("MessageRole" in message for message in caplog.messages) + + async def test_handle_inner_workflow_restores_message_role_checkpoint_from_previous_response( + self, tmp_path: Any + ) -> None: + from agent_framework import WorkflowAgent + from azure.ai.agentserver.responses import ResponseContext + from azure.ai.agentserver.responses.models import CreateResponse, ItemMessage + + previous_response_id = "resp_previous" + response_id = "resp_current" + root = tmp_path / "root" + root.mkdir() + checkpoint_storage = self._helper()(str(root), previous_response_id) + checkpoint = self._checkpoint_with_azure_message_role() + await checkpoint_storage.save(checkpoint) + + agent = MagicMock(spec=WorkflowAgent) + agent.id = "wf-agent" + agent.name = "wf" + agent.description = "" + agent.context_providers = [] + agent.workflow = MagicMock() + agent.workflow.name = "wf" + agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False) + agent.run = AsyncMock( + side_effect=[ + AgentResponse(messages=[]), + AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])]), + ] + ) + server = ResponsesHostServer(agent, store=InMemoryResponseProvider()) + server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage] + + request = CreateResponse(model="m", input="hi", previous_response_id=previous_response_id) + context = ResponseContext( + response_id=response_id, previous_response_id=previous_response_id, mode_flags=MagicMock() + ) + input_item = ItemMessage({"type": "message", "role": "user", "content": "next turn"}) + + with patch.object(ResponseContext, "get_input_items", new=AsyncMock(return_value=[input_item])): + async for _ in server._handle_inner_workflow(request, context): # pyright: ignore[reportPrivateUsage] + pass + + assert agent.run.call_count == 2 + restore_call = agent.run.call_args_list[0] + assert restore_call.kwargs["checkpoint_id"] == checkpoint.checkpoint_id + assert restore_call.kwargs["checkpoint_storage"].storage_path == (root / previous_response_id).resolve() + + new_turn_call = agent.run.call_args_list[1] + new_turn_messages = new_turn_call.args[0] + assert len(new_turn_messages) == 1 + assert new_turn_messages[0].text == "next turn" + assert new_turn_call.kwargs["checkpoint_storage"].storage_path == (root / response_id).resolve() + @pytest.mark.parametrize( "bad_id", [