Fix failing tests

This commit is contained in:
Tao Chen
2026-06-04 11:26:53 -07:00
Unverified
parent c229a873d4
commit 2f385003fc
4 changed files with 81 additions and 57 deletions
@@ -6,11 +6,11 @@ import logging
import sys
import uuid
from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload
from .._agents import BaseAgent
from .._serialization import make_json_safe
from .._sessions import (
AgentSession,
ContextProvider,
@@ -55,6 +55,28 @@ class WorkflowAgent(BaseAgent):
# Class variable for the request info function name
REQUEST_INFO_FUNCTION_NAME: ClassVar[str] = "request_info"
@dataclass
class RequestInfoFunctionArgs:
request_id: str
request_event: WorkflowEvent
def to_dict(self) -> dict[str, Any]:
return {"request_id": self.request_id, "request_event": self.request_event.to_dict()}
@classmethod
def from_dict(cls, payload: dict[str, Any]) -> WorkflowAgent.RequestInfoFunctionArgs:
if "request_id" not in payload or "request_event" not in payload:
raise ValueError(
"Invalid payload for RequestInfoFunctionArgs. 'request_id' and 'request_event' are required."
)
if not payload["request_id"]:
raise ValueError("request_id cannot be empty.")
return cls(
request_id=payload.get("request_id", ""),
request_event=WorkflowEvent.from_dict(payload.get("request_event", {})),
)
def __init__(
self,
workflow: Workflow,
@@ -683,7 +705,7 @@ class WorkflowAgent(BaseAgent):
return event.data
request_id = event.request_id
args = {"request_id": request_id, "data": make_json_safe(event)}
args = self.RequestInfoFunctionArgs(request_id=request_id, request_event=event).to_dict()
return Content.from_function_call(
call_id=request_id,
@@ -32,6 +32,19 @@ from agent_framework import (
from agent_framework._workflows._typing_utils import deserialize_type
@dataclass
class HandoffRequest:
"""Module-level dataclass used by request_info tests.
Defined at module scope (not nested inside a test method) so
``serialize_type``/``deserialize_type`` can round-trip the request_type via
the importable qualified name ``tests.workflow.test_workflow_agent.HandoffRequest``.
"""
target_agent: str
reason: str
class SimpleExecutor(Executor):
"""Simple executor that emits a response based on input."""
@@ -255,10 +268,17 @@ class TestWorkflowAgent:
assert request_function_call.name == WorkflowAgent.REQUEST_INFO_FUNCTION_NAME
assert isinstance(request_function_call.arguments, dict)
assert request_function_call.arguments.get("request_id") is not None
assert request_function_call.arguments.get("data") is not None
request_data = request_function_call.arguments["data"]
assert request_data.get("type") == "request_info"
assert deserialize_type(request_data.get("response_type")) is str
assert request_function_call.arguments.get("request_event") is not None
request_event = request_function_call.arguments["request_event"]
assert request_event.get("type") == "request_info"
assert deserialize_type(request_event.get("response_type")) is str
deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(request_function_call.arguments)
assert deserialized_args.request_id == request_function_call.call_id
assert isinstance(deserialized_args.request_event, WorkflowEvent)
assert deserialized_args.request_event.type == "request_info"
assert deserialized_args.request_event.data == "Mock request data"
assert deserialized_args.request_event.response_type is str
# Verify the request is tracked in pending_requests
pending_requests = await workflow._runner_context.get_pending_request_info_events()
@@ -285,12 +305,6 @@ class TestWorkflowAgent:
def test_request_info_dataclass_arguments_are_serialized_when_content_is_created(self) -> None:
"""Test WorkflowAgent prepares request_info arguments before observability captures messages."""
@dataclass
class HandoffRequest:
target_agent: str
reason: str
executor = SimpleExecutor(id="executor1", response_text="Response")
workflow = WorkflowBuilder(start_executor=executor).build()
agent = WorkflowAgent(workflow=workflow, name="Request Test Agent")
@@ -305,13 +319,20 @@ class TestWorkflowAgent:
assert request_function_call.call_id == "request_123"
assert isinstance(request_function_call.arguments, dict)
assert request_function_call.arguments.get("data") is not None
data = request_function_call.arguments["data"]
assert data.get("type") == "request_info"
assert data.get("request_id") == "request_123"
assert data.get("source_executor_id") == "executor1"
assert deserialize_type(data.get("response_type")) is str
assert data.get("data") == {"target_agent": "helper", "reason": "overflow"}
assert request_function_call.arguments.get("request_event") is not None
request_event = request_function_call.arguments["request_event"]
assert request_event.get("type") == "request_info"
assert request_event.get("request_id") == "request_123"
assert request_event.get("source_executor_id") == "executor1"
assert deserialize_type(request_event.get("response_type")) is str
assert request_event.get("data") == HandoffRequest(target_agent="helper", reason="overflow")
deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(request_function_call.arguments)
assert deserialized_args.request_id == "request_123"
assert isinstance(deserialized_args.request_event, WorkflowEvent)
assert deserialized_args.request_event.type == "request_info"
assert deserialized_args.request_event.data == HandoffRequest(target_agent="helper", reason="overflow")
assert deserialized_args.request_event.response_type is str
def test_process_request_info_event_passes_through_function_approval_request(self) -> None:
"""If the event data is already a function approval request, it is forwarded unchanged.
@@ -572,12 +593,6 @@ class TestWorkflowAgent:
agent surfaces a synthesized ``function_call`` (name=REQUEST_INFO_FUNCTION_NAME)
and routes a matching ``function_result`` back to the executor.
"""
@dataclass
class HandoffRequest:
target_agent: str
reason: str
captured: dict[str, Any] = {}
class HandoffRequestingExecutor(Executor):
@@ -632,9 +647,16 @@ class TestWorkflowAgent:
assert isinstance(function_call.arguments, dict)
request_id = function_call.arguments["request_id"]
assert function_call.call_id == request_id
request_payload = function_call.arguments["data"]
request_payload = function_call.arguments["request_event"]
assert request_payload.get("type") == "request_info"
assert request_payload.get("data") == {"target_agent": "helper", "reason": "overflow"}
assert request_payload.get("data") == HandoffRequest(target_agent="helper", reason="overflow")
deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(function_call.arguments)
assert deserialized_args.request_id == request_id
assert isinstance(deserialized_args.request_event, WorkflowEvent)
assert deserialized_args.request_event.type == "request_info"
assert deserialized_args.request_event.data == HandoffRequest(target_agent="helper", reason="overflow")
assert deserialized_args.request_event.response_type is str
pending = await workflow._runner_context.get_pending_request_info_events()
assert request_id in pending