[BREAKING] Python: Remove request_type param from ctx.request_info() (#1824)

* Remove request_type param from ctx.request_info()

* Address comments
This commit is contained in:
Tao Chen
2025-10-31 07:31:15 -07:00
committed by GitHub
Unverified
parent 2101d9d36d
commit 68b6a55757
21 changed files with 72 additions and 48 deletions
@@ -210,23 +210,21 @@ class RequestInfoEvent(WorkflowEvent):
self,
request_id: str,
source_executor_id: str,
request_type: type,
request_data: Any,
response_type: type,
response_type: type[Any],
):
"""Initialize the request info event.
Args:
request_id: Unique identifier for the request.
source_executor_id: ID of the executor that made the request.
request_type: Type of the request (e.g., a specific data type).
request_data: The data associated with the request.
response_type: Expected type of the response.
"""
super().__init__(request_data)
self.request_id = request_id
self.source_executor_id = source_executor_id
self.request_type = request_type
self.request_type: type[Any] = type(request_data)
self.response_type = response_type
def __repr__(self) -> str:
@@ -258,14 +256,21 @@ class RequestInfoEvent(WorkflowEvent):
if property not in data:
raise KeyError(f"Missing '{property}' field in RequestInfoEvent dictionary.")
return RequestInfoEvent(
request_info_event = RequestInfoEvent(
request_id=data["request_id"],
source_executor_id=data["source_executor_id"],
request_type=deserialize_type(data["request_type"]),
request_data=decode_checkpoint_value(data["data"]),
response_type=deserialize_type(data["response_type"]),
)
# Verify that the deserialized request_data matches the declared request_type
if deserialize_type(data["request_type"]) is not type(request_info_event.data):
raise TypeError(
"Mismatch between deserialized request_data type and request_type field in RequestInfoEvent dictionary."
)
return request_info_event
class WorkflowOutputEvent(WorkflowEvent):
"""Event triggered when a workflow executor yields output."""
@@ -99,7 +99,7 @@ class Executor(RequestInfoMixin, DictConvertible):
response = request.create_response(data=True)
await ctx.send_message(response, target_id=request.executor_id)
else:
await ctx.request_info(request.source_event)
await ctx.request_info(request.source_event, response_type=request.source_event.response_type)
## Context Types
Handler methods receive different WorkflowContext variants based on their type annotations:
@@ -504,7 +504,7 @@ class _UserInputGateway(Executor):
prompt=self._prompt,
source_executor_id=self.id,
)
await ctx.request_info(request, HandoffUserInputRequest, object)
await ctx.request_info(request, object)
@response_handler
async def resume_from_user(
@@ -1612,7 +1612,7 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
plan_text=plan_text,
round_index=self._plan_review_round,
)
await context.request_info(req, _MagenticPlanReviewRequest, _MagenticPlanReviewReply)
await context.request_info(req, _MagenticPlanReviewReply)
# region Magentic Executors
@@ -129,7 +129,7 @@ def response_handler(
# Example of a handler that sends a request
...
# Send a request with a `CustomRequest` payload and expect a `str` response.
await context.request_info(CustomRequest(...), CustomRequest, str)
await context.request_info(CustomRequest(...), str)
@response_handler
@@ -348,7 +348,7 @@ class WorkflowContext(Generic[T_Out, T_W_Out]):
return
await self._runner_context.add_event(event)
async def request_info(self, request_data: Any, request_type: type, response_type: type) -> None:
async def request_info(self, request_data: object, response_type: type) -> None:
"""Request information from outside of the workflow.
Calling this method will cause the workflow to emit a RequestInfoEvent, carrying the
@@ -360,9 +360,9 @@ class WorkflowContext(Generic[T_Out, T_W_Out]):
Args:
request_data: The data associated with the information request.
request_type: The type of the request, used to match with response handlers.
response_type: The expected type of the response, used for validation.
"""
request_type: type = type(request_data)
if not self._executor.is_request_supported(request_type, response_type):
logger.warning(
f"Executor '{self._executor_id}' requested info of type {request_type.__name__} "
@@ -374,7 +374,6 @@ class WorkflowContext(Generic[T_Out, T_W_Out]):
request_info_event = RequestInfoEvent(
request_id=str(uuid.uuid4()),
source_executor_id=self._executor_id,
request_type=request_type,
request_data=request_data,
response_type=response_type,
)
@@ -261,7 +261,7 @@ class WorkflowExecutor(Executor):
await ctx.send_message(response, target_id=request.source_executor_id)
else:
# Forward to external handler
await ctx.request_info(request.source_event)
await ctx.request_info(request.source_event, response_type=request.source_event.response_type)
```
## Implementation Notes
@@ -62,7 +62,7 @@ class ApprovalRequiredExecutor(Executor, RequestInfoMixin):
prompt=f"Please approve the operation: {message}",
context="This is a critical operation that requires human approval.",
)
await ctx.request_info(approval_request, UserApprovalRequest, bool)
await ctx.request_info(approval_request, bool)
@response_handler
async def handle_approval_response(
@@ -96,7 +96,7 @@ class CalculationExecutor(Executor, RequestInfoMixin):
try:
operands = [float(x) for x in parts[1:]]
calc_request = CalculationRequest(operation=operation, operands=operands)
await ctx.request_info(calc_request, CalculationRequest, float)
await ctx.request_info(calc_request, float)
except ValueError:
await ctx.send_message("Invalid calculation format")
else:
@@ -126,11 +126,11 @@ class MultiRequestExecutor(Executor, RequestInfoMixin):
approval_request = UserApprovalRequest(
prompt="Approve batch operation", context="Multiple operations will be performed"
)
await ctx.request_info(approval_request, UserApprovalRequest, bool)
await ctx.request_info(approval_request, bool)
# Request calculation
calc_request = CalculationRequest(operation="multiply", operands=[10.0, 5.0])
await ctx.request_info(calc_request, CalculationRequest, float)
await ctx.request_info(calc_request, float)
@response_handler
async def handle_approval_response(
@@ -7,7 +7,7 @@ from datetime import datetime, timezone
import pytest
from agent_framework import InMemoryCheckpointStorage, InProcRunnerContext
from agent_framework._workflows._checkpoint_encoding import encode_checkpoint_value
from agent_framework._workflows._checkpoint_encoding import DATACLASS_MARKER, encode_checkpoint_value
from agent_framework._workflows._checkpoint_summary import get_checkpoint_summary
from agent_framework._workflows._events import RequestInfoEvent
from agent_framework._workflows._shared_state import SharedState
@@ -39,7 +39,6 @@ async def test_rehydrate_request_info_event() -> None:
request_info_event = RequestInfoEvent(
request_id="request-123",
source_executor_id="review_gateway",
request_type=MockRequest,
request_data=MockRequest(),
response_type=bool,
)
@@ -73,7 +72,6 @@ async def test_rehydrate_fails_when_request_type_missing() -> None:
request_info_event = RequestInfoEvent(
request_id="request-123",
source_executor_id="review_gateway",
request_type=MockRequest,
request_data=MockRequest(),
response_type=bool,
)
@@ -97,12 +95,41 @@ async def test_rehydrate_fails_when_request_type_missing() -> None:
await runner_context.apply_checkpoint(checkpoint)
async def test_rehydrate_fails_when_request_type_mismatch() -> None:
"""Rehydration should fail if the request type is mismatched."""
request_info_event = RequestInfoEvent(
request_id="request-123",
source_executor_id="review_gateway",
request_data=MockRequest(),
response_type=bool,
)
runner_context = InProcRunnerContext(InMemoryCheckpointStorage())
await runner_context.add_request_info_event(request_info_event)
checkpoint_id = await runner_context.create_checkpoint(SharedState(), iteration_count=1)
checkpoint = await runner_context.load_checkpoint(checkpoint_id)
assert checkpoint is not None
assert checkpoint.pending_request_info_events
assert "request-123" in checkpoint.pending_request_info_events
assert "request_type" in checkpoint.pending_request_info_events["request-123"]
# Modify the checkpoint to simulate mismatched request type in the serialized data
checkpoint.pending_request_info_events["request-123"]["data"][DATACLASS_MARKER] = (
"nonexistent.module:MissingRequest"
)
# Rehydrate the context
with pytest.raises(TypeError):
await runner_context.apply_checkpoint(checkpoint)
async def test_pending_requests_in_summary() -> None:
"""Test that pending requests are correctly summarized in the checkpoint summary."""
request_info_event = RequestInfoEvent(
request_id="request-123",
source_executor_id="review_gateway",
request_type=MockRequest,
request_data=MockRequest(),
response_type=bool,
)
@@ -134,14 +161,12 @@ async def test_request_info_event_serializes_non_json_payloads() -> None:
req_1 = RequestInfoEvent(
request_id="req-1",
source_executor_id="source",
request_type=TimedApproval,
request_data=TimedApproval(issued_at=datetime(2024, 5, 4, 12, 30, 45)),
response_type=bool,
)
req_2 = RequestInfoEvent(
request_id="req-2",
source_executor_id="source",
request_type=SlottedApproval,
request_data=SlottedApproval(note="slot-based"),
response_type=bool,
)
@@ -76,7 +76,7 @@ class Coordinator(Executor):
else:
# Not in cache, forward to external
self._pending_sub_workflow_requests[domain_request.id] = sub_workflow_request
await ctx.request_info(domain_request, DomainCheckRequest, bool)
await ctx.request_info(domain_request, bool)
@response_handler
async def handle_domain_response(
@@ -138,7 +138,7 @@ class EmailDomainValidator(Executor):
return
# Request domain check from external source
await ctx.request_info(request, DomainCheckRequest, bool)
await ctx.request_info(request, bool)
@response_handler
async def handle_domain_response(
@@ -302,7 +302,7 @@ async def test_workflow_scoped_interception() -> None:
# Unknown source, forward to external
self._pending_sub_workflow_requests[domain_request.id] = sub_workflow_request
await ctx.request_info(domain_request, DomainCheckRequest, bool)
await ctx.request_info(domain_request, bool)
@response_handler
async def handle_domain_response(
@@ -386,7 +386,7 @@ async def test_concurrent_sub_workflow_execution() -> None:
domain_request = sub_workflow_request.source_event.data
self._pending_sub_workflow_requests[domain_request.id] = sub_workflow_request
await ctx.request_info(domain_request, DomainCheckRequest, bool)
await ctx.request_info(domain_request, bool)
@response_handler
async def handle_domain_response(
@@ -179,7 +179,6 @@ def test_serialize_deserialize_roundtrip() -> None:
instance = deserialized(
request_id="request-123",
source_executor_id="executor_1",
request_type=str,
request_data="test",
response_type=str,
)
@@ -89,7 +89,7 @@ class MockExecutorRequestApproval(Executor):
async def mock_handler_a(self, message: NumberMessage, ctx: WorkflowContext) -> None:
"""A mock handler that requests approval."""
await ctx.set_shared_state(self.id, message.data)
await ctx.request_info(MockRequest(prompt="Mock approval request"), MockRequest, ApprovalMessage)
await ctx.request_info(MockRequest(prompt="Mock approval request"), ApprovalMessage)
@response_handler
async def mock_handler_b(
@@ -485,7 +485,6 @@ async def test_workflow_run_stream_from_checkpoint_with_responses(simple_executo
"request_123": RequestInfoEvent(
request_id="request_123",
source_executor_id=simple_executor.id,
request_type=str,
request_data="Mock",
response_type=str,
).to_dict(),
@@ -60,7 +60,7 @@ class RequestingExecutor(Executor):
@handler
async def handle_message(self, _: list[ChatMessage], ctx: WorkflowContext) -> None:
# Send a RequestInfoMessage to trigger the request info process
await ctx.request_info("Mock request data", str, str)
await ctx.request_info("Mock request data", str)
@response_handler
async def handle_request_response(
@@ -78,7 +78,7 @@ class Requester(Executor):
@handler
async def ask(self, _: str, ctx: WorkflowContext) -> None: # pragma: no cover
await ctx.request_info("Mock request data", str, str)
await ctx.request_info("Mock request data", str)
async def test_idle_with_pending_requests_status_streaming():
@@ -128,9 +128,8 @@ class Coordinator(Executor):
"Keep it under 30 words."
)
await ctx.request_info(
DraftFeedbackRequest(prompt=prompt, draft_text=draft_text, conversation=conversation),
DraftFeedbackRequest,
str,
request_data=DraftFeedbackRequest(prompt=prompt, draft_text=draft_text, conversation=conversation),
response_type=str,
)
@response_handler
@@ -76,7 +76,7 @@ class ReviewerWithHumanInTheLoop(Executor):
print("Reviewer: Escalating to human manager...")
# Forward the request to a human manager by sending a HumanReviewRequest.
await ctx.request_info(HumanReviewRequest(agent_request=request), HumanReviewRequest, ReviewResponse)
await ctx.request_info(request_data=HumanReviewRequest(agent_request=request), response_type=ReviewResponse)
@response_handler
async def accept_human_review(
@@ -125,13 +125,12 @@ class ReviewGateway(Executor):
await ctx.set_executor_state({"iteration": iteration, "last_draft": draft})
# Emit a human approval request.
await ctx.request_info(
HumanApprovalRequest(
request_data=HumanApprovalRequest(
prompt="Review the draft. Reply 'approve' or provide edit instructions.",
draft=draft,
iteration=iteration,
),
HumanApprovalRequest,
str,
response_type=str,
)
@response_handler
@@ -145,7 +145,7 @@ class DraftReviewRouter(Executor):
"Confirm CTA is action-oriented",
],
)
await ctx.request_info(request, ReviewRequest, str)
await ctx.request_info(request_data=request, response_type=str)
@response_handler
async def forward_decision(
@@ -251,7 +251,7 @@ class LaunchCoordinator(Executor):
await ctx.set_executor_state(executor_state)
# Send the request without modification
await ctx.request_info(review_request, ReviewRequest, str)
await ctx.request_info(request_data=review_request, response_type=str)
@response_handler
async def handle_request_response(
@@ -122,7 +122,7 @@ def build_resource_request_distribution_workflow() -> Workflow:
@handler
async def run(self, request: ResourceRequest, ctx: WorkflowContext) -> None:
await ctx.request_info(request, ResourceRequest, ResourceResponse)
await ctx.request_info(request_data=request, response_type=ResourceResponse)
@response_handler
async def handle_response(
@@ -136,7 +136,7 @@ def build_resource_request_distribution_workflow() -> Workflow:
@handler
async def run(self, request: PolicyRequest, ctx: WorkflowContext) -> None:
await ctx.request_info(request, PolicyRequest, PolicyResponse)
await ctx.request_info(request_data=request, response_type=PolicyResponse)
@response_handler
async def handle_response(
@@ -219,7 +219,7 @@ class ResourceAllocator(Executor):
else:
# Request cannot be fulfilled via cache, forward the request to external
self._pending_requests[request_payload.id] = source_event
await ctx.request_info(request_payload, ResourceRequest, ResourceResponse)
await ctx.request_info(request_data=request_payload, response_type=ResourceResponse)
@response_handler
async def handle_external_response(
@@ -270,7 +270,7 @@ class PolicyEngine(Executor):
else:
# For other policy types, forward to external system
self._pending_requests[request_payload.id] = source_event
await ctx.request_info(request_payload, PolicyRequest, PolicyResponse)
await ctx.request_info(request_data=request_payload, response_type=PolicyResponse)
@response_handler
async def handle_external_response(
@@ -124,7 +124,7 @@ def build_email_address_validation_workflow() -> Workflow:
print(f"🔍 Validating domain: '{domain}'")
self._pending_domains[domain] = partial_result
# Send a request to the external system via the request_info mechanism
await ctx.request_info(domain, str, bool)
await ctx.request_info(request_data=domain, response_type=bool)
@response_handler
async def handle_domain_validation_response(
@@ -115,7 +115,6 @@ class TurnManager(Executor):
# Send a request with a prompt as the payload and expect a string reply.
await ctx.request_info(
request_data=HumanFeedbackRequest(prompt=prompt),
request_type=HumanFeedbackRequest,
response_type=str,
)