diff --git a/python/packages/core/agent_framework/_workflows/_events.py b/python/packages/core/agent_framework/_workflows/_events.py index e5a8c3e3c4..76ae7f8a4f 100644 --- a/python/packages/core/agent_framework/_workflows/_events.py +++ b/python/packages/core/agent_framework/_workflows/_events.py @@ -210,23 +210,21 @@ class RequestInfoEvent(WorkflowEvent): self, request_id: str, source_executor_id: str, - request_type: type, request_data: Any, - response_type: type, + response_type: type[Any], ): """Initialize the request info event. Args: request_id: Unique identifier for the request. source_executor_id: ID of the executor that made the request. - request_type: Type of the request (e.g., a specific data type). request_data: The data associated with the request. response_type: Expected type of the response. """ super().__init__(request_data) self.request_id = request_id self.source_executor_id = source_executor_id - self.request_type = request_type + self.request_type: type[Any] = type(request_data) self.response_type = response_type def __repr__(self) -> str: @@ -258,14 +256,21 @@ class RequestInfoEvent(WorkflowEvent): if property not in data: raise KeyError(f"Missing '{property}' field in RequestInfoEvent dictionary.") - return RequestInfoEvent( + request_info_event = RequestInfoEvent( request_id=data["request_id"], source_executor_id=data["source_executor_id"], - request_type=deserialize_type(data["request_type"]), request_data=decode_checkpoint_value(data["data"]), response_type=deserialize_type(data["response_type"]), ) + # Verify that the deserialized request_data matches the declared request_type + if deserialize_type(data["request_type"]) is not type(request_info_event.data): + raise TypeError( + "Mismatch between deserialized request_data type and request_type field in RequestInfoEvent dictionary." + ) + + return request_info_event + class WorkflowOutputEvent(WorkflowEvent): """Event triggered when a workflow executor yields output.""" diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index aa63d2576a..1563dd7c53 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -99,7 +99,7 @@ class Executor(RequestInfoMixin, DictConvertible): response = request.create_response(data=True) await ctx.send_message(response, target_id=request.executor_id) else: - await ctx.request_info(request.source_event) + await ctx.request_info(request.source_event, response_type=request.source_event.response_type) ## Context Types Handler methods receive different WorkflowContext variants based on their type annotations: diff --git a/python/packages/core/agent_framework/_workflows/_handoff.py b/python/packages/core/agent_framework/_workflows/_handoff.py index d48eb9590f..9d8f6c8467 100644 --- a/python/packages/core/agent_framework/_workflows/_handoff.py +++ b/python/packages/core/agent_framework/_workflows/_handoff.py @@ -504,7 +504,7 @@ class _UserInputGateway(Executor): prompt=self._prompt, source_executor_id=self.id, ) - await ctx.request_info(request, HandoffUserInputRequest, object) + await ctx.request_info(request, object) @response_handler async def resume_from_user( diff --git a/python/packages/core/agent_framework/_workflows/_magentic.py b/python/packages/core/agent_framework/_workflows/_magentic.py index 13569ba7b9..b9cf4258a1 100644 --- a/python/packages/core/agent_framework/_workflows/_magentic.py +++ b/python/packages/core/agent_framework/_workflows/_magentic.py @@ -1612,7 +1612,7 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator): plan_text=plan_text, round_index=self._plan_review_round, ) - await context.request_info(req, _MagenticPlanReviewRequest, _MagenticPlanReviewReply) + await context.request_info(req, _MagenticPlanReviewReply) # region Magentic Executors diff --git a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py index 67e08d51ec..d4c5f6d2bc 100644 --- a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py +++ b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py @@ -129,7 +129,7 @@ def response_handler( # Example of a handler that sends a request ... # Send a request with a `CustomRequest` payload and expect a `str` response. - await context.request_info(CustomRequest(...), CustomRequest, str) + await context.request_info(CustomRequest(...), str) @response_handler diff --git a/python/packages/core/agent_framework/_workflows/_workflow_context.py b/python/packages/core/agent_framework/_workflows/_workflow_context.py index 177e85563f..d2a3648298 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_context.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_context.py @@ -348,7 +348,7 @@ class WorkflowContext(Generic[T_Out, T_W_Out]): return await self._runner_context.add_event(event) - async def request_info(self, request_data: Any, request_type: type, response_type: type) -> None: + async def request_info(self, request_data: object, response_type: type) -> None: """Request information from outside of the workflow. Calling this method will cause the workflow to emit a RequestInfoEvent, carrying the @@ -360,9 +360,9 @@ class WorkflowContext(Generic[T_Out, T_W_Out]): Args: request_data: The data associated with the information request. - request_type: The type of the request, used to match with response handlers. response_type: The expected type of the response, used for validation. """ + request_type: type = type(request_data) if not self._executor.is_request_supported(request_type, response_type): logger.warning( f"Executor '{self._executor_id}' requested info of type {request_type.__name__} " @@ -374,7 +374,6 @@ class WorkflowContext(Generic[T_Out, T_W_Out]): request_info_event = RequestInfoEvent( request_id=str(uuid.uuid4()), source_executor_id=self._executor_id, - request_type=request_type, request_data=request_data, response_type=response_type, ) diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index e7614b5cdb..c656e36b72 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -261,7 +261,7 @@ class WorkflowExecutor(Executor): await ctx.send_message(response, target_id=request.source_executor_id) else: # Forward to external handler - await ctx.request_info(request.source_event) + await ctx.request_info(request.source_event, response_type=request.source_event.response_type) ``` ## Implementation Notes diff --git a/python/packages/core/tests/workflow/test_request_info_and_response.py b/python/packages/core/tests/workflow/test_request_info_and_response.py index 198a058e85..b83028d98a 100644 --- a/python/packages/core/tests/workflow/test_request_info_and_response.py +++ b/python/packages/core/tests/workflow/test_request_info_and_response.py @@ -62,7 +62,7 @@ class ApprovalRequiredExecutor(Executor, RequestInfoMixin): prompt=f"Please approve the operation: {message}", context="This is a critical operation that requires human approval.", ) - await ctx.request_info(approval_request, UserApprovalRequest, bool) + await ctx.request_info(approval_request, bool) @response_handler async def handle_approval_response( @@ -96,7 +96,7 @@ class CalculationExecutor(Executor, RequestInfoMixin): try: operands = [float(x) for x in parts[1:]] calc_request = CalculationRequest(operation=operation, operands=operands) - await ctx.request_info(calc_request, CalculationRequest, float) + await ctx.request_info(calc_request, float) except ValueError: await ctx.send_message("Invalid calculation format") else: @@ -126,11 +126,11 @@ class MultiRequestExecutor(Executor, RequestInfoMixin): approval_request = UserApprovalRequest( prompt="Approve batch operation", context="Multiple operations will be performed" ) - await ctx.request_info(approval_request, UserApprovalRequest, bool) + await ctx.request_info(approval_request, bool) # Request calculation calc_request = CalculationRequest(operation="multiply", operands=[10.0, 5.0]) - await ctx.request_info(calc_request, CalculationRequest, float) + await ctx.request_info(calc_request, float) @response_handler async def handle_approval_response( diff --git a/python/packages/core/tests/workflow/test_request_info_event_rehydrate.py b/python/packages/core/tests/workflow/test_request_info_event_rehydrate.py index 80dc7e3004..c0fd8e198f 100644 --- a/python/packages/core/tests/workflow/test_request_info_event_rehydrate.py +++ b/python/packages/core/tests/workflow/test_request_info_event_rehydrate.py @@ -7,7 +7,7 @@ from datetime import datetime, timezone import pytest from agent_framework import InMemoryCheckpointStorage, InProcRunnerContext -from agent_framework._workflows._checkpoint_encoding import encode_checkpoint_value +from agent_framework._workflows._checkpoint_encoding import DATACLASS_MARKER, encode_checkpoint_value from agent_framework._workflows._checkpoint_summary import get_checkpoint_summary from agent_framework._workflows._events import RequestInfoEvent from agent_framework._workflows._shared_state import SharedState @@ -39,7 +39,6 @@ async def test_rehydrate_request_info_event() -> None: request_info_event = RequestInfoEvent( request_id="request-123", source_executor_id="review_gateway", - request_type=MockRequest, request_data=MockRequest(), response_type=bool, ) @@ -73,7 +72,6 @@ async def test_rehydrate_fails_when_request_type_missing() -> None: request_info_event = RequestInfoEvent( request_id="request-123", source_executor_id="review_gateway", - request_type=MockRequest, request_data=MockRequest(), response_type=bool, ) @@ -97,12 +95,41 @@ async def test_rehydrate_fails_when_request_type_missing() -> None: await runner_context.apply_checkpoint(checkpoint) +async def test_rehydrate_fails_when_request_type_mismatch() -> None: + """Rehydration should fail if the request type is mismatched.""" + request_info_event = RequestInfoEvent( + request_id="request-123", + source_executor_id="review_gateway", + request_data=MockRequest(), + response_type=bool, + ) + + runner_context = InProcRunnerContext(InMemoryCheckpointStorage()) + await runner_context.add_request_info_event(request_info_event) + + checkpoint_id = await runner_context.create_checkpoint(SharedState(), iteration_count=1) + checkpoint = await runner_context.load_checkpoint(checkpoint_id) + + assert checkpoint is not None + assert checkpoint.pending_request_info_events + assert "request-123" in checkpoint.pending_request_info_events + assert "request_type" in checkpoint.pending_request_info_events["request-123"] + + # Modify the checkpoint to simulate mismatched request type in the serialized data + checkpoint.pending_request_info_events["request-123"]["data"][DATACLASS_MARKER] = ( + "nonexistent.module:MissingRequest" + ) + + # Rehydrate the context + with pytest.raises(TypeError): + await runner_context.apply_checkpoint(checkpoint) + + async def test_pending_requests_in_summary() -> None: """Test that pending requests are correctly summarized in the checkpoint summary.""" request_info_event = RequestInfoEvent( request_id="request-123", source_executor_id="review_gateway", - request_type=MockRequest, request_data=MockRequest(), response_type=bool, ) @@ -134,14 +161,12 @@ async def test_request_info_event_serializes_non_json_payloads() -> None: req_1 = RequestInfoEvent( request_id="req-1", source_executor_id="source", - request_type=TimedApproval, request_data=TimedApproval(issued_at=datetime(2024, 5, 4, 12, 30, 45)), response_type=bool, ) req_2 = RequestInfoEvent( request_id="req-2", source_executor_id="source", - request_type=SlottedApproval, request_data=SlottedApproval(note="slot-based"), response_type=bool, ) diff --git a/python/packages/core/tests/workflow/test_sub_workflow.py b/python/packages/core/tests/workflow/test_sub_workflow.py index 8fdde7f62e..cb2733b653 100644 --- a/python/packages/core/tests/workflow/test_sub_workflow.py +++ b/python/packages/core/tests/workflow/test_sub_workflow.py @@ -76,7 +76,7 @@ class Coordinator(Executor): else: # Not in cache, forward to external self._pending_sub_workflow_requests[domain_request.id] = sub_workflow_request - await ctx.request_info(domain_request, DomainCheckRequest, bool) + await ctx.request_info(domain_request, bool) @response_handler async def handle_domain_response( @@ -138,7 +138,7 @@ class EmailDomainValidator(Executor): return # Request domain check from external source - await ctx.request_info(request, DomainCheckRequest, bool) + await ctx.request_info(request, bool) @response_handler async def handle_domain_response( @@ -302,7 +302,7 @@ async def test_workflow_scoped_interception() -> None: # Unknown source, forward to external self._pending_sub_workflow_requests[domain_request.id] = sub_workflow_request - await ctx.request_info(domain_request, DomainCheckRequest, bool) + await ctx.request_info(domain_request, bool) @response_handler async def handle_domain_response( @@ -386,7 +386,7 @@ async def test_concurrent_sub_workflow_execution() -> None: domain_request = sub_workflow_request.source_event.data self._pending_sub_workflow_requests[domain_request.id] = sub_workflow_request - await ctx.request_info(domain_request, DomainCheckRequest, bool) + await ctx.request_info(domain_request, bool) @response_handler async def handle_domain_response( diff --git a/python/packages/core/tests/workflow/test_typing_utils.py b/python/packages/core/tests/workflow/test_typing_utils.py index 26f8e8ed27..4294f35f4b 100644 --- a/python/packages/core/tests/workflow/test_typing_utils.py +++ b/python/packages/core/tests/workflow/test_typing_utils.py @@ -179,7 +179,6 @@ def test_serialize_deserialize_roundtrip() -> None: instance = deserialized( request_id="request-123", source_executor_id="executor_1", - request_type=str, request_data="test", response_type=str, ) diff --git a/python/packages/core/tests/workflow/test_workflow.py b/python/packages/core/tests/workflow/test_workflow.py index b713c5e4a4..cbf75c5a65 100644 --- a/python/packages/core/tests/workflow/test_workflow.py +++ b/python/packages/core/tests/workflow/test_workflow.py @@ -89,7 +89,7 @@ class MockExecutorRequestApproval(Executor): async def mock_handler_a(self, message: NumberMessage, ctx: WorkflowContext) -> None: """A mock handler that requests approval.""" await ctx.set_shared_state(self.id, message.data) - await ctx.request_info(MockRequest(prompt="Mock approval request"), MockRequest, ApprovalMessage) + await ctx.request_info(MockRequest(prompt="Mock approval request"), ApprovalMessage) @response_handler async def mock_handler_b( @@ -485,7 +485,6 @@ async def test_workflow_run_stream_from_checkpoint_with_responses(simple_executo "request_123": RequestInfoEvent( request_id="request_123", source_executor_id=simple_executor.id, - request_type=str, request_data="Mock", response_type=str, ).to_dict(), diff --git a/python/packages/core/tests/workflow/test_workflow_agent.py b/python/packages/core/tests/workflow/test_workflow_agent.py index 44a17d02bf..009ead6edd 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent.py +++ b/python/packages/core/tests/workflow/test_workflow_agent.py @@ -60,7 +60,7 @@ class RequestingExecutor(Executor): @handler async def handle_message(self, _: list[ChatMessage], ctx: WorkflowContext) -> None: # Send a RequestInfoMessage to trigger the request info process - await ctx.request_info("Mock request data", str, str) + await ctx.request_info("Mock request data", str) @response_handler async def handle_request_response( diff --git a/python/packages/core/tests/workflow/test_workflow_states.py b/python/packages/core/tests/workflow/test_workflow_states.py index c21da08d52..4e88ed26cb 100644 --- a/python/packages/core/tests/workflow/test_workflow_states.py +++ b/python/packages/core/tests/workflow/test_workflow_states.py @@ -78,7 +78,7 @@ class Requester(Executor): @handler async def ask(self, _: str, ctx: WorkflowContext) -> None: # pragma: no cover - await ctx.request_info("Mock request data", str, str) + await ctx.request_info("Mock request data", str) async def test_idle_with_pending_requests_status_streaming(): diff --git a/python/samples/getting_started/workflows/agents/azure_chat_agents_tool_calls_with_feedback.py b/python/samples/getting_started/workflows/agents/azure_chat_agents_tool_calls_with_feedback.py index b5d6262a8b..81ee4f8c4d 100644 --- a/python/samples/getting_started/workflows/agents/azure_chat_agents_tool_calls_with_feedback.py +++ b/python/samples/getting_started/workflows/agents/azure_chat_agents_tool_calls_with_feedback.py @@ -128,9 +128,8 @@ class Coordinator(Executor): "Keep it under 30 words." ) await ctx.request_info( - DraftFeedbackRequest(prompt=prompt, draft_text=draft_text, conversation=conversation), - DraftFeedbackRequest, - str, + request_data=DraftFeedbackRequest(prompt=prompt, draft_text=draft_text, conversation=conversation), + response_type=str, ) @response_handler diff --git a/python/samples/getting_started/workflows/agents/workflow_as_agent_human_in_the_loop.py b/python/samples/getting_started/workflows/agents/workflow_as_agent_human_in_the_loop.py index 79028cc325..9935339709 100644 --- a/python/samples/getting_started/workflows/agents/workflow_as_agent_human_in_the_loop.py +++ b/python/samples/getting_started/workflows/agents/workflow_as_agent_human_in_the_loop.py @@ -76,7 +76,7 @@ class ReviewerWithHumanInTheLoop(Executor): print("Reviewer: Escalating to human manager...") # Forward the request to a human manager by sending a HumanReviewRequest. - await ctx.request_info(HumanReviewRequest(agent_request=request), HumanReviewRequest, ReviewResponse) + await ctx.request_info(request_data=HumanReviewRequest(agent_request=request), response_type=ReviewResponse) @response_handler async def accept_human_review( diff --git a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py index a70bebfba2..2a24327952 100644 --- a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py +++ b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py @@ -125,13 +125,12 @@ class ReviewGateway(Executor): await ctx.set_executor_state({"iteration": iteration, "last_draft": draft}) # Emit a human approval request. await ctx.request_info( - HumanApprovalRequest( + request_data=HumanApprovalRequest( prompt="Review the draft. Reply 'approve' or provide edit instructions.", draft=draft, iteration=iteration, ), - HumanApprovalRequest, - str, + response_type=str, ) @response_handler diff --git a/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py b/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py index 005158bb3a..3311d18b36 100644 --- a/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py +++ b/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py @@ -145,7 +145,7 @@ class DraftReviewRouter(Executor): "Confirm CTA is action-oriented", ], ) - await ctx.request_info(request, ReviewRequest, str) + await ctx.request_info(request_data=request, response_type=str) @response_handler async def forward_decision( @@ -251,7 +251,7 @@ class LaunchCoordinator(Executor): await ctx.set_executor_state(executor_state) # Send the request without modification - await ctx.request_info(review_request, ReviewRequest, str) + await ctx.request_info(request_data=review_request, response_type=str) @response_handler async def handle_request_response( diff --git a/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py b/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py index ca30d87f6f..b33a24b8b5 100644 --- a/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py +++ b/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py @@ -122,7 +122,7 @@ def build_resource_request_distribution_workflow() -> Workflow: @handler async def run(self, request: ResourceRequest, ctx: WorkflowContext) -> None: - await ctx.request_info(request, ResourceRequest, ResourceResponse) + await ctx.request_info(request_data=request, response_type=ResourceResponse) @response_handler async def handle_response( @@ -136,7 +136,7 @@ def build_resource_request_distribution_workflow() -> Workflow: @handler async def run(self, request: PolicyRequest, ctx: WorkflowContext) -> None: - await ctx.request_info(request, PolicyRequest, PolicyResponse) + await ctx.request_info(request_data=request, response_type=PolicyResponse) @response_handler async def handle_response( @@ -219,7 +219,7 @@ class ResourceAllocator(Executor): else: # Request cannot be fulfilled via cache, forward the request to external self._pending_requests[request_payload.id] = source_event - await ctx.request_info(request_payload, ResourceRequest, ResourceResponse) + await ctx.request_info(request_data=request_payload, response_type=ResourceResponse) @response_handler async def handle_external_response( @@ -270,7 +270,7 @@ class PolicyEngine(Executor): else: # For other policy types, forward to external system self._pending_requests[request_payload.id] = source_event - await ctx.request_info(request_payload, PolicyRequest, PolicyResponse) + await ctx.request_info(request_data=request_payload, response_type=PolicyResponse) @response_handler async def handle_external_response( diff --git a/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py b/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py index 769397b972..241b87e18c 100644 --- a/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py +++ b/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py @@ -124,7 +124,7 @@ def build_email_address_validation_workflow() -> Workflow: print(f"🔍 Validating domain: '{domain}'") self._pending_domains[domain] = partial_result # Send a request to the external system via the request_info mechanism - await ctx.request_info(domain, str, bool) + await ctx.request_info(request_data=domain, response_type=bool) @response_handler async def handle_domain_validation_response( diff --git a/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py b/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py index d492ff8d60..a922baf16c 100644 --- a/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py +++ b/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py @@ -115,7 +115,6 @@ class TurnManager(Executor): # Send a request with a prompt as the payload and expect a string reply. await ctx.request_info( request_data=HumanFeedbackRequest(prompt=prompt), - request_type=HumanFeedbackRequest, response_type=str, )