mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Fix failing tests
This commit is contained in:
@@ -6,11 +6,11 @@ import logging
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload
|
||||
|
||||
from .._agents import BaseAgent
|
||||
from .._serialization import make_json_safe
|
||||
from .._sessions import (
|
||||
AgentSession,
|
||||
ContextProvider,
|
||||
@@ -55,6 +55,28 @@ class WorkflowAgent(BaseAgent):
|
||||
# Class variable for the request info function name
|
||||
REQUEST_INFO_FUNCTION_NAME: ClassVar[str] = "request_info"
|
||||
|
||||
@dataclass
|
||||
class RequestInfoFunctionArgs:
|
||||
request_id: str
|
||||
request_event: WorkflowEvent
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {"request_id": self.request_id, "request_event": self.request_event.to_dict()}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, payload: dict[str, Any]) -> WorkflowAgent.RequestInfoFunctionArgs:
|
||||
if "request_id" not in payload or "request_event" not in payload:
|
||||
raise ValueError(
|
||||
"Invalid payload for RequestInfoFunctionArgs. 'request_id' and 'request_event' are required."
|
||||
)
|
||||
if not payload["request_id"]:
|
||||
raise ValueError("request_id cannot be empty.")
|
||||
|
||||
return cls(
|
||||
request_id=payload.get("request_id", ""),
|
||||
request_event=WorkflowEvent.from_dict(payload.get("request_event", {})),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
@@ -683,7 +705,7 @@ class WorkflowAgent(BaseAgent):
|
||||
return event.data
|
||||
|
||||
request_id = event.request_id
|
||||
args = {"request_id": request_id, "data": make_json_safe(event)}
|
||||
args = self.RequestInfoFunctionArgs(request_id=request_id, request_event=event).to_dict()
|
||||
|
||||
return Content.from_function_call(
|
||||
call_id=request_id,
|
||||
|
||||
@@ -32,6 +32,19 @@ from agent_framework import (
|
||||
from agent_framework._workflows._typing_utils import deserialize_type
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandoffRequest:
|
||||
"""Module-level dataclass used by request_info tests.
|
||||
|
||||
Defined at module scope (not nested inside a test method) so
|
||||
``serialize_type``/``deserialize_type`` can round-trip the request_type via
|
||||
the importable qualified name ``tests.workflow.test_workflow_agent.HandoffRequest``.
|
||||
"""
|
||||
|
||||
target_agent: str
|
||||
reason: str
|
||||
|
||||
|
||||
class SimpleExecutor(Executor):
|
||||
"""Simple executor that emits a response based on input."""
|
||||
|
||||
@@ -255,10 +268,17 @@ class TestWorkflowAgent:
|
||||
assert request_function_call.name == WorkflowAgent.REQUEST_INFO_FUNCTION_NAME
|
||||
assert isinstance(request_function_call.arguments, dict)
|
||||
assert request_function_call.arguments.get("request_id") is not None
|
||||
assert request_function_call.arguments.get("data") is not None
|
||||
request_data = request_function_call.arguments["data"]
|
||||
assert request_data.get("type") == "request_info"
|
||||
assert deserialize_type(request_data.get("response_type")) is str
|
||||
assert request_function_call.arguments.get("request_event") is not None
|
||||
request_event = request_function_call.arguments["request_event"]
|
||||
assert request_event.get("type") == "request_info"
|
||||
assert deserialize_type(request_event.get("response_type")) is str
|
||||
|
||||
deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(request_function_call.arguments)
|
||||
assert deserialized_args.request_id == request_function_call.call_id
|
||||
assert isinstance(deserialized_args.request_event, WorkflowEvent)
|
||||
assert deserialized_args.request_event.type == "request_info"
|
||||
assert deserialized_args.request_event.data == "Mock request data"
|
||||
assert deserialized_args.request_event.response_type is str
|
||||
|
||||
# Verify the request is tracked in pending_requests
|
||||
pending_requests = await workflow._runner_context.get_pending_request_info_events()
|
||||
@@ -285,12 +305,6 @@ class TestWorkflowAgent:
|
||||
|
||||
def test_request_info_dataclass_arguments_are_serialized_when_content_is_created(self) -> None:
|
||||
"""Test WorkflowAgent prepares request_info arguments before observability captures messages."""
|
||||
|
||||
@dataclass
|
||||
class HandoffRequest:
|
||||
target_agent: str
|
||||
reason: str
|
||||
|
||||
executor = SimpleExecutor(id="executor1", response_text="Response")
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="Request Test Agent")
|
||||
@@ -305,13 +319,20 @@ class TestWorkflowAgent:
|
||||
|
||||
assert request_function_call.call_id == "request_123"
|
||||
assert isinstance(request_function_call.arguments, dict)
|
||||
assert request_function_call.arguments.get("data") is not None
|
||||
data = request_function_call.arguments["data"]
|
||||
assert data.get("type") == "request_info"
|
||||
assert data.get("request_id") == "request_123"
|
||||
assert data.get("source_executor_id") == "executor1"
|
||||
assert deserialize_type(data.get("response_type")) is str
|
||||
assert data.get("data") == {"target_agent": "helper", "reason": "overflow"}
|
||||
assert request_function_call.arguments.get("request_event") is not None
|
||||
request_event = request_function_call.arguments["request_event"]
|
||||
assert request_event.get("type") == "request_info"
|
||||
assert request_event.get("request_id") == "request_123"
|
||||
assert request_event.get("source_executor_id") == "executor1"
|
||||
assert deserialize_type(request_event.get("response_type")) is str
|
||||
assert request_event.get("data") == HandoffRequest(target_agent="helper", reason="overflow")
|
||||
|
||||
deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(request_function_call.arguments)
|
||||
assert deserialized_args.request_id == "request_123"
|
||||
assert isinstance(deserialized_args.request_event, WorkflowEvent)
|
||||
assert deserialized_args.request_event.type == "request_info"
|
||||
assert deserialized_args.request_event.data == HandoffRequest(target_agent="helper", reason="overflow")
|
||||
assert deserialized_args.request_event.response_type is str
|
||||
|
||||
def test_process_request_info_event_passes_through_function_approval_request(self) -> None:
|
||||
"""If the event data is already a function approval request, it is forwarded unchanged.
|
||||
@@ -572,12 +593,6 @@ class TestWorkflowAgent:
|
||||
agent surfaces a synthesized ``function_call`` (name=REQUEST_INFO_FUNCTION_NAME)
|
||||
and routes a matching ``function_result`` back to the executor.
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class HandoffRequest:
|
||||
target_agent: str
|
||||
reason: str
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
class HandoffRequestingExecutor(Executor):
|
||||
@@ -632,9 +647,16 @@ class TestWorkflowAgent:
|
||||
assert isinstance(function_call.arguments, dict)
|
||||
request_id = function_call.arguments["request_id"]
|
||||
assert function_call.call_id == request_id
|
||||
request_payload = function_call.arguments["data"]
|
||||
request_payload = function_call.arguments["request_event"]
|
||||
assert request_payload.get("type") == "request_info"
|
||||
assert request_payload.get("data") == {"target_agent": "helper", "reason": "overflow"}
|
||||
assert request_payload.get("data") == HandoffRequest(target_agent="helper", reason="overflow")
|
||||
|
||||
deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(function_call.arguments)
|
||||
assert deserialized_args.request_id == request_id
|
||||
assert isinstance(deserialized_args.request_event, WorkflowEvent)
|
||||
assert deserialized_args.request_event.type == "request_info"
|
||||
assert deserialized_args.request_event.data == HandoffRequest(target_agent="helper", reason="overflow")
|
||||
assert deserialized_args.request_event.response_type is str
|
||||
|
||||
pending = await workflow._runner_context.get_pending_request_info_events()
|
||||
assert request_id in pending
|
||||
|
||||
Reference in New Issue
Block a user