Fix WorkflowAgent pending request resume after session restore

This commit is contained in:
copilot-swe-agent[bot]
2026-05-27 18:37:40 +00:00
committed by GitHub
Unverified
parent 9361cda413
commit 02e5600ede
2 changed files with 156 additions and 2 deletions
@@ -54,6 +54,8 @@ class WorkflowAgent(BaseAgent):
# Class variable for the request info function name
REQUEST_INFO_FUNCTION_NAME: ClassVar[str] = "request_info"
_SESSION_STATE_KEY: ClassVar[str] = "workflow_agent"
_PENDING_REQUESTS_STATE_KEY: ClassVar[str] = "pending_request_info_events"
@dataclass
class RequestInfoFunctionArgs:
@@ -258,6 +260,7 @@ class WorkflowAgent(BaseAgent):
An AgentResponse representing the workflow execution results.
"""
input_messages = normalize_messages_input(messages)
self._restore_pending_requests_from_session(session)
if (
not any(
@@ -291,10 +294,11 @@ class WorkflowAgent(BaseAgent):
)
# combine the messages
session_messages: list[Message] = session_context.get_messages(include_input=True)
workflow_input_messages = input_messages if bool(self.pending_requests) else session_messages
output_events: list[WorkflowEvent[Any]] = []
async for event in self._run_core(
session_messages,
workflow_input_messages,
checkpoint_id,
checkpoint_storage,
streaming=False,
@@ -311,6 +315,7 @@ class WorkflowAgent(BaseAgent):
session_context._response = result # type: ignore[assignment]
await self._run_after_providers(session=provider_session, context=session_context)
self._persist_pending_requests_to_session(session)
return result
async def _run_stream_impl(
@@ -338,6 +343,7 @@ class WorkflowAgent(BaseAgent):
AgentResponseUpdate objects representing the workflow execution progress.
"""
input_messages = normalize_messages_input(messages)
self._restore_pending_requests_from_session(session)
if (
not any(
@@ -372,9 +378,10 @@ class WorkflowAgent(BaseAgent):
# combine the messages
session_messages: list[Message] = session_context.get_messages(include_input=True)
workflow_input_messages = input_messages if bool(self.pending_requests) else session_messages
all_updates: list[AgentResponseUpdate] = []
async for event in self._run_core(
session_messages,
workflow_input_messages,
checkpoint_id,
checkpoint_storage,
streaming=True,
@@ -392,6 +399,7 @@ class WorkflowAgent(BaseAgent):
session_context._response = AgentResponse.from_updates(all_updates) # type: ignore[assignment]
await self._run_after_providers(session=provider_session, context=session_context)
self._persist_pending_requests_to_session(session)
async def _run_core(
self,
@@ -425,6 +433,8 @@ class WorkflowAgent(BaseAgent):
async for event in self.workflow.run(
responses=function_responses,
stream=True,
checkpoint_id=checkpoint_id,
checkpoint_storage=checkpoint_storage,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
@@ -432,6 +442,8 @@ class WorkflowAgent(BaseAgent):
else:
for event in await self.workflow.run(
responses=function_responses,
checkpoint_id=checkpoint_id,
checkpoint_storage=checkpoint_storage,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
@@ -484,6 +496,60 @@ class WorkflowAgent(BaseAgent):
# endregion Run Methods
def _restore_pending_requests_from_session(self, session: AgentSession | None) -> None:
"""Load pending request-info events from the session state."""
if session is None:
return
agent_state = session.state.get(self._SESSION_STATE_KEY)
if not isinstance(agent_state, dict):
self.pending_requests.clear()
return
pending_requests_payload = agent_state.get(self._PENDING_REQUESTS_STATE_KEY)
if not isinstance(pending_requests_payload, dict):
self.pending_requests.clear()
return
restored_pending: dict[str, WorkflowEvent[Any]] = {}
for request_id, request_payload in pending_requests_payload.items():
if isinstance(request_payload, WorkflowEvent):
restored_pending[request_id] = request_payload
continue
if not isinstance(request_payload, dict):
logger.warning("Skipping malformed pending request payload for request_id '%s'.", request_id)
continue
try:
restored_pending[request_id] = WorkflowEvent.from_dict(request_payload)
except Exception as exc: # pragma: no cover - defensive
logger.warning(
"Failed to restore pending request payload for request_id '%s': %s",
request_id,
exc,
)
self.pending_requests.clear()
self.pending_requests.update(restored_pending)
def _persist_pending_requests_to_session(self, session: AgentSession | None) -> None:
"""Persist pending request-info events to the session state."""
if session is None:
return
agent_state = session.state.setdefault(self._SESSION_STATE_KEY, {})
if not isinstance(agent_state, dict):
logger.warning(
"Skipping pending request persistence because '%s' is not a mapping.",
self._SESSION_STATE_KEY,
)
return
agent_state[self._PENDING_REQUESTS_STATE_KEY] = {
request_id: event.to_dict() for request_id, event in self.pending_requests.items()
}
def _process_pending_requests(self, input_messages: Sequence[Message]) -> dict[str, Any]:
"""Process pending requests by extracting function responses and updating state.
@@ -293,6 +293,94 @@ class TestWorkflowAgent:
# Verify cleanup - pending requests should be cleared after function response handling
assert len(agent.pending_requests) == 0
async def test_request_info_resume_after_session_restore_with_checkpoint(self):
"""Pending request metadata in AgentSession should resume the same request_id after restore."""
from agent_framework import InMemoryCheckpointStorage
simple_executor = SimpleExecutor(id="simple", response_text="SimpleResponse", streaming=False)
requesting_executor = RequestingExecutor(id="requester", streaming=False)
checkpoint_storage = InMemoryCheckpointStorage()
workflow = (
WorkflowBuilder(start_executor=simple_executor, checkpoint_storage=checkpoint_storage)
.add_edge(simple_executor, requesting_executor)
.build()
)
agent = WorkflowAgent(workflow=workflow, name="Request Restore Test Agent")
session = AgentSession()
updates: list[AgentResponseUpdate] = []
async for update in agent.run("Start request", stream=True, session=session):
updates.append(update)
approval_update = next(
(
update
for update in updates
if any(content.type == "function_approval_request" for content in update.contents)
),
None,
)
assert approval_update is not None, "Should have received a request_info approval request"
function_call = next(content for content in approval_update.contents if content.type == "function_call")
approval_request = next(
content for content in approval_update.contents if content.type == "function_approval_request"
)
request_id = approval_request.id
assert request_id is not None
assert function_call.call_id == request_id
checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name)
checkpoint_with_request = next(
(checkpoint for checkpoint in checkpoints if request_id in checkpoint.pending_request_info_events),
None,
)
assert checkpoint_with_request is not None
serialized_session = session.to_dict()
workflow_agent_state = serialized_session["state"].get("workflow_agent", {})
pending_state = workflow_agent_state.get("pending_request_info_events", {})
assert request_id in pending_state
restored_session = AgentSession.from_dict(serialized_session)
restored_simple_executor = SimpleExecutor(id="simple", response_text="SimpleResponse", streaming=False)
restored_requesting_executor = RequestingExecutor(id="requester", streaming=False)
restored_workflow = (
WorkflowBuilder(start_executor=restored_simple_executor, checkpoint_storage=checkpoint_storage)
.add_edge(restored_simple_executor, restored_requesting_executor)
.build()
)
restored_agent = WorkflowAgent(workflow=restored_workflow, name="Request Restore Test Agent")
response_args = WorkflowAgent.RequestInfoFunctionArgs(
request_id=request_id,
data="User provided answer",
).to_dict()
approval_response = Content.from_function_approval_response(
approved=True,
id=request_id,
function_call=Content.from_function_call(
call_id=request_id,
name=WorkflowAgent.REQUEST_INFO_FUNCTION_NAME,
arguments=response_args,
),
)
response_message = Message(role="user", contents=[approval_response])
continuation_result = await restored_agent.run(
response_message,
session=restored_session,
checkpoint_id=checkpoint_with_request.checkpoint_id,
checkpoint_storage=checkpoint_storage,
)
assert isinstance(continuation_result, AgentResponse)
response_texts = [message.text for message in continuation_result.messages if message.text]
assert any("Request completed with response: User provided answer" in text for text in response_texts)
assert len(restored_agent.pending_requests) == 0
def test_workflow_as_agent_method(self) -> None:
"""Test that Workflow.as_agent() creates a properly configured WorkflowAgent."""
# Create a simple workflow