Python: Allow hosted checkpoints to restore MessageRole (#6049)

* Python: Allow hosted checkpoints to restore MessageRole

Allow Responses hosting checkpoint storage to deserialize the Azure Responses MessageRole enum that hosted workflows can persist inside Agent Framework Message objects.

Add regression coverage for both direct load() and the hosted get_latest() restore path, including the plain-storage failure mode where list_checkpoints logs the blocked type and get_latest() returns None.

Ruff also normalizes a duplicate contextlib import in the touched hosting module.

* Address MessageRole checkpoint review comments

* Cover hosted MessageRole checkpoint restore path
This commit is contained in:
Baidar
2026-05-28 11:13:30 +02:00
committed by GitHub
Unverified
parent af787569b3
commit 9d8e5ca4f5
2 changed files with 149 additions and 2 deletions
@@ -72,6 +72,7 @@ from azure.ai.agentserver.responses.models import (
MessageContentOutputTextContent,
MessageContentReasoningTextContent,
MessageContentRefusalContent,
MessageRole,
OAuthConsentRequestOutputItem,
OutputItem,
OutputItemApplyPatchToolCall,
@@ -116,6 +117,8 @@ from typing_extensions import Any
logger = logging.getLogger(__name__)
_AZURE_RESPONSES_MESSAGE_ROLE_TYPE = f"{MessageRole.__module__}:{MessageRole.__qualname__}"
# region Approval Storage
class ApprovalStorage(Protocol):
@@ -249,7 +252,12 @@ def _checkpoint_storage_for_context(root: str, context_id: str) -> FileCheckpoin
storage_path = (root_path / context_id).resolve()
if not storage_path.is_relative_to(root_path):
raise RuntimeError(f"Invalid checkpoint context id: {context_id!r}")
return FileCheckpointStorage(storage_path)
return FileCheckpointStorage(
storage_path,
# Keep this provider-specific allowlist narrow. Hosted workflow
# checkpoints can persist Azure's role enum inside Message objects.
allowed_checkpoint_types=[_AZURE_RESPONSES_MESSAGE_ROLE_TYPE],
)
# endregion Approval Storage
@@ -13,7 +13,7 @@ from __future__ import annotations
import json
from collections.abc import AsyncIterator, Callable
from dataclasses import dataclass
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
@@ -26,6 +26,9 @@ from agent_framework import (
Message,
RawAgent,
ResponseStream,
WorkflowCheckpoint,
WorkflowCheckpointException,
WorkflowMessage,
)
from azure.ai.agentserver.responses import InMemoryResponseProvider
from mcp import McpError
@@ -34,6 +37,7 @@ from typing_extensions import Any
from agent_framework_foundry_hosting import ResponsesHostServer
from agent_framework_foundry_hosting._responses import (
_AZURE_RESPONSES_MESSAGE_ROLE_TYPE, # pyright: ignore[reportPrivateUsage]
CONSENT_ERROR_CODE,
FileBasedFunctionApprovalStorage, # pyright: ignore[reportPrivateUsage]
InMemoryFunctionApprovalStorage, # pyright: ignore[reportPrivateUsage]
@@ -2712,6 +2716,23 @@ class TestCheckpointContextPathValidation:
return _checkpoint_storage_for_context
@staticmethod
def _checkpoint_with_azure_message_role() -> WorkflowCheckpoint:
from azure.ai.agentserver.responses.models import MessageRole
return WorkflowCheckpoint(
workflow_name="wf",
graph_signature_hash="hash",
messages={
"executor": [
WorkflowMessage(
data=Message(role=MessageRole.USER, contents=[Content.from_text("hello")]),
source_id="source",
)
]
},
)
def test_valid_segment_creates_storage_under_root(self, tmp_path: Any) -> None:
helper = self._helper()
root = tmp_path / "root"
@@ -2720,6 +2741,124 @@ class TestCheckpointContextPathValidation:
assert storage.storage_path.is_dir()
assert storage.storage_path.parent == root.resolve()
def test_azure_message_role_allowlist_type_matches_generated_sdk_path(self) -> None:
assert (
_AZURE_RESPONSES_MESSAGE_ROLE_TYPE
== "azure.ai.agentserver.responses.models._generated.sdk.models.models._enums:MessageRole"
)
async def test_storage_allows_azure_message_role_checkpoint_restore(self, tmp_path: Any) -> None:
from azure.ai.agentserver.responses.models import MessageRole
helper = self._helper()
root = tmp_path / "root"
root.mkdir()
storage = helper(str(root), "resp_abc123")
checkpoint = self._checkpoint_with_azure_message_role()
await storage.save(checkpoint)
loaded = await storage.load(checkpoint.checkpoint_id)
loaded_message = loaded.messages["executor"][0].data
assert isinstance(loaded_message, Message)
assert type(loaded_message.role) is MessageRole
assert loaded_message.role == MessageRole.USER
assert loaded_message.text == "hello"
async def test_plain_storage_blocks_azure_message_role_checkpoint_restore(self, tmp_path: Any) -> None:
storage = FileCheckpointStorage(tmp_path / "plain")
checkpoint = self._checkpoint_with_azure_message_role()
await storage.save(checkpoint)
with pytest.raises(WorkflowCheckpointException, match="MessageRole"):
await storage.load(checkpoint.checkpoint_id)
async def test_get_latest_restores_azure_message_role(self, tmp_path: Any) -> None:
from azure.ai.agentserver.responses.models import MessageRole
helper = self._helper()
root = tmp_path / "root"
root.mkdir()
storage = helper(str(root), "resp_abc123")
checkpoint = self._checkpoint_with_azure_message_role()
await storage.save(checkpoint)
latest = await storage.get_latest(workflow_name="wf")
assert latest is not None
assert latest.checkpoint_id == checkpoint.checkpoint_id
latest_message = latest.messages["executor"][0].data
assert isinstance(latest_message, Message)
assert type(latest_message.role) is MessageRole
async def test_get_latest_silently_skips_without_allowlist(
self, tmp_path: Any, caplog: pytest.LogCaptureFixture
) -> None:
import logging
storage = FileCheckpointStorage(tmp_path / "plain")
checkpoint = self._checkpoint_with_azure_message_role()
await storage.save(checkpoint)
with caplog.at_level(logging.WARNING, logger="agent_framework"):
latest = await storage.get_latest(workflow_name="wf")
assert latest is None
assert any("MessageRole" in message for message in caplog.messages)
async def test_handle_inner_workflow_restores_message_role_checkpoint_from_previous_response(
self, tmp_path: Any
) -> None:
from agent_framework import WorkflowAgent
from azure.ai.agentserver.responses import ResponseContext
from azure.ai.agentserver.responses.models import CreateResponse, ItemMessage
previous_response_id = "resp_previous"
response_id = "resp_current"
root = tmp_path / "root"
root.mkdir()
checkpoint_storage = self._helper()(str(root), previous_response_id)
checkpoint = self._checkpoint_with_azure_message_role()
await checkpoint_storage.save(checkpoint)
agent = MagicMock(spec=WorkflowAgent)
agent.id = "wf-agent"
agent.name = "wf"
agent.description = ""
agent.context_providers = []
agent.workflow = MagicMock()
agent.workflow.name = "wf"
agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False)
agent.run = AsyncMock(
side_effect=[
AgentResponse(messages=[]),
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])]),
]
)
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage]
request = CreateResponse(model="m", input="hi", previous_response_id=previous_response_id)
context = ResponseContext(
response_id=response_id, previous_response_id=previous_response_id, mode_flags=MagicMock()
)
input_item = ItemMessage({"type": "message", "role": "user", "content": "next turn"})
with patch.object(ResponseContext, "get_input_items", new=AsyncMock(return_value=[input_item])):
async for _ in server._handle_inner_workflow(request, context): # pyright: ignore[reportPrivateUsage]
pass
assert agent.run.call_count == 2
restore_call = agent.run.call_args_list[0]
assert restore_call.kwargs["checkpoint_id"] == checkpoint.checkpoint_id
assert restore_call.kwargs["checkpoint_storage"].storage_path == (root / previous_response_id).resolve()
new_turn_call = agent.run.call_args_list[1]
new_turn_messages = new_turn_call.args[0]
assert len(new_turn_messages) == 1
assert new_turn_messages[0].text == "next turn"
assert new_turn_call.kwargs["checkpoint_storage"].storage_path == (root / response_id).resolve()
@pytest.mark.parametrize(
"bad_id",
[