Python: Refactor workflow as agent pending request handling (#6259)

* WIP: Refactor Workflow as agent pending request handling

* WIP: debugging empty message bug

* Working: Workflow as agent with function approval

* Address Copilot comments

* Fix mypy

* Address comments and fix pipeline

* Request info non function approval now becomes function call

* Revert uv.lock

* Fix mypy

* Bump min version of azure-ai-project

* Remove RequestInfoFunctionArgs

* fix tests

* Fix failing tests

* Fix sample
This commit is contained in:
Tao Chen
2026-06-05 10:23:19 -07:00
committed by GitHub
Unverified
parent d5335fbeae
commit 9cafd7e58b
13 changed files with 1841 additions and 296 deletions
@@ -2,7 +2,6 @@
from __future__ import annotations
import json
import logging
import sys
import uuid
@@ -12,7 +11,6 @@ from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload
from .._agents import BaseAgent
from .._serialization import make_json_safe
from .._sessions import (
AgentSession,
ContextProvider,
@@ -30,11 +28,12 @@ from .._types import (
UsageDetails,
add_usage_details,
)
from ..exceptions import AgentInvalidRequestException, AgentInvalidResponseException
from ..exceptions import AgentException, AgentInvalidRequestException, AgentInvalidResponseException
from ._checkpoint import CheckpointStorage
from ._events import (
AGENT_FORWARDED_EVENT_TYPES,
WorkflowEvent,
WorkflowRunState,
)
from ._message_utils import normalize_messages_input
from ._typing_utils import is_instance_of, is_type_compatible
@@ -59,27 +58,24 @@ class WorkflowAgent(BaseAgent):
@dataclass
class RequestInfoFunctionArgs:
request_id: str
data: Any
request_event: WorkflowEvent
def to_dict(self) -> dict[str, Any]:
return {"request_id": self.request_id, "data": make_json_safe(self.data)}
def to_json(self) -> str:
return json.dumps(self.to_dict())
return {"request_id": self.request_id, "request_event": self.request_event.to_dict()}
@classmethod
def from_dict(cls, payload: dict[str, Any]) -> WorkflowAgent.RequestInfoFunctionArgs:
return cls(request_id=payload.get("request_id", ""), data=payload.get("data"))
if "request_id" not in payload or "request_event" not in payload:
raise ValueError(
"Invalid payload for RequestInfoFunctionArgs. 'request_id' and 'request_event' are required."
)
if not payload["request_id"]:
raise ValueError("request_id cannot be empty.")
@classmethod
def from_json(cls, raw: str) -> WorkflowAgent.RequestInfoFunctionArgs:
try:
parsed: Any = json.loads(raw)
except json.JSONDecodeError as exc:
raise ValueError(f"RequestInfoFunctionArgs JSON payload is malformed: {exc}") from exc
if not isinstance(parsed, dict):
raise ValueError("RequestInfoFunctionArgs JSON payload must decode to a mapping")
return cls.from_dict(cast(dict[str, Any], parsed))
return cls(
request_id=payload.get("request_id", ""),
request_event=WorkflowEvent.from_dict(payload.get("request_event", {})),
)
def __init__(
self,
@@ -129,16 +125,11 @@ class WorkflowAgent(BaseAgent):
**kwargs,
)
self._workflow: Workflow = workflow
self._pending_requests: dict[str, WorkflowEvent[Any]] = {}
@property
def workflow(self) -> Workflow:
return self._workflow
@property
def pending_requests(self) -> dict[str, WorkflowEvent[Any]]:
return self._pending_requests
# region Run Methods
@overload
@@ -182,7 +173,7 @@ class WorkflowAgent(BaseAgent):
Args:
messages: The message(s) to send to the workflow. Required for new runs,
should be None when resuming from checkpoint.
could be None if only restoring the underlying workflow from a checkpoint.
Keyword Args:
stream: If True, returns an async iterable of updates. If False (default),
@@ -416,101 +407,79 @@ class WorkflowAgent(BaseAgent):
Yields:
WorkflowEvent objects from the workflow execution.
"""
# Determine the execution mode based on state.
# The streaming flag controls the workflow's internal streaming mode,
# which affects executor behavior (e.g. AgentExecutor emits different event
# types in streaming vs non-streaming mode).
if bool(self.pending_requests):
function_responses = self._process_pending_requests(input_messages)
# Restore the workflow state if a checkpoint is provided
if checkpoint_id is not None:
if checkpoint_storage is None:
raise AgentInvalidRequestException("checkpoint_storage must be provided when checkpoint_id is provided")
logger.debug(f"Restoring workflow from checkpoint {checkpoint_id}")
# Restore the workflow from checkpoint
if streaming:
async for event in self.workflow.run(
responses=function_responses,
stream=True,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event
else:
for event in await self.workflow.run(
responses=function_responses,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event
elif checkpoint_id is not None:
# Restore the prior workflow state from the checkpoint. Shared
# state (e.g. accumulated conversation history maintained by the
# workflow's executors) survives across turns because Workflow.run
# no longer wipes state per call. Callers who want to deliver a
# new user message after restore should make a second
# `workflow.run(message=...)` call - they are NOT mutually
# exclusive on the same instance, but each must be its own call.
if streaming:
async for event in self.workflow.run(
async for _ in self.workflow.run(
stream=True,
checkpoint_id=checkpoint_id,
checkpoint_storage=checkpoint_storage,
):
pass
else:
_ = await self.workflow.run(
checkpoint_id=checkpoint_id,
checkpoint_storage=checkpoint_storage,
)
if not input_messages:
logger.info("No input messages provided; the workflow has been restored to the checkpoint state.")
return
final_state = self._workflow.status
logger.debug(f"Workflow state: {final_state}")
if final_state == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS:
# Extract function responses from input messages, and ensure that
# only function responses are present in messages if there is any
# pending request.
# NOTE: It is possible that some pending requests are not fulfilled,
# and we will let the workflow to handle this -- the agent does not
# have an opinion on this.
function_responses = self._extract_function_responses(input_messages)
if streaming:
async for event in self.workflow.run(
responses=function_responses,
stream=True,
checkpoint_storage=checkpoint_storage,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event
else:
for event in await self.workflow.run(
checkpoint_id=checkpoint_id,
responses=function_responses,
checkpoint_storage=checkpoint_storage,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event
elif final_state == WorkflowRunState.IDLE:
if streaming:
async for event in self.workflow.run(
message=input_messages,
stream=True,
checkpoint_storage=checkpoint_storage,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event
else:
for event in await self.workflow.run(
message=input_messages,
checkpoint_storage=checkpoint_storage,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event
else:
if streaming:
async for event in self.workflow.run(
message=input_messages,
stream=True,
checkpoint_storage=checkpoint_storage,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event
else:
for event in await self.workflow.run(
message=input_messages,
checkpoint_storage=checkpoint_storage,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event
raise AgentException(f"The underlying workflow is in an invalid state to restart: {final_state}.")
# endregion Run Methods
def _process_pending_requests(self, input_messages: Sequence[Message]) -> dict[str, Any]:
"""Process pending requests by extracting function responses and updating state.
Args:
input_messages: Input messages that may contain function responses.
Returns:
A dictionary mapping request IDs to their response data.
"""
logger.info(f"Continuing workflow to address {len(self.pending_requests)} requests")
# Extract function responses from input messages, and ensure that
# only function responses are present in messages if there is any
# pending request.
function_responses = self._extract_function_responses(input_messages)
# Pop pending requests if fulfilled.
for request_id in list(self.pending_requests.keys()):
if request_id in function_responses:
self.pending_requests.pop(request_id)
# NOTE: It is possible that some pending requests are not fulfilled,
# and we will let the workflow to handle this -- the agent does not
# have an opinion on this.
return function_responses
def _convert_workflow_events_to_agent_response(
self,
response_id: str,
@@ -528,10 +497,10 @@ class WorkflowAgent(BaseAgent):
for output_event in output_events:
if output_event.type == "request_info":
function_call, approval_request = self._process_request_info_event(output_event)
request_content = self._process_request_info_event(output_event)
messages.append(
Message(
contents=[function_call, approval_request],
contents=[request_content],
role="assistant",
author_name=output_event.source_executor_id,
message_id=str(uuid.uuid4()),
@@ -598,38 +567,6 @@ class WorkflowAgent(BaseAgent):
raw_representation=raw_representations,
)
def _process_request_info_event(
self,
event: WorkflowEvent[Any],
) -> tuple[Content, Content]:
"""Convert a request_info event to FunctionCallContent and FunctionApprovalRequestContent.
Args:
event: A WorkflowEvent with type='request_info'.
Returns:
A tuple of (FunctionCallContent, FunctionApprovalRequestContent).
"""
request_id = event.request_id
if not request_id:
raise ValueError("request_info event must have a request_id")
self.pending_requests[request_id] = event
args = self.RequestInfoFunctionArgs(request_id=request_id, data=event.data).to_dict()
function_call = Content.from_function_call(
call_id=request_id,
name=self.REQUEST_INFO_FUNCTION_NAME,
arguments=args,
)
approval_request = Content.from_function_approval_request(
id=request_id,
function_call=function_call,
additional_properties={"request_id": request_id},
)
return function_call, approval_request
def _convert_workflow_event_to_agent_response_updates(
self,
response_id: str,
@@ -731,85 +668,72 @@ class WorkflowAgent(BaseAgent):
]
if event.type == "request_info":
# Store the pending request for later correlation
request_id = event.request_id
if not request_id:
raise ValueError("request_info event must have a request_id")
self.pending_requests[request_id] = event
args = self.RequestInfoFunctionArgs(request_id=request_id, data=event.data).to_dict()
function_call = Content.from_function_call(
call_id=request_id,
name=self.REQUEST_INFO_FUNCTION_NAME,
arguments=args,
)
approval_request = Content.from_function_approval_request(
id=request_id,
function_call=function_call,
additional_properties={"request_id": request_id},
)
request_content = self._process_request_info_event(event)
return [
AgentResponseUpdate(
contents=[function_call, approval_request],
contents=[request_content],
role="assistant",
author_name=self.name,
response_id=response_id,
message_id=str(uuid.uuid4()),
created_at=datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
raw_representation=event,
)
]
# Ignore workflow-internal events
return []
def _process_request_info_event(
self,
event: WorkflowEvent[Any],
) -> Content:
"""Convert a request_info event to FunctionApprovalRequestContent.
Args:
event: A WorkflowEvent with type='request_info'.
Returns:
A content object representing the request info. The content can be a `function_approval_request`
or a `function_call` depending on the structure of the event data.
Note:
If the event data is already a FunctionApprovalRequestContent, it will be returned as-is.
"""
if isinstance(event.data, Content) and event.data.user_input_request:
# Return the event data as-is if it's already a properly formed FunctionApprovalRequestContent
return event.data
request_id = event.request_id
args = self.RequestInfoFunctionArgs(request_id=request_id, request_event=event).to_dict()
return Content.from_function_call(
call_id=request_id,
name=self.REQUEST_INFO_FUNCTION_NAME,
arguments=args,
)
def _extract_function_responses(self, input_messages: Sequence[Message]) -> dict[str, Any]:
"""Extract function responses from input messages."""
"""Extract function responses from input messages.
The responses are for pending requests that the workflow is waiting on, and
will be passed to the workflow. The pending requests are processed to either
`function_approval_request` or `function_call` content by `_process_request_info_event`.
"""
function_responses: dict[str, Any] = {}
for message in input_messages:
for content in message.contents:
if content.type == "function_approval_response":
# Parse the function arguments to recover request payload
arguments_payload = content.function_call.arguments # type: ignore[attr-defined, union-attr]
if isinstance(arguments_payload, str):
try:
parsed_args = self.RequestInfoFunctionArgs.from_json(arguments_payload)
except ValueError as exc:
raise AgentInvalidResponseException(
"FunctionApprovalResponseContent arguments must decode to a mapping."
) from exc
elif isinstance(arguments_payload, dict):
parsed_args = self.RequestInfoFunctionArgs.from_dict(arguments_payload)
else:
raise AgentInvalidResponseException(
"FunctionApprovalResponseContent arguments must be a mapping or JSON string."
)
request_id = parsed_args.request_id or content.id # type: ignore[attr-defined]
if not content.approved: # type: ignore[attr-defined]
raise AgentInvalidResponseException(f"Request '{request_id}' was not approved by the caller.")
if request_id in self.pending_requests:
function_responses[request_id] = parsed_args.data
elif bool(self.pending_requests):
raise AgentInvalidRequestException(
"Only responses for pending requests are allowed when there are outstanding approvals."
)
request_id: str = content.id # type: ignore[assignment]
function_responses[request_id] = content
elif content.type == "function_result":
request_id = content.call_id # type: ignore[attr-defined]
if request_id in self.pending_requests:
response_data = content.result if hasattr(content, "result") else str(content) # type: ignore[attr-defined]
function_responses[request_id] = response_data
elif bool(self.pending_requests):
raise AgentInvalidRequestException(
"Only function responses for pending requests are allowed while requests are outstanding."
)
response_data = content.result if hasattr(content, "result") else str(content) # type: ignore[attr-defined]
function_responses[content.call_id] = response_data # type: ignore
else:
if bool(self.pending_requests):
raise AgentInvalidResponseException(
"Unexpected content type while awaiting request info responses."
)
raise AgentInvalidResponseException(
"Unexpected content type while awaiting request info responses."
)
return function_responses
def _extract_contents(self, data: Any) -> list[Content]:
@@ -429,15 +429,30 @@ class AgentExecutor(Executor):
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
)
await ctx.yield_output(response)
# Handle any user input requests
if response.user_input_requests:
user_input_request_count = len(response.user_input_requests)
total_message_content_count = sum(len(msg.contents) for msg in response.messages)
if user_input_request_count != total_message_content_count:
logger.warning(
"Response %s contains %d user input requests but total message contents are %d. "
"This indicates the response contains both user input requests and message contents. "
"Double check if this is the intended behavior, as non user input request contents in "
"this response will not be emitted.",
response.response_id,
user_input_request_count,
total_message_content_count,
)
for user_input_request in response.user_input_requests:
self._pending_agent_requests[user_input_request.id] = user_input_request # type: ignore[index]
await ctx.request_info(user_input_request, Content)
await ctx.request_info(user_input_request, Content, request_id=user_input_request.id)
return None
# Only yield output if the response is complete and not waiting for user input.
# This is to avoid emitting two events of different types ('output' and 'request_info')
# that carry the same payload.
await ctx.yield_output(response)
return response
async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUpdate]) -> AgentResponse | None:
@@ -472,9 +487,25 @@ class AgentExecutor(Executor):
)
async for update in stream:
updates.append(update)
await ctx.yield_output(update)
if update.user_input_requests:
user_input_request_count = len(update.user_input_requests)
total_message_content_count = len(update.contents)
if user_input_request_count != total_message_content_count:
logger.warning(
"Response update %s contains %d user input requests but total message contents are %d. "
"This indicates the response update contains both user input requests and message contents. "
"Double check if this is the intended behavior, as non user input request contents will "
"not be emitted.",
update.response_id,
user_input_request_count,
total_message_content_count,
)
streamed_user_input_requests.extend(update.user_input_requests)
else:
# Only yield output events for updates that do not contain user input requests.
# This is to avoid emitting two events of different types ('output' and 'request_info')
# that carry the same payload.
await ctx.yield_output(update)
# Prefer stream finalization when available so result hooks run
# (e.g., thread conversation updates). Fall back to reconstructing from updates
@@ -509,7 +540,7 @@ class AgentExecutor(Executor):
if user_input_requests:
for user_input_request in user_input_requests:
self._pending_agent_requests[user_input_request.id] = user_input_request # type: ignore[index]
await ctx.request_info(user_input_request, Content)
await ctx.request_info(user_input_request, Content, request_id=user_input_request.id)
return None
return response
@@ -360,6 +360,22 @@ class Workflow(DictConvertible):
# Flag to prevent concurrent workflow executions
self._is_running = False
# Current run-level status of this workflow instance. Updated in lockstep with
# the status events emitted from `_run_workflow_with_tracing`. Defaults to IDLE
# for a freshly built workflow that has not yet been run.
self._status: WorkflowRunState = WorkflowRunState.IDLE
@property
def status(self) -> WorkflowRunState:
"""Return the current run-level status of this workflow instance.
Mirrors the most recent status event emitted by the workflow. Safe to read at
any time: workflows run on a single asyncio event loop, and the underlying
attribute is a single enum reference whose assignment is atomic under the
CPython GIL, so no locking is required.
"""
return self._status
def _ensure_not_running(self) -> None:
"""Ensure the workflow is not already running."""
if self._is_running:
@@ -513,8 +529,9 @@ class Workflow(DictConvertible):
with _framework_event_origin():
started = WorkflowEvent.started()
yield started # noqa: RUF070
self._status = WorkflowRunState.IN_PROGRESS
with _framework_event_origin():
in_progress = WorkflowEvent.status(WorkflowRunState.IN_PROGRESS)
in_progress = WorkflowEvent.status(self._status)
yield in_progress # noqa: RUF070
# Per-run reset for fresh-message runs only. We deliberately
@@ -569,17 +586,20 @@ class Workflow(DictConvertible):
if event.type == "request_info" and not emitted_in_progress_pending:
emitted_in_progress_pending = True
self._status = WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS
with _framework_event_origin():
pending_status = WorkflowEvent.status(WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS)
pending_status = WorkflowEvent.status(self._status)
yield pending_status # noqa: RUF070
# Workflow runs until idle - emit final status based on whether requests are pending
if saw_request:
self._status = WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
with _framework_event_origin():
terminal_status = WorkflowEvent.status(WorkflowRunState.IDLE_WITH_PENDING_REQUESTS)
terminal_status = WorkflowEvent.status(self._status)
yield terminal_status
else:
self._status = WorkflowRunState.IDLE
with _framework_event_origin():
terminal_status = WorkflowEvent.status(WorkflowRunState.IDLE)
terminal_status = WorkflowEvent.status(self._status)
yield terminal_status
span.add_event(OtelAttr.WORKFLOW_COMPLETED)
@@ -593,6 +613,7 @@ class Workflow(DictConvertible):
with _framework_event_origin():
failed_event = WorkflowEvent.failed(details)
yield failed_event # noqa: RUF070
self._status = WorkflowRunState.FAILED
with _framework_event_origin():
failed_status = WorkflowEvent.status(WorkflowRunState.FAILED)
yield failed_status # noqa: RUF070
@@ -25,6 +25,7 @@ from agent_framework import (
prepend_agent_framework_to_user_agent,
tool,
)
from agent_framework._serialization import make_json_safe
from agent_framework.observability import (
ROLE_EVENT_MAP,
AgentTelemetryLayer,
@@ -3195,17 +3196,15 @@ def test_capture_messages_with_prepared_request_info_function_call_arguments(spa
from opentelemetry import trace
from agent_framework import WorkflowAgent
@dataclasses.dataclass
class HandoffRequest:
target_agent: str
reason: str
arguments = WorkflowAgent.RequestInfoFunctionArgs(
request_id="call_dc",
data=HandoffRequest(target_agent="helper", reason="overflow"),
).to_dict()
arguments = {
"request_id": "call_dc",
"data": make_json_safe(HandoffRequest(target_agent="helper", reason="overflow")),
}
msg = Message(
role="assistant",
contents=[
@@ -699,3 +699,171 @@ async def test_resolve_executor_kwargs_empty_per_executor_does_not_fallback_to_g
resolved = {"exec_a": {}, GLOBAL_KWARGS_KEY: {"global_key": "global_val"}}
result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage]
assert result == {}
# region Tool approval emission
class _ApprovalEmittingAgent(BaseAgent):
"""Agent that returns a single ``function_approval_request`` Content.
Used to verify that ``AgentExecutor`` does *not* surface the approval
payload via both an ``output`` event and a ``request_info`` event in the
same superstep — only the ``request_info`` event must carry it.
"""
def __init__(
self,
*,
approval_request_id: str = "apr_1",
tool_name: str = "delete_file",
tool_arguments: dict[str, Any] | None = None,
**kwargs: Any,
):
super().__init__(**kwargs)
self._approval_request_id = approval_request_id
self._tool_name = tool_name
self._tool_arguments: dict[str, Any] = tool_arguments or {"path": "/tmp/secret.txt"}
self.run_count = 0
def _build_approval_content(self) -> Content:
function_call = Content.from_function_call(
call_id=self._approval_request_id,
name=self._tool_name,
arguments=self._tool_arguments,
)
return Content.from_function_approval_request(id=self._approval_request_id, function_call=function_call)
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
self.run_count += 1
approval = self._build_approval_content()
if stream:
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
yield AgentResponseUpdate(contents=[approval], role="assistant")
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates)
async def _run() -> AgentResponse:
return AgentResponse(messages=[Message("assistant", [approval])])
return _run()
def _has_approval_payload(event: WorkflowEvent[Any]) -> bool:
"""Return True if the event's data carries a ``function_approval_request`` content."""
data: Any = event.data
def _contents_of(value: Any) -> list[Content]:
if isinstance(value, AgentResponseUpdate):
return list(value.contents)
if isinstance(value, AgentResponse):
return [c for m in value.messages for c in m.contents]
if isinstance(value, AgentExecutorResponse):
return [c for m in value.agent_response.messages for c in m.contents]
if isinstance(value, Message):
return list(value.contents)
if isinstance(value, Content):
return [value]
return []
return any(c.type == "function_approval_request" for c in _contents_of(data))
async def test_agent_executor_does_not_double_emit_approval_non_streaming() -> None:
"""Non-streaming: approval payload must only appear in the ``request_info`` event.
Regression test for the bug where ``AgentExecutor._run_agent`` first
``yield_output``-ed the response (carrying the approval Content) and then
additionally emitted a ``request_info`` event for the same payload.
"""
agent = _ApprovalEmittingAgent(id="approve_agent", name="ApproveAgent", approval_request_id="apr_ns_1")
executor = AgentExecutor(agent, id="approve_exec")
workflow = WorkflowBuilder(start_executor=executor).build()
request_info_events: list[WorkflowEvent[Any]] = []
output_events: list[WorkflowEvent[Any]] = []
for event in await workflow.run("please delete it"):
if event.type == "request_info":
request_info_events.append(event)
elif event.type == "output":
output_events.append(event)
assert len(request_info_events) == 1
assert _has_approval_payload(request_info_events[0])
# The approval payload must not also be surfaced as a workflow output.
assert not any(_has_approval_payload(e) for e in output_events)
assert agent.run_count == 1
async def test_agent_executor_does_not_double_emit_approval_streaming() -> None:
"""Streaming: per-update approval payload must not be ``yield_output``-ed."""
agent = _ApprovalEmittingAgent(id="approve_agent_s", name="ApproveAgentS", approval_request_id="apr_st_1")
executor = AgentExecutor(agent, id="approve_exec_s")
workflow = WorkflowBuilder(start_executor=executor).build()
request_info_events: list[WorkflowEvent[Any]] = []
output_events: list[WorkflowEvent[Any]] = []
async for event in workflow.run("please delete it", stream=True):
if event.type == "request_info":
request_info_events.append(event)
elif event.type == "output":
output_events.append(event)
assert len(request_info_events) == 1
assert _has_approval_payload(request_info_events[0])
assert not any(_has_approval_payload(e) for e in output_events)
assert agent.run_count == 1
async def test_agent_executor_request_info_uses_user_input_request_id() -> None:
"""``ctx.request_info`` must register the request under the agent's approval id.
This makes the workflow's pending-request id round-trip with the
``function_approval_response.id`` the caller echoes back, so
``Workflow._send_responses_internal`` can look it up directly.
"""
agent = _ApprovalEmittingAgent(id="approve_agent_id", name="ApproveAgentId", approval_request_id="apr_match")
executor = AgentExecutor(agent, id="approve_exec_id")
workflow = WorkflowBuilder(start_executor=executor).build()
request_info_events: list[WorkflowEvent[Any]] = []
async for event in workflow.run("please delete it", stream=True):
if event.type == "request_info":
request_info_events.append(event)
assert len(request_info_events) == 1
assert request_info_events[0].request_id == "apr_match"
# endregion Tool approval emission
@@ -1,6 +1,5 @@
# Copyright (c) Microsoft. All rights reserved.
import json
import uuid
from collections.abc import Awaitable, Sequence
from dataclasses import dataclass
@@ -30,6 +29,20 @@ from agent_framework import (
handler,
response_handler,
)
from agent_framework._workflows._typing_utils import deserialize_type
@dataclass
class HandoffRequest:
"""Module-level dataclass used by request_info tests.
Defined at module scope (not nested inside a test method) so
``serialize_type``/``deserialize_type`` can round-trip the request_type via
the importable qualified name ``tests.workflow.test_workflow_agent.HandoffRequest``.
"""
target_agent: str
reason: str
class SimpleExecutor(Executor):
@@ -240,52 +253,45 @@ class TestWorkflowAgent:
# Should have received an approval request for the request info
assert len(updates) > 0
approval_update: AgentResponseUpdate | None = None
request_update: AgentResponseUpdate | None = None
for update in updates:
if any(content.type == "function_approval_request" for content in update.contents):
approval_update = update
if any(content.type == "function_call" for content in update.contents):
request_update = update
break
assert approval_update is not None, "Should have received a request_info approval request"
assert request_update is not None, "Should have received a request_info wrapped in a function_call content"
function_call = next(content for content in approval_update.contents if content.type == "function_call")
approval_request = next(
content for content in approval_update.contents if content.type == "function_approval_request"
)
request_function_call = next(content for content in request_update.contents if content.type == "function_call")
assert request_function_call.call_id is not None
# Verify the function call has expected structure
assert function_call.call_id is not None
assert function_call.name == "request_info"
assert isinstance(function_call.arguments, dict)
assert function_call.arguments.get("request_id") == approval_request.id
assert request_function_call.name == WorkflowAgent.REQUEST_INFO_FUNCTION_NAME
assert isinstance(request_function_call.arguments, dict)
assert request_function_call.arguments.get("request_id") is not None
assert request_function_call.arguments.get("request_event") is not None
request_event = request_function_call.arguments["request_event"]
assert request_event.get("type") == "request_info"
assert deserialize_type(request_event.get("response_type")) is str
# Approval request should reference the same function call
assert approval_request.id is not None
assert approval_request.function_call is not None
assert approval_request.function_call.call_id == function_call.call_id
assert approval_request.function_call.name == function_call.name
deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(request_function_call.arguments)
assert deserialized_args.request_id == request_function_call.call_id
assert isinstance(deserialized_args.request_event, WorkflowEvent)
assert deserialized_args.request_event.type == "request_info"
assert deserialized_args.request_event.data == "Mock request data"
assert deserialized_args.request_event.response_type is str
# Verify the request is tracked in pending_requests
assert len(agent.pending_requests) == 1
assert function_call.call_id in agent.pending_requests
pending_requests = await workflow._runner_context.get_pending_request_info_events()
assert len(pending_requests) == 1
assert request_function_call.call_id in pending_requests
# Now provide an approval response with updated arguments to test continuation
response_args = WorkflowAgent.RequestInfoFunctionArgs(
request_id=approval_request.id,
data="User provided answer",
).to_dict()
approval_response = Content.from_function_approval_response(
approved=True,
id=approval_request.id,
function_call=Content.from_function_call(
call_id=function_call.call_id,
name=function_call.name,
arguments=response_args,
),
# Now provide a function result response with updated arguments to test continuation
function_result = Content.from_function_result(
call_id=request_function_call.call_id,
result="Mock response to request info",
)
response_message = Message(role="user", contents=[approval_response])
response_message = Message(role="user", contents=[function_result])
# Continue the workflow with the response
continuation_result = await agent.run(response_message)
@@ -294,16 +300,11 @@ class TestWorkflowAgent:
assert isinstance(continuation_result, AgentResponse)
# Verify cleanup - pending requests should be cleared after function response handling
assert len(agent.pending_requests) == 0
pending_requests = await workflow._runner_context.get_pending_request_info_events()
assert len(pending_requests) == 0
def test_request_info_dataclass_arguments_are_serialized_when_content_is_created(self) -> None:
"""Test WorkflowAgent prepares request_info arguments before observability captures messages."""
@dataclass
class HandoffRequest:
target_agent: str
reason: str
executor = SimpleExecutor(id="executor1", response_text="Response")
workflow = WorkflowBuilder(start_executor=executor).build()
agent = WorkflowAgent(workflow=workflow, name="Request Test Agent")
@@ -314,14 +315,367 @@ class TestWorkflowAgent:
response_type=str,
)
function_call, approval_request = agent._process_request_info_event(event) # pyright: ignore[reportPrivateUsage]
request_function_call = agent._process_request_info_event(event) # pyright: ignore[reportPrivateUsage]
assert function_call.arguments == {
"request_id": "request_123",
"data": {"target_agent": "helper", "reason": "overflow"},
}
assert approval_request.function_call is function_call
assert json.loads(json.dumps(function_call.arguments)) == function_call.arguments
assert request_function_call.call_id == "request_123"
assert isinstance(request_function_call.arguments, dict)
assert request_function_call.arguments.get("request_event") is not None
request_event = request_function_call.arguments["request_event"]
assert request_event.get("type") == "request_info"
assert request_event.get("request_id") == "request_123"
assert request_event.get("source_executor_id") == "executor1"
assert deserialize_type(request_event.get("response_type")) is str
assert request_event.get("data") == HandoffRequest(target_agent="helper", reason="overflow")
deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(request_function_call.arguments)
assert deserialized_args.request_id == "request_123"
assert isinstance(deserialized_args.request_event, WorkflowEvent)
assert deserialized_args.request_event.type == "request_info"
assert deserialized_args.request_event.data == HandoffRequest(target_agent="helper", reason="overflow")
assert deserialized_args.request_event.response_type is str
def test_process_request_info_event_passes_through_function_approval_request(self) -> None:
"""If the event data is already a function approval request, it is forwarded unchanged.
Tool-approval requests emitted by an inner agent surface as ``Content``
objects with ``user_input_request=True``. ``WorkflowAgent`` must not
re-wrap these inside a synthesized ``request_info`` function call;
instead it should return the original content as-is so callers can
respond with a matching ``function_approval_response``.
"""
executor = SimpleExecutor(id="executor1", response_text="Response")
workflow = WorkflowBuilder(start_executor=executor).build()
agent = WorkflowAgent(workflow=workflow, name="Approval Passthrough Agent")
approval_id = "approval-passthrough-1"
inner_function_call = Content.from_function_call(
call_id="tool-call-1",
name="delete_file",
arguments={"path": "/tmp/x"},
)
approval_request = Content.from_function_approval_request(
id=approval_id,
function_call=inner_function_call,
)
event = WorkflowEvent.request_info(
request_id=approval_id,
source_executor_id="executor1",
request_data=approval_request,
response_type=Content,
)
result = agent._process_request_info_event(event) # pyright: ignore[reportPrivateUsage]
# The original FunctionApprovalRequestContent is returned as-is — same
# instance, with the original tool name preserved (NOT replaced by the
# synthesized REQUEST_INFO_FUNCTION_NAME).
assert result is approval_request
assert result.type == "function_approval_request"
assert result.id == approval_id
assert result.user_input_request is True
assert result.function_call is inner_function_call # type: ignore[attr-defined]
assert result.function_call.name == "delete_file" # type: ignore[attr-defined]
assert result.function_call.name != WorkflowAgent.REQUEST_INFO_FUNCTION_NAME # type: ignore[attr-defined]
def test_extract_function_responses_passes_through_approval_response_approved(self) -> None:
"""A function_approval_response with approved=True is keyed by content.id and forwarded as-is.
After the refactor, ``WorkflowAgent`` no longer unwraps a synthesized
``request_info`` function call from approval responses — the response
content is routed straight back to the workflow under its own ``id``,
which matches the pending request id surfaced by
``_process_request_info_event``.
"""
executor = SimpleExecutor(id="executor1", response_text="Response")
workflow = WorkflowBuilder(start_executor=executor).build()
agent = WorkflowAgent(workflow=workflow, name="Approval Response Agent")
approval_id = "approval-response-approved-1"
inner_function_call = Content.from_function_call(
call_id="tool-call-1",
name="delete_file",
arguments={"path": "/tmp/x"},
)
approval_request = Content.from_function_approval_request(
id=approval_id,
function_call=inner_function_call,
)
approval_response = approval_request.to_function_approval_response(approved=True) # type: ignore[attr-defined]
message = Message(role="user", contents=[approval_response])
responses = agent._extract_function_responses([message]) # pyright: ignore[reportPrivateUsage]
assert set(responses.keys()) == {approval_id}
assert responses[approval_id] is approval_response
assert responses[approval_id].approved is True # type: ignore[attr-defined]
def test_extract_function_responses_passes_through_approval_response_denied(self) -> None:
"""A function_approval_response with approved=False is forwarded the same way as an approval.
Only the ``approved`` flag changes — routing back to the workflow is
identical for accept and reject paths.
"""
executor = SimpleExecutor(id="executor1", response_text="Response")
workflow = WorkflowBuilder(start_executor=executor).build()
agent = WorkflowAgent(workflow=workflow, name="Approval Response Agent")
approval_id = "approval-response-denied-1"
inner_function_call = Content.from_function_call(
call_id="tool-call-2",
name="send_email",
arguments={"to": "alice@example.com"},
)
approval_request = Content.from_function_approval_request(
id=approval_id,
function_call=inner_function_call,
)
approval_response = approval_request.to_function_approval_response(approved=False) # type: ignore[attr-defined]
message = Message(role="user", contents=[approval_response])
responses = agent._extract_function_responses([message]) # pyright: ignore[reportPrivateUsage]
assert set(responses.keys()) == {approval_id}
assert responses[approval_id] is approval_response
assert responses[approval_id].approved is False # type: ignore[attr-defined]
async def test_function_approval_request_flows_end_to_end_approved(self) -> None:
"""End-to-end: an executor emits a function_approval_request, the agent
forwards it unchanged, and an ``approved=True`` response resumes the workflow.
This exercises the full pass-through path:
``ctx.request_info(approval_content, ...)`` -> ``WorkflowAgent`` surfaces
the original ``FunctionApprovalRequestContent`` -> caller responds with a
``FunctionApprovalResponseContent`` -> ``WorkflowAgent`` routes it back
to the workflow which delivers it to the executor's ``@response_handler``.
"""
approval_id = "e2e-approval-1"
inner_function_call = Content.from_function_call(
call_id="tool-call-e2e-1",
name="delete_file",
arguments={"path": "/tmp/x"},
)
approval_request = Content.from_function_approval_request(
id=approval_id,
function_call=inner_function_call,
)
class ApprovalRequestingExecutor(Executor):
@handler
async def handle_message(self, _: list[Message], ctx: WorkflowContext) -> None:
await ctx.request_info(approval_request, Content, request_id=approval_id)
@response_handler
async def handle_response(
self,
original_request: Content,
response: Content,
ctx: WorkflowContext[Never, AgentResponse],
) -> None:
assert response.type == "function_approval_response"
assert response.id == approval_id # type: ignore[attr-defined]
approved = bool(response.approved) # type: ignore[attr-defined]
tool_name = original_request.function_call.name # type: ignore[attr-defined]
await ctx.yield_output(
AgentResponse(
messages=[
Message(
role="assistant",
contents=[Content.from_text(text=f"{tool_name} approved={approved}")],
)
]
)
)
executor = ApprovalRequestingExecutor(id="approval_requester")
workflow = WorkflowBuilder(start_executor=executor).build()
agent = WorkflowAgent(workflow=workflow, name="E2E Approval Agent")
# First run: workflow pauses with the approval request.
first = await agent.run("please delete it")
assert isinstance(first, AgentResponse)
forwarded = next(
(
c
for m in first.messages
for c in m.contents
if c.type == "function_approval_request" and c.id == approval_id
),
None,
)
assert forwarded is approval_request, "Approval request must surface unchanged"
pending = await workflow._runner_context.get_pending_request_info_events()
assert approval_id in pending
# Respond with approved=True.
approval_response = approval_request.to_function_approval_response(approved=True) # type: ignore[attr-defined]
final = await agent.run(Message(role="user", contents=[approval_response]))
assert isinstance(final, AgentResponse)
final_text = " ".join(m.text or "" for m in final.messages)
assert "delete_file approved=True" in final_text
pending = await workflow._runner_context.get_pending_request_info_events()
assert approval_id not in pending
async def test_function_approval_request_flows_end_to_end_denied(self) -> None:
"""End-to-end denied path: ``approved=False`` is delivered to the executor's
response handler so the workflow can branch on the rejection."""
approval_id = "e2e-approval-deny-1"
inner_function_call = Content.from_function_call(
call_id="tool-call-e2e-deny-1",
name="send_email",
arguments={"to": "alice@example.com"},
)
approval_request = Content.from_function_approval_request(
id=approval_id,
function_call=inner_function_call,
)
class ApprovalRequestingExecutor(Executor):
@handler
async def handle_message(self, _: list[Message], ctx: WorkflowContext) -> None:
await ctx.request_info(approval_request, Content, request_id=approval_id)
@response_handler
async def handle_response(
self,
original_request: Content,
response: Content,
ctx: WorkflowContext[Never, AgentResponse],
) -> None:
assert response.type == "function_approval_response"
assert response.id == approval_id # type: ignore[attr-defined]
approved = bool(response.approved) # type: ignore[attr-defined]
tool_name = original_request.function_call.name # type: ignore[attr-defined]
await ctx.yield_output(
AgentResponse(
messages=[
Message(
role="assistant",
contents=[Content.from_text(text=f"{tool_name} approved={approved}")],
)
]
)
)
executor = ApprovalRequestingExecutor(id="approval_requester_deny")
workflow = WorkflowBuilder(start_executor=executor).build()
agent = WorkflowAgent(workflow=workflow, name="E2E Approval Deny Agent")
first = await agent.run("please send")
assert isinstance(first, AgentResponse)
forwarded = next(
(
c
for m in first.messages
for c in m.contents
if c.type == "function_approval_request" and c.id == approval_id
),
None,
)
assert forwarded is approval_request
# Respond with approved=False.
approval_response = approval_request.to_function_approval_response(approved=False) # type: ignore[attr-defined]
final = await agent.run(Message(role="user", contents=[approval_response]))
assert isinstance(final, AgentResponse)
final_text = " ".join(m.text or "" for m in final.messages)
assert "send_email approved=False" in final_text
pending = await workflow._runner_context.get_pending_request_info_events()
assert approval_id not in pending
async def test_request_info_non_approval_flows_end_to_end(self) -> None:
"""End-to-end: when request data is not a function approval content, the
agent surfaces a synthesized ``function_call`` (name=REQUEST_INFO_FUNCTION_NAME)
and routes a matching ``function_result`` back to the executor.
"""
captured: dict[str, Any] = {}
class HandoffRequestingExecutor(Executor):
@handler
async def handle_message(self, _: list[Message], ctx: WorkflowContext) -> None:
await ctx.request_info(
HandoffRequest(target_agent="helper", reason="overflow"),
str,
)
@response_handler
async def handle_response(
self,
original_request: HandoffRequest,
response: str,
ctx: WorkflowContext[Never, AgentResponse],
) -> None:
captured["original"] = original_request
captured["response"] = response
await ctx.yield_output(
AgentResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_text(text=f"handoff to {original_request.target_agent}: {response}")
],
)
]
)
)
executor = HandoffRequestingExecutor(id="handoff_requester")
workflow = WorkflowBuilder(start_executor=executor).build()
agent = WorkflowAgent(workflow=workflow, name="E2E Handoff Agent")
# First run: workflow pauses with a synthesized request_info function_call.
first = await agent.run("start handoff")
assert isinstance(first, AgentResponse)
function_call = next(
(
c
for m in first.messages
for c in m.contents
if c.type == "function_call" and c.name == WorkflowAgent.REQUEST_INFO_FUNCTION_NAME
),
None,
)
assert function_call is not None, "Expected a synthesized request_info function_call"
assert function_call.call_id is not None
assert isinstance(function_call.arguments, dict)
request_id = function_call.arguments["request_id"]
assert function_call.call_id == request_id
request_payload = function_call.arguments["request_event"]
assert request_payload.get("type") == "request_info"
assert request_payload.get("data") == HandoffRequest(target_agent="helper", reason="overflow")
deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(function_call.arguments)
assert deserialized_args.request_id == request_id
assert isinstance(deserialized_args.request_event, WorkflowEvent)
assert deserialized_args.request_event.type == "request_info"
assert deserialized_args.request_event.data == HandoffRequest(target_agent="helper", reason="overflow")
assert deserialized_args.request_event.response_type is str
pending = await workflow._runner_context.get_pending_request_info_events()
assert request_id in pending
# Respond with a function_result keyed by the call_id.
function_result = Content.from_function_result(call_id=request_id, result="ok-do-it")
final = await agent.run(Message(role="user", contents=[function_result]))
assert isinstance(final, AgentResponse)
final_text = " ".join(m.text or "" for m in final.messages)
assert "handoff to helper: ok-do-it" in final_text
# The executor's response handler received the original request and the response.
assert isinstance(captured.get("original"), HandoffRequest)
assert captured["original"].target_agent == "helper"
assert captured["response"] == "ok-do-it"
pending = await workflow._runner_context.get_pending_request_info_events()
assert request_id not in pending
def test_workflow_as_agent_method(self) -> None:
"""Test that Workflow.as_agent() creates a properly configured WorkflowAgent."""
@@ -1592,3 +1946,406 @@ class TestWorkflowAgentMergeUpdates:
# Order: text (user), text (assistant), function_result (orphan at end)
assert content_types == ["text", "text", "function_result"]
class _ToolApprovalMockAgent(SupportsAgentRun):
"""Mock agent whose first run returns a FunctionApprovalRequestContent.
Subsequent runs (after receiving an approval response in the input messages)
return a final assistant text response that echoes the approved arguments.
This mirrors a real agent whose tool invocation requires user approval.
"""
def __init__(
self,
name: str,
*,
tool_name: str = "delete_file",
tool_arguments: dict[str, Any] | None = None,
approval_request_ids: Sequence[str] | None = None,
) -> None:
self.id = str(uuid.uuid4())
self.name = name
self.description: str | None = None
self._tool_name = tool_name
self._tool_arguments = tool_arguments or {"path": "/tmp/example"}
# Pre-allocated request ids so the test can verify what the WorkflowAgent forwards.
self._approval_request_ids: list[str] = list(approval_request_ids) if approval_request_ids else []
self.run_count = 0
# Inputs received on the most recent (continuation) run, for assertions.
self.last_run_messages: list[Message] = []
def create_session(self, **kwargs: Any) -> AgentSession:
return AgentSession()
def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession:
return AgentSession()
def _next_request_id(self) -> str:
if self._approval_request_ids:
return self._approval_request_ids.pop(0)
return str(uuid.uuid4())
def _build_approval_request(self) -> Content:
request_id = self._next_request_id()
function_call = Content.from_function_call(
call_id=request_id,
name=self._tool_name,
arguments=self._tool_arguments,
)
return Content.from_function_approval_request(id=request_id, function_call=function_call)
@overload
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
if stream:
return self._run_stream(messages=messages, session=session, **kwargs)
return self._run(messages=messages, session=session, **kwargs)
def _normalize(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None,
) -> list[Message]:
if messages is None:
return []
if isinstance(messages, str):
return [Message(role="user", contents=[Content.from_text(text=messages)])]
if isinstance(messages, Message):
return [messages]
if isinstance(messages, Content):
return [Message(role="user", contents=[messages])]
result: list[Message] = []
for item in messages:
if isinstance(item, Message):
result.append(item)
elif isinstance(item, Content):
result.append(Message(role="user", contents=[item]))
else:
result.append(Message(role="user", contents=[Content.from_text(text=item)]))
return result
def _approval_responses_in(self, messages: list[Message]) -> list[Content]:
approvals: list[Content] = []
for msg in messages:
for content in msg.contents:
if content.type == "function_approval_response":
approvals.append(content)
return approvals
async def _run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
session: AgentSession | None = None,
**kwargs: Any,
) -> AgentResponse:
normalized = self._normalize(messages)
self.last_run_messages = normalized
self.run_count += 1
approvals = self._approval_responses_in(normalized)
if approvals:
# Continuation: reflect approved arguments in the final response text.
approved_text = "; ".join(
f"approved={a.approved} id={a.id}" # type: ignore[attr-defined]
for a in approvals
)
return AgentResponse(messages=[Message("assistant", [Content.from_text(text=f"done ({approved_text})")])])
# First run: ask for tool approval.
approval = self._build_approval_request()
return AgentResponse(messages=[Message("assistant", [approval])])
def _run_stream(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
session: AgentSession | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse]:
normalized = self._normalize(messages)
self.last_run_messages = normalized
self.run_count += 1
approvals = self._approval_responses_in(normalized)
async def _iter():
if approvals:
approved_text = "; ".join(
f"approved={a.approved} id={a.id}" # type: ignore[attr-defined]
for a in approvals
)
yield AgentResponseUpdate(
contents=[Content.from_text(text=f"done ({approved_text})")],
role="assistant",
author_name=self.name,
)
return
approval = self._build_approval_request()
yield AgentResponseUpdate(
contents=[approval],
role="assistant",
author_name=self.name,
)
return ResponseStream(_iter(), finalizer=AgentResponse.from_updates)
class TestWorkflowAgentToolApproval:
"""Tests for tool-approval requests bubbling through WorkflowAgent.
Covers the case where a workflow contains an AgentExecutor whose underlying
agent emits a FunctionApprovalRequestContent (tool needing user approval).
The WorkflowAgent must:
* forward the original FunctionApprovalRequestContent unchanged (no
wrapping inside a synthesized 'request_info' function call), and
* route a subsequent FunctionApprovalResponseContent back to the
AgentExecutor so the agent can resume.
"""
def _find_approval_request(
self,
contents: Sequence[Content],
tool_name: str,
) -> Content | None:
for content in contents:
if (
content.type == "function_approval_request"
and getattr(content.function_call, "name", None) == tool_name # type: ignore[attr-defined]
):
return content
return None
async def test_tool_approval_request_forwarded_unchanged(self) -> None:
"""The agent's FunctionApprovalRequestContent surfaces verbatim (not re-wrapped)."""
approval_id = "approval-abc-123"
mock_agent = _ToolApprovalMockAgent(
name="approval-agent",
tool_name="delete_file",
tool_arguments={"path": "/tmp/secret.txt"},
approval_request_ids=[approval_id],
)
@executor
async def start(messages: list[Message], ctx: WorkflowContext[AgentExecutorRequest]) -> None:
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
workflow = WorkflowBuilder(start_executor=start).add_edge(start, mock_agent).build()
agent = WorkflowAgent(workflow=workflow, name="Approval Test Agent")
result = await agent.run("please delete the file")
assert isinstance(result, AgentResponse)
# Locate the approval request emitted by the WorkflowAgent.
all_contents: list[Content] = [c for m in result.messages for c in m.contents]
approval = self._find_approval_request(all_contents, tool_name="delete_file")
assert approval is not None, "WorkflowAgent did not forward the tool approval request"
# The id and inner function_call must match what the underlying agent produced
# — i.e. the WorkflowAgent must NOT have re-wrapped it inside a synthesized
# 'request_info' approval request.
assert approval.id == approval_id
function_call = approval.function_call # type: ignore[attr-defined]
assert function_call is not None
assert function_call.name == "delete_file"
assert function_call.name != WorkflowAgent.REQUEST_INFO_FUNCTION_NAME
assert function_call.arguments == {"path": "/tmp/secret.txt"}
# The agent must be paused awaiting the approval response.
pending = await workflow._runner_context.get_pending_request_info_events()
assert approval_id in pending
async def test_tool_approval_request_forwarded_unchanged_streaming(self) -> None:
"""Streaming variant: the approval request is forwarded as-is in updates."""
approval_id = "approval-stream-1"
mock_agent = _ToolApprovalMockAgent(
name="approval-agent-stream",
tool_name="send_email",
tool_arguments={"to": "alice@example.com"},
approval_request_ids=[approval_id],
)
@executor
async def start(messages: list[Message], ctx: WorkflowContext[AgentExecutorRequest]) -> None:
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
workflow = WorkflowBuilder(start_executor=start).add_edge(start, mock_agent).build()
agent = WorkflowAgent(workflow=workflow, name="Approval Stream Agent")
updates: list[AgentResponseUpdate] = []
async for update in agent.run("hi", stream=True):
updates.append(update)
approval_updates = [u for u in updates if any(c.type == "function_approval_request" for c in u.contents)]
assert approval_updates, "Streaming did not surface a tool approval request"
approval = self._find_approval_request(approval_updates[-1].contents, tool_name="send_email")
assert approval is not None
assert approval.id == approval_id
function_call = approval.function_call # type: ignore[attr-defined]
assert function_call is not None
assert function_call.name == "send_email"
assert function_call.name != WorkflowAgent.REQUEST_INFO_FUNCTION_NAME
assert function_call.arguments == {"to": "alice@example.com"}
async def test_tool_approval_response_resumes_agent(self) -> None:
"""Sending the approval response back resumes the agent and clears pending requests."""
approval_id = "approval-resume-1"
mock_agent = _ToolApprovalMockAgent(
name="approval-resume-agent",
tool_name="delete_file",
tool_arguments={"path": "/tmp/x"},
approval_request_ids=[approval_id],
)
@executor
async def start(messages: list[Message], ctx: WorkflowContext[AgentExecutorRequest]) -> None:
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
workflow = WorkflowBuilder(start_executor=start).add_edge(start, mock_agent).build()
agent = WorkflowAgent(workflow=workflow, name="Approval Resume Agent")
first_result = await agent.run("delete it")
approval = self._find_approval_request(
[c for m in first_result.messages for c in m.contents],
tool_name="delete_file",
)
assert approval is not None
assert mock_agent.run_count == 1
# Build the approval response. NOTE: the inner function_call's name is the
# original tool name ('delete_file'), NOT 'request_info'. This exercises the
# branch in WorkflowAgent._extract_function_responses that routes raw
# tool-approval responses straight through using content.id.
approval_response = approval.to_function_approval_response(approved=True) # type: ignore[attr-defined]
response_message = Message(role="user", contents=[approval_response])
final_result = await agent.run(response_message)
assert isinstance(final_result, AgentResponse)
# The mock agent should have been invoked a second time and seen the
# approval response in its inputs.
assert mock_agent.run_count == 2
approvals_seen = [
c for m in mock_agent.last_run_messages for c in m.contents if c.type == "function_approval_response"
]
assert len(approvals_seen) == 1
assert approvals_seen[0].id == approval_id # type: ignore[attr-defined]
assert approvals_seen[0].approved is True # type: ignore[attr-defined]
# The pending approval should now be cleared.
pending = await workflow._runner_context.get_pending_request_info_events()
assert approval_id not in pending
# The final assistant message reflects the resumption.
final_text = " ".join(m.text or "" for m in final_result.messages)
assert "done" in final_text
assert approval_id in final_text
async def test_tool_approval_response_rejected_resumes_agent(self) -> None:
"""Rejection path: ``approved=False`` is forwarded to the inner agent and clears the pending request.
The WorkflowAgent must route a rejection response back to the paused
``AgentExecutor`` exactly the same way as an approval — only the
``approved`` flag differs. The inner agent decides what to do with it.
"""
approval_id = "approval-reject-1"
mock_agent = _ToolApprovalMockAgent(
name="approval-reject-agent",
tool_name="delete_file",
tool_arguments={"path": "/tmp/x"},
approval_request_ids=[approval_id],
)
@executor
async def start(messages: list[Message], ctx: WorkflowContext[AgentExecutorRequest]) -> None:
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
workflow = WorkflowBuilder(start_executor=start).add_edge(start, mock_agent).build()
agent = WorkflowAgent(workflow=workflow, name="Approval Reject Agent")
first_result = await agent.run("delete it")
approval = self._find_approval_request(
[c for m in first_result.messages for c in m.contents],
tool_name="delete_file",
)
assert approval is not None
assert mock_agent.run_count == 1
# Reject the tool invocation.
approval_response = approval.to_function_approval_response(approved=False) # type: ignore[attr-defined]
response_message = Message(role="user", contents=[approval_response])
final_result = await agent.run(response_message)
assert isinstance(final_result, AgentResponse)
# The inner agent must have been resumed and seen ``approved=False``.
assert mock_agent.run_count == 2
approvals_seen = [
c for m in mock_agent.last_run_messages for c in m.contents if c.type == "function_approval_response"
]
assert len(approvals_seen) == 1
assert approvals_seen[0].id == approval_id # type: ignore[attr-defined]
assert approvals_seen[0].approved is False # type: ignore[attr-defined]
# Pending approval cleared regardless of approve/reject.
pending = await workflow._runner_context.get_pending_request_info_events()
assert approval_id not in pending
# The final assistant message reflects the rejection.
final_text = " ".join(m.text or "" for m in final_result.messages)
assert "approved=False" in final_text
assert approval_id in final_text
async def test_tool_approval_request_id_matches_pending_request(self) -> None:
"""The approval request id surfaced by WorkflowAgent matches the workflow's pending request id.
This guards the AgentExecutor change that forwards
request_id=user_input_request.id to ctx.request_info(...), which is what
allows the response routed back via WorkflowAgent to resolve the pending
request without an id-mismatch error.
"""
approval_id = "approval-id-match-1"
mock_agent = _ToolApprovalMockAgent(
name="approval-id-match-agent",
approval_request_ids=[approval_id],
)
@executor
async def start(messages: list[Message], ctx: WorkflowContext[AgentExecutorRequest]) -> None:
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
workflow = WorkflowBuilder(start_executor=start).add_edge(start, mock_agent).build()
agent = WorkflowAgent(workflow=workflow, name="Approval Id Agent")
await agent.run("go")
pending = await workflow._runner_context.get_pending_request_info_events()
# The agent's approval id is used as the workflow's pending request id.
assert list(pending.keys()) == [approval_id]
@@ -0,0 +1,149 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for the ``Workflow.status`` property."""
from dataclasses import dataclass
import pytest
from agent_framework import (
Executor,
Workflow,
WorkflowBuilder,
WorkflowContext,
WorkflowEvent,
WorkflowRunState,
handler,
response_handler,
)
from agent_framework._workflows._executor import Executor as _Executor
from agent_framework._workflows._request_info_mixin import RequestInfoMixin
class PassThroughExecutor(Executor):
"""Executor that yields its input as a workflow output and stops."""
@handler
async def passthrough(self, msg: str, ctx: WorkflowContext[str, str]) -> None:
await ctx.yield_output(msg)
class FailingExecutor(Executor):
"""Executor that raises at runtime to drive the FAILED status."""
@handler
async def fail(self, msg: int, ctx: WorkflowContext) -> None: # pragma: no cover - invoked via workflow
raise RuntimeError("boom")
@dataclass
class _ApprovalRequest:
prompt: str
request_id: str = ""
def __post_init__(self) -> None:
if not self.request_id:
import uuid
self.request_id = str(uuid.uuid4())
class ApprovalExecutor(_Executor, RequestInfoMixin):
"""Executor that issues a single request_info call and finalizes on response."""
def __init__(self, id: str = "approval"):
super().__init__(id=id)
@handler
async def start(self, message: str, ctx: WorkflowContext[str, str]) -> None:
await ctx.request_info(_ApprovalRequest(prompt=message), bool)
@response_handler
async def on_response(
self, original_request: _ApprovalRequest, approved: bool, ctx: WorkflowContext[str, str]
) -> None:
await ctx.yield_output(f"approved={approved}")
def _build_passthrough_workflow() -> Workflow:
executor = PassThroughExecutor(id="p")
return WorkflowBuilder(start_executor=executor, output_from=[executor]).build()
def _build_failing_workflow() -> Workflow:
# FailingExecutor has no workflow_output_types, so we leave designation
# implicit; the deprecation warning is filtered at call sites that need it.
return WorkflowBuilder(start_executor=FailingExecutor(id="f")).build()
def _build_approval_workflow() -> Workflow:
executor = ApprovalExecutor(id="approval")
return WorkflowBuilder(start_executor=executor, output_from=[executor]).build()
async def test_status_default_is_idle_before_first_run():
wf = _build_passthrough_workflow()
assert wf.status is WorkflowRunState.IDLE
async def test_status_is_idle_after_successful_run():
wf = _build_passthrough_workflow()
await wf.run("hello")
assert wf.status is WorkflowRunState.IDLE
async def test_status_is_failed_after_failure():
wf = _build_failing_workflow()
with pytest.raises(RuntimeError, match="boom"):
await wf.run(0)
assert wf.status is WorkflowRunState.FAILED
async def test_status_transitions_during_streaming_run():
"""Workflow.status mirrors the most recent emitted status event."""
wf = _build_passthrough_workflow()
observed: list[WorkflowRunState] = []
async for event in wf.run("hi", stream=True):
if isinstance(event, WorkflowEvent) and event.type == "status":
# By the time a status event surfaces to the consumer, the property
# must already reflect that state (updated in lockstep with emission).
assert wf.status == event.state
observed.append(event.state) # type: ignore
# IN_PROGRESS must precede IDLE; both must appear.
assert WorkflowRunState.IN_PROGRESS in observed
assert observed[-1] is WorkflowRunState.IDLE
assert wf.status is WorkflowRunState.IDLE
async def test_status_idle_with_pending_requests_then_resolves_to_idle():
wf = _build_approval_workflow()
request_event: WorkflowEvent | None = None
async for event in wf.run("please approve", stream=True):
if isinstance(event, WorkflowEvent) and event.type == "request_info":
request_event = event
assert request_event is not None
assert wf.status is WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
async for _ in wf.run(stream=True, responses={request_event.request_id: True}):
pass
assert wf.status is WorkflowRunState.IDLE
async def test_status_in_progress_pending_requests_observed_mid_run():
"""While streaming, status reaches IN_PROGRESS_PENDING_REQUESTS after a request_info event."""
wf = _build_approval_workflow()
seen_states: list[WorkflowRunState] = []
async for event in wf.run("please approve", stream=True):
if isinstance(event, WorkflowEvent) and event.type == "status":
seen_states.append(event.state) # type: ignore
assert WorkflowRunState.IN_PROGRESS in seen_states
assert WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS in seen_states
assert seen_states[-1] is WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
assert wf.status is WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
+1 -1
View File
@@ -26,7 +26,7 @@ dependencies = [
"agent-framework-core>=1.8.0,<2",
"agent-framework-openai>=1.8.0,<2",
"azure-ai-inference>=1.0.0b9,<1.0.0b10",
"azure-ai-projects>=2.1.0,<3.0",
"azure-ai-projects>=2.2.0,<3.0",
]
[tool.uv]
@@ -567,7 +567,7 @@ class ResponsesHostServer(ResponsesAgentServerHost):
by the hosting infrastructure or files will be preserved upon deactivation.
"""
input_items = await context.get_input_items()
input_messages = await _items_to_messages(input_items)
input_messages = await _items_to_messages(input_items, approval_storage=self._approval_storage)
is_streaming_request = request.stream is not None and request.stream is True
_, are_options_set = _to_chat_options(request)
@@ -664,7 +664,11 @@ class ResponsesHostServer(ResponsesAgentServerHost):
checkpoint_storage=write_storage,
)
async for item in _to_outputs_for_messages(response_event_stream, response.messages):
async for item in _to_outputs_for_messages(
response_event_stream,
response.messages,
approval_storage=self._approval_storage,
):
yield item
await self._delete_not_latest_checkpoints(write_storage, self._agent.workflow.name)
@@ -685,7 +689,9 @@ class ResponsesHostServer(ResponsesAgentServerHost):
for event in tracker.handle(content):
yield event
if tracker.needs_async:
async for item in _to_outputs(response_event_stream, content):
async for item in _to_outputs(
response_event_stream, content, approval_storage=self._approval_storage
):
yield item
tracker.needs_async = False
@@ -11,24 +11,33 @@ the registered _handle_create handler.
from __future__ import annotations
import json
from collections.abc import AsyncIterator, Callable
import uuid
from collections.abc import AsyncIterator, Awaitable, Callable, Sequence
from dataclasses import dataclass
from typing import Literal, overload
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from agent_framework import (
AgentExecutorRequest,
AgentResponse,
AgentResponseUpdate,
AgentSession,
Content,
FileCheckpointStorage,
HistoryProvider,
Message,
RawAgent,
ResponseStream,
SupportsAgentRun,
WorkflowAgent,
WorkflowBuilder,
WorkflowCheckpoint,
WorkflowCheckpointException,
WorkflowContext,
WorkflowMessage,
executor,
)
from azure.ai.agentserver.responses import InMemoryResponseProvider
from mcp import McpError
@@ -102,7 +111,7 @@ def _make_agent(
return agent
def _make_server(agent: MagicMock, **kwargs: Any) -> ResponsesHostServer:
def _make_server(agent: Any, **kwargs: Any) -> ResponsesHostServer:
"""Create a ResponsesHostServer with an in-memory store."""
return ResponsesHostServer(agent, store=InMemoryResponseProvider(), **kwargs)
@@ -3469,3 +3478,498 @@ class TestOAuthConsentSurfacing:
# endregion
# region Workflow agent hosting (end-to-end)
class _ToolApprovalWorkflowAgentMock(SupportsAgentRun):
"""Inner agent for a hosted ``WorkflowAgent`` whose first run emits a
``FunctionApprovalRequestContent`` and whose follow-up run (after
receiving a ``FunctionApprovalResponseContent`` in its inputs) returns a
final assistant text response.
Mirrors a real agent whose tool invocation requires user approval. Used
here to exercise the full HTTP pipeline through ``ResponsesHostServer``
when the hosted agent is a ``WorkflowAgent`` containing a tool-approval
flow.
"""
def __init__(
self,
name: str,
*,
tool_name: str = "delete_file",
tool_arguments: dict[str, Any] | None = None,
approval_request_ids: Sequence[str] | None = None,
final_text: str = "done",
) -> None:
self.id = str(uuid.uuid4())
self.name = name
self.description: str | None = None
self._tool_name = tool_name
self._tool_arguments = tool_arguments or {"path": "/tmp/example"}
self._approval_request_ids: list[str] = list(approval_request_ids) if approval_request_ids else []
self._final_text = final_text
self.run_count = 0
self.last_run_messages: list[Message] = []
def create_session(self, **kwargs: Any) -> AgentSession:
return AgentSession()
def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession:
return AgentSession()
def _next_request_id(self) -> str:
# Stable across calls: when the workflow checkpoint round-trips through
# restore, ``AgentExecutor`` re-invokes the inner agent during replay.
# We must surface the *same* approval request id on each invocation so
# the workflow's pending-request id matches the id the test echoes
# back as ``mcp_approval_response``.
if self._approval_request_ids:
return self._approval_request_ids[0]
return str(uuid.uuid4())
def _build_approval_request(self) -> Content:
request_id = self._next_request_id()
function_call = Content.from_function_call(
call_id=request_id,
name=self._tool_name,
arguments=self._tool_arguments,
additional_properties={"server_label": "test_server"},
)
return Content.from_function_approval_request(id=request_id, function_call=function_call)
@overload
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
if stream:
return self._run_stream(messages=messages, **kwargs)
return self._run(messages=messages, **kwargs)
@staticmethod
def _normalize(
messages: str | Content | Message | Sequence[str | Content | Message] | None,
) -> list[Message]:
if messages is None:
return []
if isinstance(messages, str):
return [Message(role="user", contents=[Content.from_text(text=messages)])]
if isinstance(messages, Message):
return [messages]
if isinstance(messages, Content):
return [Message(role="user", contents=[messages])]
result: list[Message] = []
for item in messages:
if isinstance(item, Message):
result.append(item)
elif isinstance(item, Content):
result.append(Message(role="user", contents=[item]))
else:
result.append(Message(role="user", contents=[Content.from_text(text=item)]))
return result
@staticmethod
def _approval_responses_in(messages: list[Message]) -> list[Content]:
return [c for m in messages for c in m.contents if c.type == "function_approval_response"]
async def _run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
**kwargs: Any,
) -> AgentResponse:
normalized = self._normalize(messages)
self.last_run_messages = normalized
self.run_count += 1
if self._approval_responses_in(normalized):
return AgentResponse(messages=[Message("assistant", [Content.from_text(text=self._final_text)])])
approval = self._build_approval_request()
return AgentResponse(messages=[Message("assistant", [approval])])
def _run_stream(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse]:
normalized = self._normalize(messages)
self.last_run_messages = normalized
self.run_count += 1
approvals = self._approval_responses_in(normalized)
async def _iter() -> AsyncIterator[AgentResponseUpdate]:
if approvals:
yield AgentResponseUpdate(
contents=[Content.from_text(text=self._final_text)],
role="assistant",
author_name=self.name,
)
return
yield AgentResponseUpdate(
contents=[self._build_approval_request()],
role="assistant",
author_name=self.name,
)
return ResponseStream(_iter(), finalizer=AgentResponse.from_updates)
def _build_text_workflow_agent(text: str) -> WorkflowAgent:
"""Build a minimal ``WorkflowAgent`` whose inner agent emits a fixed text."""
class _TextAgent(SupportsAgentRun):
def __init__(self, name: str, text: str) -> None:
self.id = str(uuid.uuid4())
self.name = name
self.description: str | None = None
self._text = text
def create_session(self, **kwargs: Any) -> AgentSession:
return AgentSession()
def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession:
return AgentSession()
@overload
def run(
self,
messages: Any = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: Any = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: Any = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
text = self._text
name = self.name
async def _aresult() -> AgentResponse:
return AgentResponse(messages=[Message("assistant", [Content.from_text(text=text)])])
async def _aiter() -> AsyncIterator[AgentResponseUpdate]:
yield AgentResponseUpdate(
contents=[Content.from_text(text=text)],
role="assistant",
author_name=name,
)
if stream:
return ResponseStream(_aiter(), finalizer=AgentResponse.from_updates)
return _aresult()
inner = _TextAgent("text-agent", text)
@executor
async def start(messages: list[Message], ctx: WorkflowContext[AgentExecutorRequest]) -> None:
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
workflow = WorkflowBuilder(start_executor=start).add_edge(start, inner).build()
return WorkflowAgent(workflow=workflow, name="Text Workflow Agent")
def _build_approval_workflow_agent(
*,
approval_request_id: str,
tool_name: str = "delete_file",
tool_arguments: dict[str, Any] | None = None,
final_text: str = "done",
) -> tuple[WorkflowAgent, _ToolApprovalWorkflowAgentMock]:
"""Build a ``WorkflowAgent`` whose inner agent emits a tool approval request."""
mock_agent = _ToolApprovalWorkflowAgentMock(
name="approval-agent",
tool_name=tool_name,
tool_arguments=tool_arguments or {"path": "/tmp/secret.txt"},
approval_request_ids=[approval_request_id],
final_text=final_text,
)
@executor
async def start(messages: list[Message], ctx: WorkflowContext[AgentExecutorRequest]) -> None:
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
workflow = WorkflowBuilder(start_executor=start).add_edge(start, mock_agent).build()
workflow_agent = WorkflowAgent(workflow=workflow, name="Approval Workflow Agent")
return workflow_agent, mock_agent
class TestWorkflowAgentHosting:
"""End-to-end HTTP tests for ``ResponsesHostServer`` hosting a ``WorkflowAgent``.
These tests drive ``_handle_inner_workflow`` through the ASGI stack:
they exercise checkpoint write/restore (multi-turn) and the
tool-approval round-trip path, which is the primary differentiator
relative to the regular agent path.
"""
async def test_basic_text_response(self) -> None:
workflow_agent = _build_text_workflow_agent("hello from workflow")
server = _make_server(workflow_agent)
resp = await _post(server, input_text="hi", stream=False)
assert resp.status_code == 200
body = resp.json()
assert body["status"] == "completed"
text_found = any(
part.get("type") == "output_text" and part.get("text") == "hello from workflow"
for item in body["output"]
if item["type"] == "message"
for part in item.get("content", [])
)
assert text_found, f"Expected workflow output text in {body['output']}"
async def test_basic_text_response_streaming(self) -> None:
workflow_agent = _build_text_workflow_agent("hello stream")
server = _make_server(workflow_agent)
resp = await _post(server, input_text="hi", stream=True)
assert resp.status_code == 200
events = _parse_sse_events(resp.text)
types = _sse_event_types(events)
assert types[0] == "response.created"
assert types[-1] == "response.completed"
assert "response.output_text.delta" in types
text_done = [e for e in events if e["event"] == "response.output_text.done"]
assert any(e["data"]["text"] == "hello stream" for e in text_done)
async def test_non_streaming_emits_mcp_approval_request_and_persists_to_storage(self) -> None:
workflow_agent, mock_agent = _build_approval_workflow_agent(approval_request_id="apr_wf_ns")
server = _make_server(workflow_agent)
resp = await _post(server, stream=False)
assert resp.status_code == 200
body = resp.json()
assert body["status"] == "completed"
approval_items = [it for it in body["output"] if it["type"] == "mcp_approval_request"]
assert len(approval_items) == 1
assert approval_items[0]["name"] == "delete_file"
assert approval_items[0]["server_label"] == "test_server"
approval_request_id = approval_items[0]["id"]
# The id surfaced over the wire is generated by the response stream
# builder; the original approval ``Content`` (carrying the inner
# ``function_call``) must be persisted under that id so the next
# turn can reconstruct it.
loaded = await server._approval_storage.load_approval_request( # pyright: ignore[reportPrivateUsage]
approval_request_id
)
assert loaded.type == "function_approval_request"
assert loaded.function_call.name == "delete_file" # type: ignore[attr-defined]
assert mock_agent.run_count == 1
async def test_streaming_emits_mcp_approval_request_and_persists_to_storage(self) -> None:
workflow_agent, mock_agent = _build_approval_workflow_agent(approval_request_id="apr_wf_st")
server = _make_server(workflow_agent)
resp = await _post(server, stream=True)
assert resp.status_code == 200
events = _parse_sse_events(resp.text)
types = _sse_event_types(events)
assert types[0] == "response.created"
assert types[-1] == "response.completed"
approval_request_id: str | None = None
for e in events:
if e["event"] != "response.output_item.added":
continue
item = e["data"].get("item") or {}
if item.get("type") == "mcp_approval_request":
approval_request_id = item.get("id")
break
assert approval_request_id is not None
loaded = await server._approval_storage.load_approval_request( # pyright: ignore[reportPrivateUsage]
approval_request_id
)
assert loaded.type == "function_approval_request"
assert mock_agent.run_count == 1
async def test_round_trip_approval_response_resumes_workflow_agent(self) -> None:
"""Two-turn HTTP round-trip:
Turn 1 emits ``mcp_approval_request`` and writes a workflow
checkpoint under the response id. Turn 2 sends the
``mcp_approval_response`` with ``previous_response_id`` set, so the
host restores the checkpoint, the WorkflowAgent routes the
approval response back to the paused inner agent, and the inner
agent emits the final assistant text.
"""
workflow_agent, mock_agent = _build_approval_workflow_agent(
approval_request_id="apr_wf_rt",
final_text="done with approval",
)
server = _make_server(workflow_agent)
first = await _post(server, stream=False)
assert first.status_code == 200
first_body = first.json()
first_response_id = first_body["id"]
approval_items = [it for it in first_body["output"] if it["type"] == "mcp_approval_request"]
assert len(approval_items) == 1
approval_request_id = approval_items[0]["id"]
assert mock_agent.run_count == 1
second_payload: dict[str, Any] = {
"model": "test-model",
"input": [
{
"type": "mcp_approval_response",
"approval_request_id": approval_request_id,
"approve": True,
}
],
"stream": False,
"previous_response_id": first_response_id,
}
second = await _post_json(server, second_payload)
assert second.status_code == 200
second_body = second.json()
assert second_body["status"] == "completed"
# The inner agent must have been resumed (restore replay + new turn).
# Restore call is a no-op for the mock (no input); the new-turn call
# delivers the approval response, so run_count grows by at least 1.
assert mock_agent.run_count >= 2
# The final assistant text from the resumed inner agent surfaces in
# the HTTP output.
text_pieces = [
part.get("text", "")
for item in second_body["output"]
if item["type"] == "message"
for part in item.get("content", [])
if part.get("type") == "output_text"
]
assert any("done with approval" in t for t in text_pieces), (
f"expected resumed workflow output, got {second_body['output']}"
)
# The new-turn invocation of the inner agent must have received the
# approval response routed back through WorkflowAgent.
approval_responses = [
c for m in mock_agent.last_run_messages for c in m.contents if c.type == "function_approval_response"
]
assert len(approval_responses) == 1
assert approval_responses[0].approved is True # type: ignore[attr-defined]
async def test_round_trip_approval_response_streaming(self) -> None:
"""Streaming variant of the round-trip: turn 2 is requested with
``stream=true`` and surfaces the resumed text as SSE events."""
workflow_agent, mock_agent = _build_approval_workflow_agent(
approval_request_id="apr_wf_rt_st",
final_text="streamed-done",
)
server = _make_server(workflow_agent)
first = await _post(server, stream=False)
first_body = first.json()
first_response_id = first_body["id"]
approval_request_id = next(it["id"] for it in first_body["output"] if it["type"] == "mcp_approval_request")
second = await _post_json(
server,
{
"model": "test-model",
"input": [
{
"type": "mcp_approval_response",
"approval_request_id": approval_request_id,
"approve": True,
}
],
"stream": True,
"previous_response_id": first_response_id,
},
)
assert second.status_code == 200
events = _parse_sse_events(second.text)
types = _sse_event_types(events)
assert types[0] == "response.created"
assert types[-1] == "response.completed"
text_done = [e for e in events if e["event"] == "response.output_text.done"]
assert any("streamed-done" in e["data"]["text"] for e in text_done)
assert mock_agent.run_count >= 2
async def test_round_trip_approval_response_rejected(self) -> None:
"""Sending ``approve=False`` must surface as ``approved=False`` to the
inner agent on resume."""
workflow_agent, mock_agent = _build_approval_workflow_agent(
approval_request_id="apr_wf_reject",
final_text="acknowledged",
)
server = _make_server(workflow_agent)
first = await _post(server, stream=False)
first_body = first.json()
first_response_id = first_body["id"]
approval_request_id = next(it["id"] for it in first_body["output"] if it["type"] == "mcp_approval_request")
second = await _post_json(
server,
{
"model": "test-model",
"input": [
{
"type": "mcp_approval_response",
"approval_request_id": approval_request_id,
"approve": False,
}
],
"stream": False,
"previous_response_id": first_response_id,
},
)
assert second.status_code == 200
approval_responses = [
c for m in mock_agent.last_run_messages for c in m.contents if c.type == "function_approval_response"
]
assert len(approval_responses) == 1
assert approval_responses[0].approved is False # type: ignore[attr-defined]
# endregion
@@ -134,15 +134,11 @@ def handle_response_and_requests(response: AgentResponse) -> dict[str, HandoffAg
if message.text:
print(f"- {message.author_name or message.role}: {message.text}")
for content in message.contents:
if content.type == "function_call":
if isinstance(content.arguments, dict):
request = WorkflowAgent.RequestInfoFunctionArgs.from_dict(content.arguments)
elif isinstance(content.arguments, str):
request = WorkflowAgent.RequestInfoFunctionArgs.from_json(content.arguments)
else:
raise ValueError("Invalid arguments type. Expecting a request info structure for this sample.")
if isinstance(request.data, HandoffAgentUserRequest):
pending_requests[request.request_id] = request.data
if content.type == "function_call" and content.name == WorkflowAgent.REQUEST_INFO_FUNCTION_NAME:
request_function_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(content.arguments) # type: ignore
request_id = request_function_args.request_id
request_event = request_function_args.request_event
pending_requests[request_id] = request_event.data
return pending_requests
@@ -3,10 +3,8 @@
import asyncio
import os
import sys
from collections.abc import Mapping
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from agent_framework.foundry import FoundryChatClient
from azure.identity import AzureCliCredential
@@ -141,28 +139,14 @@ async def main() -> None:
# Handle the human review if required.
if human_review_function_call:
# Parse the human review request arguments.
human_request_args = human_review_function_call.arguments
if isinstance(human_request_args, str):
request: WorkflowAgent.RequestInfoFunctionArgs = WorkflowAgent.RequestInfoFunctionArgs.from_json(
human_request_args
)
elif isinstance(human_request_args, Mapping):
request = WorkflowAgent.RequestInfoFunctionArgs.from_dict(dict(human_request_args))
else:
raise TypeError("Unexpected argument type for human review function call.")
request_payload: Any = request.data
human_request_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(human_review_function_call.arguments) # type: ignore
request_payload = human_request_args.request_event.data
if not isinstance(request_payload, HumanReviewRequest):
raise ValueError("Human review request payload must be a HumanReviewRequest.")
agent_request = request_payload.agent_request
if agent_request is None:
raise ValueError("Human review request must include agent_request.")
request_id = agent_request.request_id
if not request_payload.agent_request:
raise ValueError("Human review request must contain an agent_request.")
# Mock a human response approval for demonstration purposes.
human_response = ReviewResponse(request_id=request_id, feedback="", approved=True)
human_response = ReviewResponse(request_id=request_payload.agent_request.request_id, feedback="", approved=True)
# Create the function call result object to send back to the agent.
human_review_function_result = Content(
"function_result",
@@ -28,6 +28,7 @@ import zipfile
from pathlib import Path
from azure.ai.projects.aio import AIProjectClient
from azure.ai.projects.models import CreateSkillVersionFromFilesBody
from azure.core.exceptions import ResourceNotFoundError
from azure.identity.aio import DefaultAzureCredential
from dotenv import load_dotenv
@@ -68,8 +69,13 @@ async def main() -> None:
name = skill_md.parent.name
print(f"Provisioning skill '{name}' from {skill_md.relative_to(SKILLS_DIR.parent)}...")
await _delete_skill_if_exists(project, name)
imported = await project.beta.skills.create_from_package(_zip_skill_md(skill_md))
print(f" Imported skill '{imported.name}' (id={imported.skill_id}, has_blob={imported.has_blob}).")
imported = await project.beta.skills.create_from_files(
name,
content=CreateSkillVersionFromFilesBody(
files=[(f"{name}.zip", _zip_skill_md(skill_md), "application/zip")]
),
)
print(f" Imported skill '{imported.name}' (id={imported.skill_id}, version={imported.version}).")
print("Verifying skills via project.beta.skills.list()...")
listed = {skill.name: skill async for skill in project.beta.skills.list()}
@@ -79,8 +85,8 @@ async def main() -> None:
if skill is None:
raise RuntimeError(f"Skill '{name}' was imported but is not present in the project listing.")
print(
f" OK '{skill.name}': id={skill.skill_id}, "
f"description={skill.description!r}, has_blob={skill.has_blob}"
f" OK '{skill.name}': id={skill.id}, "
f"description={skill.description!r}, default_version={skill.default_version}"
)
print("Done.")