Fix failing tests

This commit is contained in:
Tao Chen
2026-06-04 11:26:53 -07:00
Unverified
parent c229a873d4
commit 2f385003fc
4 changed files with 81 additions and 57 deletions
@@ -6,11 +6,11 @@ import logging
import sys
import uuid
from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload
from .._agents import BaseAgent
from .._serialization import make_json_safe
from .._sessions import (
AgentSession,
ContextProvider,
@@ -55,6 +55,28 @@ class WorkflowAgent(BaseAgent):
# Class variable for the request info function name
REQUEST_INFO_FUNCTION_NAME: ClassVar[str] = "request_info"
@dataclass
class RequestInfoFunctionArgs:
request_id: str
request_event: WorkflowEvent
def to_dict(self) -> dict[str, Any]:
return {"request_id": self.request_id, "request_event": self.request_event.to_dict()}
@classmethod
def from_dict(cls, payload: dict[str, Any]) -> WorkflowAgent.RequestInfoFunctionArgs:
if "request_id" not in payload or "request_event" not in payload:
raise ValueError(
"Invalid payload for RequestInfoFunctionArgs. 'request_id' and 'request_event' are required."
)
if not payload["request_id"]:
raise ValueError("request_id cannot be empty.")
return cls(
request_id=payload.get("request_id", ""),
request_event=WorkflowEvent.from_dict(payload.get("request_event", {})),
)
def __init__(
self,
workflow: Workflow,
@@ -683,7 +705,7 @@ class WorkflowAgent(BaseAgent):
return event.data
request_id = event.request_id
args = {"request_id": request_id, "data": make_json_safe(event)}
args = self.RequestInfoFunctionArgs(request_id=request_id, request_event=event).to_dict()
return Content.from_function_call(
call_id=request_id,
@@ -32,6 +32,19 @@ from agent_framework import (
from agent_framework._workflows._typing_utils import deserialize_type
@dataclass
class HandoffRequest:
"""Module-level dataclass used by request_info tests.
Defined at module scope (not nested inside a test method) so
``serialize_type``/``deserialize_type`` can round-trip the request_type via
the importable qualified name ``tests.workflow.test_workflow_agent.HandoffRequest``.
"""
target_agent: str
reason: str
class SimpleExecutor(Executor):
"""Simple executor that emits a response based on input."""
@@ -255,10 +268,17 @@ class TestWorkflowAgent:
assert request_function_call.name == WorkflowAgent.REQUEST_INFO_FUNCTION_NAME
assert isinstance(request_function_call.arguments, dict)
assert request_function_call.arguments.get("request_id") is not None
assert request_function_call.arguments.get("data") is not None
request_data = request_function_call.arguments["data"]
assert request_data.get("type") == "request_info"
assert deserialize_type(request_data.get("response_type")) is str
assert request_function_call.arguments.get("request_event") is not None
request_event = request_function_call.arguments["request_event"]
assert request_event.get("type") == "request_info"
assert deserialize_type(request_event.get("response_type")) is str
deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(request_function_call.arguments)
assert deserialized_args.request_id == request_function_call.call_id
assert isinstance(deserialized_args.request_event, WorkflowEvent)
assert deserialized_args.request_event.type == "request_info"
assert deserialized_args.request_event.data == "Mock request data"
assert deserialized_args.request_event.response_type is str
# Verify the request is tracked in pending_requests
pending_requests = await workflow._runner_context.get_pending_request_info_events()
@@ -285,12 +305,6 @@ class TestWorkflowAgent:
def test_request_info_dataclass_arguments_are_serialized_when_content_is_created(self) -> None:
"""Test WorkflowAgent prepares request_info arguments before observability captures messages."""
@dataclass
class HandoffRequest:
target_agent: str
reason: str
executor = SimpleExecutor(id="executor1", response_text="Response")
workflow = WorkflowBuilder(start_executor=executor).build()
agent = WorkflowAgent(workflow=workflow, name="Request Test Agent")
@@ -305,13 +319,20 @@ class TestWorkflowAgent:
assert request_function_call.call_id == "request_123"
assert isinstance(request_function_call.arguments, dict)
assert request_function_call.arguments.get("data") is not None
data = request_function_call.arguments["data"]
assert data.get("type") == "request_info"
assert data.get("request_id") == "request_123"
assert data.get("source_executor_id") == "executor1"
assert deserialize_type(data.get("response_type")) is str
assert data.get("data") == {"target_agent": "helper", "reason": "overflow"}
assert request_function_call.arguments.get("request_event") is not None
request_event = request_function_call.arguments["request_event"]
assert request_event.get("type") == "request_info"
assert request_event.get("request_id") == "request_123"
assert request_event.get("source_executor_id") == "executor1"
assert deserialize_type(request_event.get("response_type")) is str
assert request_event.get("data") == HandoffRequest(target_agent="helper", reason="overflow")
deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(request_function_call.arguments)
assert deserialized_args.request_id == "request_123"
assert isinstance(deserialized_args.request_event, WorkflowEvent)
assert deserialized_args.request_event.type == "request_info"
assert deserialized_args.request_event.data == HandoffRequest(target_agent="helper", reason="overflow")
assert deserialized_args.request_event.response_type is str
def test_process_request_info_event_passes_through_function_approval_request(self) -> None:
"""If the event data is already a function approval request, it is forwarded unchanged.
@@ -572,12 +593,6 @@ class TestWorkflowAgent:
agent surfaces a synthesized ``function_call`` (name=REQUEST_INFO_FUNCTION_NAME)
and routes a matching ``function_result`` back to the executor.
"""
@dataclass
class HandoffRequest:
target_agent: str
reason: str
captured: dict[str, Any] = {}
class HandoffRequestingExecutor(Executor):
@@ -632,9 +647,16 @@ class TestWorkflowAgent:
assert isinstance(function_call.arguments, dict)
request_id = function_call.arguments["request_id"]
assert function_call.call_id == request_id
request_payload = function_call.arguments["data"]
request_payload = function_call.arguments["request_event"]
assert request_payload.get("type") == "request_info"
assert request_payload.get("data") == {"target_agent": "helper", "reason": "overflow"}
assert request_payload.get("data") == HandoffRequest(target_agent="helper", reason="overflow")
deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(function_call.arguments)
assert deserialized_args.request_id == request_id
assert isinstance(deserialized_args.request_event, WorkflowEvent)
assert deserialized_args.request_event.type == "request_info"
assert deserialized_args.request_event.data == HandoffRequest(target_agent="helper", reason="overflow")
assert deserialized_args.request_event.response_type is str
pending = await workflow._runner_context.get_pending_request_info_events()
assert request_id in pending
@@ -134,15 +134,11 @@ def handle_response_and_requests(response: AgentResponse) -> dict[str, HandoffAg
if message.text:
print(f"- {message.author_name or message.role}: {message.text}")
for content in message.contents:
if content.type == "function_call":
if isinstance(content.arguments, dict):
request = WorkflowAgent.RequestInfoFunctionArgs.from_dict(content.arguments)
elif isinstance(content.arguments, str):
request = WorkflowAgent.RequestInfoFunctionArgs.from_json(content.arguments)
else:
raise ValueError("Invalid arguments type. Expecting a request info structure for this sample.")
if isinstance(request.data, HandoffAgentUserRequest):
pending_requests[request.request_id] = request.data
if content.type == "function_call" and content.name == WorkflowAgent.REQUEST_INFO_FUNCTION_NAME:
request_function_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(content.arguments) # type: ignore
request_id = request_function_args.request_id
request_event = request_function_args.request_event
pending_requests[request_id] = request_event.data
return pending_requests
@@ -3,10 +3,8 @@
import asyncio
import os
import sys
from collections.abc import Mapping
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from agent_framework.foundry import FoundryChatClient
from azure.identity import AzureCliCredential
@@ -141,28 +139,14 @@ async def main() -> None:
# Handle the human review if required.
if human_review_function_call:
# Parse the human review request arguments.
human_request_args = human_review_function_call.arguments
if isinstance(human_request_args, str):
request: WorkflowAgent.RequestInfoFunctionArgs = WorkflowAgent.RequestInfoFunctionArgs.from_json(
human_request_args
)
elif isinstance(human_request_args, Mapping):
request = WorkflowAgent.RequestInfoFunctionArgs.from_dict(dict(human_request_args))
else:
raise TypeError("Unexpected argument type for human review function call.")
request_payload: Any = request.data
human_request_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(human_review_function_call.arguments) # type: ignore
request_payload = human_request_args.request_event.data
if not isinstance(request_payload, HumanReviewRequest):
raise ValueError("Human review request payload must be a HumanReviewRequest.")
agent_request = request_payload.agent_request
if agent_request is None:
raise ValueError("Human review request must include agent_request.")
request_id = agent_request.request_id
if not request_payload.agent_request:
raise ValueError("Human review request must contain an agent_request.")
# Mock a human response approval for demonstration purposes.
human_response = ReviewResponse(request_id=request_id, feedback="", approved=True)
human_response = ReviewResponse(request_id=request_payload.agent_request.request_id, feedback="", approved=True)
# Create the function call result object to send back to the agent.
human_review_function_result = Content(
"function_result",