mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
[BREAKING] Python: Remove request_type param from ctx.request_info() (#1824)
* Remove request_type param from ctx.request_info() * Address comments
This commit is contained in:
committed by
GitHub
Unverified
parent
2101d9d36d
commit
68b6a55757
@@ -210,23 +210,21 @@ class RequestInfoEvent(WorkflowEvent):
|
||||
self,
|
||||
request_id: str,
|
||||
source_executor_id: str,
|
||||
request_type: type,
|
||||
request_data: Any,
|
||||
response_type: type,
|
||||
response_type: type[Any],
|
||||
):
|
||||
"""Initialize the request info event.
|
||||
|
||||
Args:
|
||||
request_id: Unique identifier for the request.
|
||||
source_executor_id: ID of the executor that made the request.
|
||||
request_type: Type of the request (e.g., a specific data type).
|
||||
request_data: The data associated with the request.
|
||||
response_type: Expected type of the response.
|
||||
"""
|
||||
super().__init__(request_data)
|
||||
self.request_id = request_id
|
||||
self.source_executor_id = source_executor_id
|
||||
self.request_type = request_type
|
||||
self.request_type: type[Any] = type(request_data)
|
||||
self.response_type = response_type
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -258,14 +256,21 @@ class RequestInfoEvent(WorkflowEvent):
|
||||
if property not in data:
|
||||
raise KeyError(f"Missing '{property}' field in RequestInfoEvent dictionary.")
|
||||
|
||||
return RequestInfoEvent(
|
||||
request_info_event = RequestInfoEvent(
|
||||
request_id=data["request_id"],
|
||||
source_executor_id=data["source_executor_id"],
|
||||
request_type=deserialize_type(data["request_type"]),
|
||||
request_data=decode_checkpoint_value(data["data"]),
|
||||
response_type=deserialize_type(data["response_type"]),
|
||||
)
|
||||
|
||||
# Verify that the deserialized request_data matches the declared request_type
|
||||
if deserialize_type(data["request_type"]) is not type(request_info_event.data):
|
||||
raise TypeError(
|
||||
"Mismatch between deserialized request_data type and request_type field in RequestInfoEvent dictionary."
|
||||
)
|
||||
|
||||
return request_info_event
|
||||
|
||||
|
||||
class WorkflowOutputEvent(WorkflowEvent):
|
||||
"""Event triggered when a workflow executor yields output."""
|
||||
|
||||
@@ -99,7 +99,7 @@ class Executor(RequestInfoMixin, DictConvertible):
|
||||
response = request.create_response(data=True)
|
||||
await ctx.send_message(response, target_id=request.executor_id)
|
||||
else:
|
||||
await ctx.request_info(request.source_event)
|
||||
await ctx.request_info(request.source_event, response_type=request.source_event.response_type)
|
||||
|
||||
## Context Types
|
||||
Handler methods receive different WorkflowContext variants based on their type annotations:
|
||||
|
||||
@@ -504,7 +504,7 @@ class _UserInputGateway(Executor):
|
||||
prompt=self._prompt,
|
||||
source_executor_id=self.id,
|
||||
)
|
||||
await ctx.request_info(request, HandoffUserInputRequest, object)
|
||||
await ctx.request_info(request, object)
|
||||
|
||||
@response_handler
|
||||
async def resume_from_user(
|
||||
|
||||
@@ -1612,7 +1612,7 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
|
||||
plan_text=plan_text,
|
||||
round_index=self._plan_review_round,
|
||||
)
|
||||
await context.request_info(req, _MagenticPlanReviewRequest, _MagenticPlanReviewReply)
|
||||
await context.request_info(req, _MagenticPlanReviewReply)
|
||||
|
||||
|
||||
# region Magentic Executors
|
||||
|
||||
@@ -129,7 +129,7 @@ def response_handler(
|
||||
# Example of a handler that sends a request
|
||||
...
|
||||
# Send a request with a `CustomRequest` payload and expect a `str` response.
|
||||
await context.request_info(CustomRequest(...), CustomRequest, str)
|
||||
await context.request_info(CustomRequest(...), str)
|
||||
|
||||
|
||||
@response_handler
|
||||
|
||||
@@ -348,7 +348,7 @@ class WorkflowContext(Generic[T_Out, T_W_Out]):
|
||||
return
|
||||
await self._runner_context.add_event(event)
|
||||
|
||||
async def request_info(self, request_data: Any, request_type: type, response_type: type) -> None:
|
||||
async def request_info(self, request_data: object, response_type: type) -> None:
|
||||
"""Request information from outside of the workflow.
|
||||
|
||||
Calling this method will cause the workflow to emit a RequestInfoEvent, carrying the
|
||||
@@ -360,9 +360,9 @@ class WorkflowContext(Generic[T_Out, T_W_Out]):
|
||||
|
||||
Args:
|
||||
request_data: The data associated with the information request.
|
||||
request_type: The type of the request, used to match with response handlers.
|
||||
response_type: The expected type of the response, used for validation.
|
||||
"""
|
||||
request_type: type = type(request_data)
|
||||
if not self._executor.is_request_supported(request_type, response_type):
|
||||
logger.warning(
|
||||
f"Executor '{self._executor_id}' requested info of type {request_type.__name__} "
|
||||
@@ -374,7 +374,6 @@ class WorkflowContext(Generic[T_Out, T_W_Out]):
|
||||
request_info_event = RequestInfoEvent(
|
||||
request_id=str(uuid.uuid4()),
|
||||
source_executor_id=self._executor_id,
|
||||
request_type=request_type,
|
||||
request_data=request_data,
|
||||
response_type=response_type,
|
||||
)
|
||||
|
||||
@@ -261,7 +261,7 @@ class WorkflowExecutor(Executor):
|
||||
await ctx.send_message(response, target_id=request.source_executor_id)
|
||||
else:
|
||||
# Forward to external handler
|
||||
await ctx.request_info(request.source_event)
|
||||
await ctx.request_info(request.source_event, response_type=request.source_event.response_type)
|
||||
```
|
||||
|
||||
## Implementation Notes
|
||||
|
||||
@@ -62,7 +62,7 @@ class ApprovalRequiredExecutor(Executor, RequestInfoMixin):
|
||||
prompt=f"Please approve the operation: {message}",
|
||||
context="This is a critical operation that requires human approval.",
|
||||
)
|
||||
await ctx.request_info(approval_request, UserApprovalRequest, bool)
|
||||
await ctx.request_info(approval_request, bool)
|
||||
|
||||
@response_handler
|
||||
async def handle_approval_response(
|
||||
@@ -96,7 +96,7 @@ class CalculationExecutor(Executor, RequestInfoMixin):
|
||||
try:
|
||||
operands = [float(x) for x in parts[1:]]
|
||||
calc_request = CalculationRequest(operation=operation, operands=operands)
|
||||
await ctx.request_info(calc_request, CalculationRequest, float)
|
||||
await ctx.request_info(calc_request, float)
|
||||
except ValueError:
|
||||
await ctx.send_message("Invalid calculation format")
|
||||
else:
|
||||
@@ -126,11 +126,11 @@ class MultiRequestExecutor(Executor, RequestInfoMixin):
|
||||
approval_request = UserApprovalRequest(
|
||||
prompt="Approve batch operation", context="Multiple operations will be performed"
|
||||
)
|
||||
await ctx.request_info(approval_request, UserApprovalRequest, bool)
|
||||
await ctx.request_info(approval_request, bool)
|
||||
|
||||
# Request calculation
|
||||
calc_request = CalculationRequest(operation="multiply", operands=[10.0, 5.0])
|
||||
await ctx.request_info(calc_request, CalculationRequest, float)
|
||||
await ctx.request_info(calc_request, float)
|
||||
|
||||
@response_handler
|
||||
async def handle_approval_response(
|
||||
|
||||
@@ -7,7 +7,7 @@ from datetime import datetime, timezone
|
||||
import pytest
|
||||
|
||||
from agent_framework import InMemoryCheckpointStorage, InProcRunnerContext
|
||||
from agent_framework._workflows._checkpoint_encoding import encode_checkpoint_value
|
||||
from agent_framework._workflows._checkpoint_encoding import DATACLASS_MARKER, encode_checkpoint_value
|
||||
from agent_framework._workflows._checkpoint_summary import get_checkpoint_summary
|
||||
from agent_framework._workflows._events import RequestInfoEvent
|
||||
from agent_framework._workflows._shared_state import SharedState
|
||||
@@ -39,7 +39,6 @@ async def test_rehydrate_request_info_event() -> None:
|
||||
request_info_event = RequestInfoEvent(
|
||||
request_id="request-123",
|
||||
source_executor_id="review_gateway",
|
||||
request_type=MockRequest,
|
||||
request_data=MockRequest(),
|
||||
response_type=bool,
|
||||
)
|
||||
@@ -73,7 +72,6 @@ async def test_rehydrate_fails_when_request_type_missing() -> None:
|
||||
request_info_event = RequestInfoEvent(
|
||||
request_id="request-123",
|
||||
source_executor_id="review_gateway",
|
||||
request_type=MockRequest,
|
||||
request_data=MockRequest(),
|
||||
response_type=bool,
|
||||
)
|
||||
@@ -97,12 +95,41 @@ async def test_rehydrate_fails_when_request_type_missing() -> None:
|
||||
await runner_context.apply_checkpoint(checkpoint)
|
||||
|
||||
|
||||
async def test_rehydrate_fails_when_request_type_mismatch() -> None:
|
||||
"""Rehydration should fail if the request type is mismatched."""
|
||||
request_info_event = RequestInfoEvent(
|
||||
request_id="request-123",
|
||||
source_executor_id="review_gateway",
|
||||
request_data=MockRequest(),
|
||||
response_type=bool,
|
||||
)
|
||||
|
||||
runner_context = InProcRunnerContext(InMemoryCheckpointStorage())
|
||||
await runner_context.add_request_info_event(request_info_event)
|
||||
|
||||
checkpoint_id = await runner_context.create_checkpoint(SharedState(), iteration_count=1)
|
||||
checkpoint = await runner_context.load_checkpoint(checkpoint_id)
|
||||
|
||||
assert checkpoint is not None
|
||||
assert checkpoint.pending_request_info_events
|
||||
assert "request-123" in checkpoint.pending_request_info_events
|
||||
assert "request_type" in checkpoint.pending_request_info_events["request-123"]
|
||||
|
||||
# Modify the checkpoint to simulate mismatched request type in the serialized data
|
||||
checkpoint.pending_request_info_events["request-123"]["data"][DATACLASS_MARKER] = (
|
||||
"nonexistent.module:MissingRequest"
|
||||
)
|
||||
|
||||
# Rehydrate the context
|
||||
with pytest.raises(TypeError):
|
||||
await runner_context.apply_checkpoint(checkpoint)
|
||||
|
||||
|
||||
async def test_pending_requests_in_summary() -> None:
|
||||
"""Test that pending requests are correctly summarized in the checkpoint summary."""
|
||||
request_info_event = RequestInfoEvent(
|
||||
request_id="request-123",
|
||||
source_executor_id="review_gateway",
|
||||
request_type=MockRequest,
|
||||
request_data=MockRequest(),
|
||||
response_type=bool,
|
||||
)
|
||||
@@ -134,14 +161,12 @@ async def test_request_info_event_serializes_non_json_payloads() -> None:
|
||||
req_1 = RequestInfoEvent(
|
||||
request_id="req-1",
|
||||
source_executor_id="source",
|
||||
request_type=TimedApproval,
|
||||
request_data=TimedApproval(issued_at=datetime(2024, 5, 4, 12, 30, 45)),
|
||||
response_type=bool,
|
||||
)
|
||||
req_2 = RequestInfoEvent(
|
||||
request_id="req-2",
|
||||
source_executor_id="source",
|
||||
request_type=SlottedApproval,
|
||||
request_data=SlottedApproval(note="slot-based"),
|
||||
response_type=bool,
|
||||
)
|
||||
|
||||
@@ -76,7 +76,7 @@ class Coordinator(Executor):
|
||||
else:
|
||||
# Not in cache, forward to external
|
||||
self._pending_sub_workflow_requests[domain_request.id] = sub_workflow_request
|
||||
await ctx.request_info(domain_request, DomainCheckRequest, bool)
|
||||
await ctx.request_info(domain_request, bool)
|
||||
|
||||
@response_handler
|
||||
async def handle_domain_response(
|
||||
@@ -138,7 +138,7 @@ class EmailDomainValidator(Executor):
|
||||
return
|
||||
|
||||
# Request domain check from external source
|
||||
await ctx.request_info(request, DomainCheckRequest, bool)
|
||||
await ctx.request_info(request, bool)
|
||||
|
||||
@response_handler
|
||||
async def handle_domain_response(
|
||||
@@ -302,7 +302,7 @@ async def test_workflow_scoped_interception() -> None:
|
||||
|
||||
# Unknown source, forward to external
|
||||
self._pending_sub_workflow_requests[domain_request.id] = sub_workflow_request
|
||||
await ctx.request_info(domain_request, DomainCheckRequest, bool)
|
||||
await ctx.request_info(domain_request, bool)
|
||||
|
||||
@response_handler
|
||||
async def handle_domain_response(
|
||||
@@ -386,7 +386,7 @@ async def test_concurrent_sub_workflow_execution() -> None:
|
||||
|
||||
domain_request = sub_workflow_request.source_event.data
|
||||
self._pending_sub_workflow_requests[domain_request.id] = sub_workflow_request
|
||||
await ctx.request_info(domain_request, DomainCheckRequest, bool)
|
||||
await ctx.request_info(domain_request, bool)
|
||||
|
||||
@response_handler
|
||||
async def handle_domain_response(
|
||||
|
||||
@@ -179,7 +179,6 @@ def test_serialize_deserialize_roundtrip() -> None:
|
||||
instance = deserialized(
|
||||
request_id="request-123",
|
||||
source_executor_id="executor_1",
|
||||
request_type=str,
|
||||
request_data="test",
|
||||
response_type=str,
|
||||
)
|
||||
|
||||
@@ -89,7 +89,7 @@ class MockExecutorRequestApproval(Executor):
|
||||
async def mock_handler_a(self, message: NumberMessage, ctx: WorkflowContext) -> None:
|
||||
"""A mock handler that requests approval."""
|
||||
await ctx.set_shared_state(self.id, message.data)
|
||||
await ctx.request_info(MockRequest(prompt="Mock approval request"), MockRequest, ApprovalMessage)
|
||||
await ctx.request_info(MockRequest(prompt="Mock approval request"), ApprovalMessage)
|
||||
|
||||
@response_handler
|
||||
async def mock_handler_b(
|
||||
@@ -485,7 +485,6 @@ async def test_workflow_run_stream_from_checkpoint_with_responses(simple_executo
|
||||
"request_123": RequestInfoEvent(
|
||||
request_id="request_123",
|
||||
source_executor_id=simple_executor.id,
|
||||
request_type=str,
|
||||
request_data="Mock",
|
||||
response_type=str,
|
||||
).to_dict(),
|
||||
|
||||
@@ -60,7 +60,7 @@ class RequestingExecutor(Executor):
|
||||
@handler
|
||||
async def handle_message(self, _: list[ChatMessage], ctx: WorkflowContext) -> None:
|
||||
# Send a RequestInfoMessage to trigger the request info process
|
||||
await ctx.request_info("Mock request data", str, str)
|
||||
await ctx.request_info("Mock request data", str)
|
||||
|
||||
@response_handler
|
||||
async def handle_request_response(
|
||||
|
||||
@@ -78,7 +78,7 @@ class Requester(Executor):
|
||||
|
||||
@handler
|
||||
async def ask(self, _: str, ctx: WorkflowContext) -> None: # pragma: no cover
|
||||
await ctx.request_info("Mock request data", str, str)
|
||||
await ctx.request_info("Mock request data", str)
|
||||
|
||||
|
||||
async def test_idle_with_pending_requests_status_streaming():
|
||||
|
||||
+2
-3
@@ -128,9 +128,8 @@ class Coordinator(Executor):
|
||||
"Keep it under 30 words."
|
||||
)
|
||||
await ctx.request_info(
|
||||
DraftFeedbackRequest(prompt=prompt, draft_text=draft_text, conversation=conversation),
|
||||
DraftFeedbackRequest,
|
||||
str,
|
||||
request_data=DraftFeedbackRequest(prompt=prompt, draft_text=draft_text, conversation=conversation),
|
||||
response_type=str,
|
||||
)
|
||||
|
||||
@response_handler
|
||||
|
||||
+1
-1
@@ -76,7 +76,7 @@ class ReviewerWithHumanInTheLoop(Executor):
|
||||
print("Reviewer: Escalating to human manager...")
|
||||
|
||||
# Forward the request to a human manager by sending a HumanReviewRequest.
|
||||
await ctx.request_info(HumanReviewRequest(agent_request=request), HumanReviewRequest, ReviewResponse)
|
||||
await ctx.request_info(request_data=HumanReviewRequest(agent_request=request), response_type=ReviewResponse)
|
||||
|
||||
@response_handler
|
||||
async def accept_human_review(
|
||||
|
||||
+2
-3
@@ -125,13 +125,12 @@ class ReviewGateway(Executor):
|
||||
await ctx.set_executor_state({"iteration": iteration, "last_draft": draft})
|
||||
# Emit a human approval request.
|
||||
await ctx.request_info(
|
||||
HumanApprovalRequest(
|
||||
request_data=HumanApprovalRequest(
|
||||
prompt="Review the draft. Reply 'approve' or provide edit instructions.",
|
||||
draft=draft,
|
||||
iteration=iteration,
|
||||
),
|
||||
HumanApprovalRequest,
|
||||
str,
|
||||
response_type=str,
|
||||
)
|
||||
|
||||
@response_handler
|
||||
|
||||
@@ -145,7 +145,7 @@ class DraftReviewRouter(Executor):
|
||||
"Confirm CTA is action-oriented",
|
||||
],
|
||||
)
|
||||
await ctx.request_info(request, ReviewRequest, str)
|
||||
await ctx.request_info(request_data=request, response_type=str)
|
||||
|
||||
@response_handler
|
||||
async def forward_decision(
|
||||
@@ -251,7 +251,7 @@ class LaunchCoordinator(Executor):
|
||||
await ctx.set_executor_state(executor_state)
|
||||
|
||||
# Send the request without modification
|
||||
await ctx.request_info(review_request, ReviewRequest, str)
|
||||
await ctx.request_info(request_data=review_request, response_type=str)
|
||||
|
||||
@response_handler
|
||||
async def handle_request_response(
|
||||
|
||||
+4
-4
@@ -122,7 +122,7 @@ def build_resource_request_distribution_workflow() -> Workflow:
|
||||
|
||||
@handler
|
||||
async def run(self, request: ResourceRequest, ctx: WorkflowContext) -> None:
|
||||
await ctx.request_info(request, ResourceRequest, ResourceResponse)
|
||||
await ctx.request_info(request_data=request, response_type=ResourceResponse)
|
||||
|
||||
@response_handler
|
||||
async def handle_response(
|
||||
@@ -136,7 +136,7 @@ def build_resource_request_distribution_workflow() -> Workflow:
|
||||
|
||||
@handler
|
||||
async def run(self, request: PolicyRequest, ctx: WorkflowContext) -> None:
|
||||
await ctx.request_info(request, PolicyRequest, PolicyResponse)
|
||||
await ctx.request_info(request_data=request, response_type=PolicyResponse)
|
||||
|
||||
@response_handler
|
||||
async def handle_response(
|
||||
@@ -219,7 +219,7 @@ class ResourceAllocator(Executor):
|
||||
else:
|
||||
# Request cannot be fulfilled via cache, forward the request to external
|
||||
self._pending_requests[request_payload.id] = source_event
|
||||
await ctx.request_info(request_payload, ResourceRequest, ResourceResponse)
|
||||
await ctx.request_info(request_data=request_payload, response_type=ResourceResponse)
|
||||
|
||||
@response_handler
|
||||
async def handle_external_response(
|
||||
@@ -270,7 +270,7 @@ class PolicyEngine(Executor):
|
||||
else:
|
||||
# For other policy types, forward to external system
|
||||
self._pending_requests[request_payload.id] = source_event
|
||||
await ctx.request_info(request_payload, PolicyRequest, PolicyResponse)
|
||||
await ctx.request_info(request_data=request_payload, response_type=PolicyResponse)
|
||||
|
||||
@response_handler
|
||||
async def handle_external_response(
|
||||
|
||||
+1
-1
@@ -124,7 +124,7 @@ def build_email_address_validation_workflow() -> Workflow:
|
||||
print(f"🔍 Validating domain: '{domain}'")
|
||||
self._pending_domains[domain] = partial_result
|
||||
# Send a request to the external system via the request_info mechanism
|
||||
await ctx.request_info(domain, str, bool)
|
||||
await ctx.request_info(request_data=domain, response_type=bool)
|
||||
|
||||
@response_handler
|
||||
async def handle_domain_validation_response(
|
||||
|
||||
-1
@@ -115,7 +115,6 @@ class TurnManager(Executor):
|
||||
# Send a request with a prompt as the payload and expect a string reply.
|
||||
await ctx.request_info(
|
||||
request_data=HumanFeedbackRequest(prompt=prompt),
|
||||
request_type=HumanFeedbackRequest,
|
||||
response_type=str,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user