diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index f64638b4cb..f1d1507702 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -6,11 +6,11 @@ import logging import sys import uuid from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence +from dataclasses import dataclass from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload from .._agents import BaseAgent -from .._serialization import make_json_safe from .._sessions import ( AgentSession, ContextProvider, @@ -55,6 +55,28 @@ class WorkflowAgent(BaseAgent): # Class variable for the request info function name REQUEST_INFO_FUNCTION_NAME: ClassVar[str] = "request_info" + @dataclass + class RequestInfoFunctionArgs: + request_id: str + request_event: WorkflowEvent + + def to_dict(self) -> dict[str, Any]: + return {"request_id": self.request_id, "request_event": self.request_event.to_dict()} + + @classmethod + def from_dict(cls, payload: dict[str, Any]) -> WorkflowAgent.RequestInfoFunctionArgs: + if "request_id" not in payload or "request_event" not in payload: + raise ValueError( + "Invalid payload for RequestInfoFunctionArgs. 'request_id' and 'request_event' are required." + ) + if not payload["request_id"]: + raise ValueError("request_id cannot be empty.") + + return cls( + request_id=payload.get("request_id", ""), + request_event=WorkflowEvent.from_dict(payload.get("request_event", {})), + ) + def __init__( self, workflow: Workflow, @@ -683,7 +705,7 @@ class WorkflowAgent(BaseAgent): return event.data request_id = event.request_id - args = {"request_id": request_id, "data": make_json_safe(event)} + args = self.RequestInfoFunctionArgs(request_id=request_id, request_event=event).to_dict() return Content.from_function_call( call_id=request_id, diff --git a/python/packages/core/tests/workflow/test_workflow_agent.py b/python/packages/core/tests/workflow/test_workflow_agent.py index fe0c854f73..67cea0ed38 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent.py +++ b/python/packages/core/tests/workflow/test_workflow_agent.py @@ -32,6 +32,19 @@ from agent_framework import ( from agent_framework._workflows._typing_utils import deserialize_type +@dataclass +class HandoffRequest: + """Module-level dataclass used by request_info tests. + + Defined at module scope (not nested inside a test method) so + ``serialize_type``/``deserialize_type`` can round-trip the request_type via + the importable qualified name ``tests.workflow.test_workflow_agent.HandoffRequest``. + """ + + target_agent: str + reason: str + + class SimpleExecutor(Executor): """Simple executor that emits a response based on input.""" @@ -255,10 +268,17 @@ class TestWorkflowAgent: assert request_function_call.name == WorkflowAgent.REQUEST_INFO_FUNCTION_NAME assert isinstance(request_function_call.arguments, dict) assert request_function_call.arguments.get("request_id") is not None - assert request_function_call.arguments.get("data") is not None - request_data = request_function_call.arguments["data"] - assert request_data.get("type") == "request_info" - assert deserialize_type(request_data.get("response_type")) is str + assert request_function_call.arguments.get("request_event") is not None + request_event = request_function_call.arguments["request_event"] + assert request_event.get("type") == "request_info" + assert deserialize_type(request_event.get("response_type")) is str + + deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(request_function_call.arguments) + assert deserialized_args.request_id == request_function_call.call_id + assert isinstance(deserialized_args.request_event, WorkflowEvent) + assert deserialized_args.request_event.type == "request_info" + assert deserialized_args.request_event.data == "Mock request data" + assert deserialized_args.request_event.response_type is str # Verify the request is tracked in pending_requests pending_requests = await workflow._runner_context.get_pending_request_info_events() @@ -285,12 +305,6 @@ class TestWorkflowAgent: def test_request_info_dataclass_arguments_are_serialized_when_content_is_created(self) -> None: """Test WorkflowAgent prepares request_info arguments before observability captures messages.""" - - @dataclass - class HandoffRequest: - target_agent: str - reason: str - executor = SimpleExecutor(id="executor1", response_text="Response") workflow = WorkflowBuilder(start_executor=executor).build() agent = WorkflowAgent(workflow=workflow, name="Request Test Agent") @@ -305,13 +319,20 @@ class TestWorkflowAgent: assert request_function_call.call_id == "request_123" assert isinstance(request_function_call.arguments, dict) - assert request_function_call.arguments.get("data") is not None - data = request_function_call.arguments["data"] - assert data.get("type") == "request_info" - assert data.get("request_id") == "request_123" - assert data.get("source_executor_id") == "executor1" - assert deserialize_type(data.get("response_type")) is str - assert data.get("data") == {"target_agent": "helper", "reason": "overflow"} + assert request_function_call.arguments.get("request_event") is not None + request_event = request_function_call.arguments["request_event"] + assert request_event.get("type") == "request_info" + assert request_event.get("request_id") == "request_123" + assert request_event.get("source_executor_id") == "executor1" + assert deserialize_type(request_event.get("response_type")) is str + assert request_event.get("data") == HandoffRequest(target_agent="helper", reason="overflow") + + deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(request_function_call.arguments) + assert deserialized_args.request_id == "request_123" + assert isinstance(deserialized_args.request_event, WorkflowEvent) + assert deserialized_args.request_event.type == "request_info" + assert deserialized_args.request_event.data == HandoffRequest(target_agent="helper", reason="overflow") + assert deserialized_args.request_event.response_type is str def test_process_request_info_event_passes_through_function_approval_request(self) -> None: """If the event data is already a function approval request, it is forwarded unchanged. @@ -572,12 +593,6 @@ class TestWorkflowAgent: agent surfaces a synthesized ``function_call`` (name=REQUEST_INFO_FUNCTION_NAME) and routes a matching ``function_result`` back to the executor. """ - - @dataclass - class HandoffRequest: - target_agent: str - reason: str - captured: dict[str, Any] = {} class HandoffRequestingExecutor(Executor): @@ -632,9 +647,16 @@ class TestWorkflowAgent: assert isinstance(function_call.arguments, dict) request_id = function_call.arguments["request_id"] assert function_call.call_id == request_id - request_payload = function_call.arguments["data"] + request_payload = function_call.arguments["request_event"] assert request_payload.get("type") == "request_info" - assert request_payload.get("data") == {"target_agent": "helper", "reason": "overflow"} + assert request_payload.get("data") == HandoffRequest(target_agent="helper", reason="overflow") + + deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(function_call.arguments) + assert deserialized_args.request_id == request_id + assert isinstance(deserialized_args.request_event, WorkflowEvent) + assert deserialized_args.request_event.type == "request_info" + assert deserialized_args.request_event.data == HandoffRequest(target_agent="helper", reason="overflow") + assert deserialized_args.request_event.response_type is str pending = await workflow._runner_context.get_pending_request_info_events() assert request_id in pending diff --git a/python/samples/03-workflows/agents/handoff_workflow_as_agent.py b/python/samples/03-workflows/agents/handoff_workflow_as_agent.py index f392692af9..464f585f09 100644 --- a/python/samples/03-workflows/agents/handoff_workflow_as_agent.py +++ b/python/samples/03-workflows/agents/handoff_workflow_as_agent.py @@ -134,15 +134,11 @@ def handle_response_and_requests(response: AgentResponse) -> dict[str, HandoffAg if message.text: print(f"- {message.author_name or message.role}: {message.text}") for content in message.contents: - if content.type == "function_call": - if isinstance(content.arguments, dict): - request = WorkflowAgent.RequestInfoFunctionArgs.from_dict(content.arguments) - elif isinstance(content.arguments, str): - request = WorkflowAgent.RequestInfoFunctionArgs.from_json(content.arguments) - else: - raise ValueError("Invalid arguments type. Expecting a request info structure for this sample.") - if isinstance(request.data, HandoffAgentUserRequest): - pending_requests[request.request_id] = request.data + if content.type == "function_call" and content.name == WorkflowAgent.REQUEST_INFO_FUNCTION_NAME: + request_function_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(content.arguments) # type: ignore + request_id = request_function_args.request_id + request_event = request_function_args.request_event + pending_requests[request_id] = request_event.data return pending_requests diff --git a/python/samples/03-workflows/agents/workflow_as_agent_human_in_the_loop.py b/python/samples/03-workflows/agents/workflow_as_agent_human_in_the_loop.py index 3b8ccc0faa..1622491d12 100644 --- a/python/samples/03-workflows/agents/workflow_as_agent_human_in_the_loop.py +++ b/python/samples/03-workflows/agents/workflow_as_agent_human_in_the_loop.py @@ -3,10 +3,8 @@ import asyncio import os import sys -from collections.abc import Mapping from dataclasses import dataclass from pathlib import Path -from typing import Any from agent_framework.foundry import FoundryChatClient from azure.identity import AzureCliCredential @@ -141,28 +139,14 @@ async def main() -> None: # Handle the human review if required. if human_review_function_call: # Parse the human review request arguments. - human_request_args = human_review_function_call.arguments - if isinstance(human_request_args, str): - request: WorkflowAgent.RequestInfoFunctionArgs = WorkflowAgent.RequestInfoFunctionArgs.from_json( - human_request_args - ) - elif isinstance(human_request_args, Mapping): - request = WorkflowAgent.RequestInfoFunctionArgs.from_dict(dict(human_request_args)) - else: - raise TypeError("Unexpected argument type for human review function call.") - - request_payload: Any = request.data + human_request_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(human_review_function_call.arguments) # type: ignore + request_payload = human_request_args.request_event.data if not isinstance(request_payload, HumanReviewRequest): raise ValueError("Human review request payload must be a HumanReviewRequest.") - - agent_request = request_payload.agent_request - if agent_request is None: - raise ValueError("Human review request must include agent_request.") - - request_id = agent_request.request_id + if not request_payload.agent_request: + raise ValueError("Human review request must contain an agent_request.") # Mock a human response approval for demonstration purposes. - human_response = ReviewResponse(request_id=request_id, feedback="", approved=True) - + human_response = ReviewResponse(request_id=request_payload.agent_request.request_id, feedback="", approved=True) # Create the function call result object to send back to the agent. human_review_function_result = Content( "function_result",