mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Fix WorkflowAgent pending request resume after session restore
This commit is contained in:
committed by
GitHub
Unverified
parent
9361cda413
commit
02e5600ede
@@ -54,6 +54,8 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
# Class variable for the request info function name
|
||||
REQUEST_INFO_FUNCTION_NAME: ClassVar[str] = "request_info"
|
||||
_SESSION_STATE_KEY: ClassVar[str] = "workflow_agent"
|
||||
_PENDING_REQUESTS_STATE_KEY: ClassVar[str] = "pending_request_info_events"
|
||||
|
||||
@dataclass
|
||||
class RequestInfoFunctionArgs:
|
||||
@@ -258,6 +260,7 @@ class WorkflowAgent(BaseAgent):
|
||||
An AgentResponse representing the workflow execution results.
|
||||
"""
|
||||
input_messages = normalize_messages_input(messages)
|
||||
self._restore_pending_requests_from_session(session)
|
||||
|
||||
if (
|
||||
not any(
|
||||
@@ -291,10 +294,11 @@ class WorkflowAgent(BaseAgent):
|
||||
)
|
||||
# combine the messages
|
||||
session_messages: list[Message] = session_context.get_messages(include_input=True)
|
||||
workflow_input_messages = input_messages if bool(self.pending_requests) else session_messages
|
||||
|
||||
output_events: list[WorkflowEvent[Any]] = []
|
||||
async for event in self._run_core(
|
||||
session_messages,
|
||||
workflow_input_messages,
|
||||
checkpoint_id,
|
||||
checkpoint_storage,
|
||||
streaming=False,
|
||||
@@ -311,6 +315,7 @@ class WorkflowAgent(BaseAgent):
|
||||
session_context._response = result # type: ignore[assignment]
|
||||
|
||||
await self._run_after_providers(session=provider_session, context=session_context)
|
||||
self._persist_pending_requests_to_session(session)
|
||||
return result
|
||||
|
||||
async def _run_stream_impl(
|
||||
@@ -338,6 +343,7 @@ class WorkflowAgent(BaseAgent):
|
||||
AgentResponseUpdate objects representing the workflow execution progress.
|
||||
"""
|
||||
input_messages = normalize_messages_input(messages)
|
||||
self._restore_pending_requests_from_session(session)
|
||||
|
||||
if (
|
||||
not any(
|
||||
@@ -372,9 +378,10 @@ class WorkflowAgent(BaseAgent):
|
||||
# combine the messages
|
||||
|
||||
session_messages: list[Message] = session_context.get_messages(include_input=True)
|
||||
workflow_input_messages = input_messages if bool(self.pending_requests) else session_messages
|
||||
all_updates: list[AgentResponseUpdate] = []
|
||||
async for event in self._run_core(
|
||||
session_messages,
|
||||
workflow_input_messages,
|
||||
checkpoint_id,
|
||||
checkpoint_storage,
|
||||
streaming=True,
|
||||
@@ -392,6 +399,7 @@ class WorkflowAgent(BaseAgent):
|
||||
session_context._response = AgentResponse.from_updates(all_updates) # type: ignore[assignment]
|
||||
|
||||
await self._run_after_providers(session=provider_session, context=session_context)
|
||||
self._persist_pending_requests_to_session(session)
|
||||
|
||||
async def _run_core(
|
||||
self,
|
||||
@@ -425,6 +433,8 @@ class WorkflowAgent(BaseAgent):
|
||||
async for event in self.workflow.run(
|
||||
responses=function_responses,
|
||||
stream=True,
|
||||
checkpoint_id=checkpoint_id,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
@@ -432,6 +442,8 @@ class WorkflowAgent(BaseAgent):
|
||||
else:
|
||||
for event in await self.workflow.run(
|
||||
responses=function_responses,
|
||||
checkpoint_id=checkpoint_id,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
@@ -484,6 +496,60 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
# endregion Run Methods
|
||||
|
||||
def _restore_pending_requests_from_session(self, session: AgentSession | None) -> None:
|
||||
"""Load pending request-info events from the session state."""
|
||||
if session is None:
|
||||
return
|
||||
|
||||
agent_state = session.state.get(self._SESSION_STATE_KEY)
|
||||
if not isinstance(agent_state, dict):
|
||||
self.pending_requests.clear()
|
||||
return
|
||||
|
||||
pending_requests_payload = agent_state.get(self._PENDING_REQUESTS_STATE_KEY)
|
||||
if not isinstance(pending_requests_payload, dict):
|
||||
self.pending_requests.clear()
|
||||
return
|
||||
|
||||
restored_pending: dict[str, WorkflowEvent[Any]] = {}
|
||||
for request_id, request_payload in pending_requests_payload.items():
|
||||
if isinstance(request_payload, WorkflowEvent):
|
||||
restored_pending[request_id] = request_payload
|
||||
continue
|
||||
|
||||
if not isinstance(request_payload, dict):
|
||||
logger.warning("Skipping malformed pending request payload for request_id '%s'.", request_id)
|
||||
continue
|
||||
|
||||
try:
|
||||
restored_pending[request_id] = WorkflowEvent.from_dict(request_payload)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.warning(
|
||||
"Failed to restore pending request payload for request_id '%s': %s",
|
||||
request_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
self.pending_requests.clear()
|
||||
self.pending_requests.update(restored_pending)
|
||||
|
||||
def _persist_pending_requests_to_session(self, session: AgentSession | None) -> None:
|
||||
"""Persist pending request-info events to the session state."""
|
||||
if session is None:
|
||||
return
|
||||
|
||||
agent_state = session.state.setdefault(self._SESSION_STATE_KEY, {})
|
||||
if not isinstance(agent_state, dict):
|
||||
logger.warning(
|
||||
"Skipping pending request persistence because '%s' is not a mapping.",
|
||||
self._SESSION_STATE_KEY,
|
||||
)
|
||||
return
|
||||
|
||||
agent_state[self._PENDING_REQUESTS_STATE_KEY] = {
|
||||
request_id: event.to_dict() for request_id, event in self.pending_requests.items()
|
||||
}
|
||||
|
||||
def _process_pending_requests(self, input_messages: Sequence[Message]) -> dict[str, Any]:
|
||||
"""Process pending requests by extracting function responses and updating state.
|
||||
|
||||
|
||||
@@ -293,6 +293,94 @@ class TestWorkflowAgent:
|
||||
# Verify cleanup - pending requests should be cleared after function response handling
|
||||
assert len(agent.pending_requests) == 0
|
||||
|
||||
async def test_request_info_resume_after_session_restore_with_checkpoint(self):
|
||||
"""Pending request metadata in AgentSession should resume the same request_id after restore."""
|
||||
from agent_framework import InMemoryCheckpointStorage
|
||||
|
||||
simple_executor = SimpleExecutor(id="simple", response_text="SimpleResponse", streaming=False)
|
||||
requesting_executor = RequestingExecutor(id="requester", streaming=False)
|
||||
checkpoint_storage = InMemoryCheckpointStorage()
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor=simple_executor, checkpoint_storage=checkpoint_storage)
|
||||
.add_edge(simple_executor, requesting_executor)
|
||||
.build()
|
||||
)
|
||||
agent = WorkflowAgent(workflow=workflow, name="Request Restore Test Agent")
|
||||
session = AgentSession()
|
||||
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
async for update in agent.run("Start request", stream=True, session=session):
|
||||
updates.append(update)
|
||||
|
||||
approval_update = next(
|
||||
(
|
||||
update
|
||||
for update in updates
|
||||
if any(content.type == "function_approval_request" for content in update.contents)
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert approval_update is not None, "Should have received a request_info approval request"
|
||||
|
||||
function_call = next(content for content in approval_update.contents if content.type == "function_call")
|
||||
approval_request = next(
|
||||
content for content in approval_update.contents if content.type == "function_approval_request"
|
||||
)
|
||||
request_id = approval_request.id
|
||||
assert request_id is not None
|
||||
assert function_call.call_id == request_id
|
||||
|
||||
checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name)
|
||||
checkpoint_with_request = next(
|
||||
(checkpoint for checkpoint in checkpoints if request_id in checkpoint.pending_request_info_events),
|
||||
None,
|
||||
)
|
||||
assert checkpoint_with_request is not None
|
||||
|
||||
serialized_session = session.to_dict()
|
||||
workflow_agent_state = serialized_session["state"].get("workflow_agent", {})
|
||||
pending_state = workflow_agent_state.get("pending_request_info_events", {})
|
||||
assert request_id in pending_state
|
||||
|
||||
restored_session = AgentSession.from_dict(serialized_session)
|
||||
|
||||
restored_simple_executor = SimpleExecutor(id="simple", response_text="SimpleResponse", streaming=False)
|
||||
restored_requesting_executor = RequestingExecutor(id="requester", streaming=False)
|
||||
restored_workflow = (
|
||||
WorkflowBuilder(start_executor=restored_simple_executor, checkpoint_storage=checkpoint_storage)
|
||||
.add_edge(restored_simple_executor, restored_requesting_executor)
|
||||
.build()
|
||||
)
|
||||
restored_agent = WorkflowAgent(workflow=restored_workflow, name="Request Restore Test Agent")
|
||||
|
||||
response_args = WorkflowAgent.RequestInfoFunctionArgs(
|
||||
request_id=request_id,
|
||||
data="User provided answer",
|
||||
).to_dict()
|
||||
approval_response = Content.from_function_approval_response(
|
||||
approved=True,
|
||||
id=request_id,
|
||||
function_call=Content.from_function_call(
|
||||
call_id=request_id,
|
||||
name=WorkflowAgent.REQUEST_INFO_FUNCTION_NAME,
|
||||
arguments=response_args,
|
||||
),
|
||||
)
|
||||
response_message = Message(role="user", contents=[approval_response])
|
||||
|
||||
continuation_result = await restored_agent.run(
|
||||
response_message,
|
||||
session=restored_session,
|
||||
checkpoint_id=checkpoint_with_request.checkpoint_id,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
)
|
||||
|
||||
assert isinstance(continuation_result, AgentResponse)
|
||||
response_texts = [message.text for message in continuation_result.messages if message.text]
|
||||
assert any("Request completed with response: User provided answer" in text for text in response_texts)
|
||||
assert len(restored_agent.pending_requests) == 0
|
||||
|
||||
def test_workflow_as_agent_method(self) -> None:
|
||||
"""Test that Workflow.as_agent() creates a properly configured WorkflowAgent."""
|
||||
# Create a simple workflow
|
||||
|
||||
Reference in New Issue
Block a user