mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Refactor workflow as agent pending request handling (#6259)
* WIP: Refactor Workflow as agent pending request handling * WIP: debugging empty message bug * Working: Workflow as agent with function approval * Address Copilot comments * Fix mypy * Address comments and fix pipeline * Request info non function approval now becomes function call * Revert uv.lock * Fix mypy * Bump min version of azure-ai-project * Remove RequestInfoFunctionArgs * fix tests * Fix failing tests * Fix sample
This commit is contained in:
committed by
GitHub
Unverified
parent
d5335fbeae
commit
9cafd7e58b
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import uuid
|
||||
@@ -12,7 +11,6 @@ from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload
|
||||
|
||||
from .._agents import BaseAgent
|
||||
from .._serialization import make_json_safe
|
||||
from .._sessions import (
|
||||
AgentSession,
|
||||
ContextProvider,
|
||||
@@ -30,11 +28,12 @@ from .._types import (
|
||||
UsageDetails,
|
||||
add_usage_details,
|
||||
)
|
||||
from ..exceptions import AgentInvalidRequestException, AgentInvalidResponseException
|
||||
from ..exceptions import AgentException, AgentInvalidRequestException, AgentInvalidResponseException
|
||||
from ._checkpoint import CheckpointStorage
|
||||
from ._events import (
|
||||
AGENT_FORWARDED_EVENT_TYPES,
|
||||
WorkflowEvent,
|
||||
WorkflowRunState,
|
||||
)
|
||||
from ._message_utils import normalize_messages_input
|
||||
from ._typing_utils import is_instance_of, is_type_compatible
|
||||
@@ -59,27 +58,24 @@ class WorkflowAgent(BaseAgent):
|
||||
@dataclass
|
||||
class RequestInfoFunctionArgs:
|
||||
request_id: str
|
||||
data: Any
|
||||
request_event: WorkflowEvent
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {"request_id": self.request_id, "data": make_json_safe(self.data)}
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(self.to_dict())
|
||||
return {"request_id": self.request_id, "request_event": self.request_event.to_dict()}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, payload: dict[str, Any]) -> WorkflowAgent.RequestInfoFunctionArgs:
|
||||
return cls(request_id=payload.get("request_id", ""), data=payload.get("data"))
|
||||
if "request_id" not in payload or "request_event" not in payload:
|
||||
raise ValueError(
|
||||
"Invalid payload for RequestInfoFunctionArgs. 'request_id' and 'request_event' are required."
|
||||
)
|
||||
if not payload["request_id"]:
|
||||
raise ValueError("request_id cannot be empty.")
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, raw: str) -> WorkflowAgent.RequestInfoFunctionArgs:
|
||||
try:
|
||||
parsed: Any = json.loads(raw)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(f"RequestInfoFunctionArgs JSON payload is malformed: {exc}") from exc
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("RequestInfoFunctionArgs JSON payload must decode to a mapping")
|
||||
return cls.from_dict(cast(dict[str, Any], parsed))
|
||||
return cls(
|
||||
request_id=payload.get("request_id", ""),
|
||||
request_event=WorkflowEvent.from_dict(payload.get("request_event", {})),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -129,16 +125,11 @@ class WorkflowAgent(BaseAgent):
|
||||
**kwargs,
|
||||
)
|
||||
self._workflow: Workflow = workflow
|
||||
self._pending_requests: dict[str, WorkflowEvent[Any]] = {}
|
||||
|
||||
@property
|
||||
def workflow(self) -> Workflow:
|
||||
return self._workflow
|
||||
|
||||
@property
|
||||
def pending_requests(self) -> dict[str, WorkflowEvent[Any]]:
|
||||
return self._pending_requests
|
||||
|
||||
# region Run Methods
|
||||
|
||||
@overload
|
||||
@@ -182,7 +173,7 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
Args:
|
||||
messages: The message(s) to send to the workflow. Required for new runs,
|
||||
should be None when resuming from checkpoint.
|
||||
could be None if only restoring the underlying workflow from a checkpoint.
|
||||
|
||||
Keyword Args:
|
||||
stream: If True, returns an async iterable of updates. If False (default),
|
||||
@@ -416,101 +407,79 @@ class WorkflowAgent(BaseAgent):
|
||||
Yields:
|
||||
WorkflowEvent objects from the workflow execution.
|
||||
"""
|
||||
# Determine the execution mode based on state.
|
||||
# The streaming flag controls the workflow's internal streaming mode,
|
||||
# which affects executor behavior (e.g. AgentExecutor emits different event
|
||||
# types in streaming vs non-streaming mode).
|
||||
if bool(self.pending_requests):
|
||||
function_responses = self._process_pending_requests(input_messages)
|
||||
# Restore the workflow state if a checkpoint is provided
|
||||
if checkpoint_id is not None:
|
||||
if checkpoint_storage is None:
|
||||
raise AgentInvalidRequestException("checkpoint_storage must be provided when checkpoint_id is provided")
|
||||
logger.debug(f"Restoring workflow from checkpoint {checkpoint_id}")
|
||||
# Restore the workflow from checkpoint
|
||||
if streaming:
|
||||
async for event in self.workflow.run(
|
||||
responses=function_responses,
|
||||
stream=True,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
yield event
|
||||
else:
|
||||
for event in await self.workflow.run(
|
||||
responses=function_responses,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
yield event
|
||||
|
||||
elif checkpoint_id is not None:
|
||||
# Restore the prior workflow state from the checkpoint. Shared
|
||||
# state (e.g. accumulated conversation history maintained by the
|
||||
# workflow's executors) survives across turns because Workflow.run
|
||||
# no longer wipes state per call. Callers who want to deliver a
|
||||
# new user message after restore should make a second
|
||||
# `workflow.run(message=...)` call - they are NOT mutually
|
||||
# exclusive on the same instance, but each must be its own call.
|
||||
if streaming:
|
||||
async for event in self.workflow.run(
|
||||
async for _ in self.workflow.run(
|
||||
stream=True,
|
||||
checkpoint_id=checkpoint_id,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
):
|
||||
pass
|
||||
else:
|
||||
_ = await self.workflow.run(
|
||||
checkpoint_id=checkpoint_id,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
)
|
||||
if not input_messages:
|
||||
logger.info("No input messages provided; the workflow has been restored to the checkpoint state.")
|
||||
return
|
||||
|
||||
final_state = self._workflow.status
|
||||
logger.debug(f"Workflow state: {final_state}")
|
||||
|
||||
if final_state == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS:
|
||||
# Extract function responses from input messages, and ensure that
|
||||
# only function responses are present in messages if there is any
|
||||
# pending request.
|
||||
# NOTE: It is possible that some pending requests are not fulfilled,
|
||||
# and we will let the workflow to handle this -- the agent does not
|
||||
# have an opinion on this.
|
||||
function_responses = self._extract_function_responses(input_messages)
|
||||
if streaming:
|
||||
async for event in self.workflow.run(
|
||||
responses=function_responses,
|
||||
stream=True,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
yield event
|
||||
else:
|
||||
for event in await self.workflow.run(
|
||||
checkpoint_id=checkpoint_id,
|
||||
responses=function_responses,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
yield event
|
||||
elif final_state == WorkflowRunState.IDLE:
|
||||
if streaming:
|
||||
async for event in self.workflow.run(
|
||||
message=input_messages,
|
||||
stream=True,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
yield event
|
||||
else:
|
||||
for event in await self.workflow.run(
|
||||
message=input_messages,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
yield event
|
||||
|
||||
else:
|
||||
if streaming:
|
||||
async for event in self.workflow.run(
|
||||
message=input_messages,
|
||||
stream=True,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
yield event
|
||||
else:
|
||||
for event in await self.workflow.run(
|
||||
message=input_messages,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
yield event
|
||||
raise AgentException(f"The underlying workflow is in an invalid state to restart: {final_state}.")
|
||||
|
||||
# endregion Run Methods
|
||||
|
||||
def _process_pending_requests(self, input_messages: Sequence[Message]) -> dict[str, Any]:
|
||||
"""Process pending requests by extracting function responses and updating state.
|
||||
|
||||
Args:
|
||||
input_messages: Input messages that may contain function responses.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping request IDs to their response data.
|
||||
"""
|
||||
logger.info(f"Continuing workflow to address {len(self.pending_requests)} requests")
|
||||
|
||||
# Extract function responses from input messages, and ensure that
|
||||
# only function responses are present in messages if there is any
|
||||
# pending request.
|
||||
function_responses = self._extract_function_responses(input_messages)
|
||||
|
||||
# Pop pending requests if fulfilled.
|
||||
for request_id in list(self.pending_requests.keys()):
|
||||
if request_id in function_responses:
|
||||
self.pending_requests.pop(request_id)
|
||||
|
||||
# NOTE: It is possible that some pending requests are not fulfilled,
|
||||
# and we will let the workflow to handle this -- the agent does not
|
||||
# have an opinion on this.
|
||||
return function_responses
|
||||
|
||||
def _convert_workflow_events_to_agent_response(
|
||||
self,
|
||||
response_id: str,
|
||||
@@ -528,10 +497,10 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
for output_event in output_events:
|
||||
if output_event.type == "request_info":
|
||||
function_call, approval_request = self._process_request_info_event(output_event)
|
||||
request_content = self._process_request_info_event(output_event)
|
||||
messages.append(
|
||||
Message(
|
||||
contents=[function_call, approval_request],
|
||||
contents=[request_content],
|
||||
role="assistant",
|
||||
author_name=output_event.source_executor_id,
|
||||
message_id=str(uuid.uuid4()),
|
||||
@@ -598,38 +567,6 @@ class WorkflowAgent(BaseAgent):
|
||||
raw_representation=raw_representations,
|
||||
)
|
||||
|
||||
def _process_request_info_event(
|
||||
self,
|
||||
event: WorkflowEvent[Any],
|
||||
) -> tuple[Content, Content]:
|
||||
"""Convert a request_info event to FunctionCallContent and FunctionApprovalRequestContent.
|
||||
|
||||
Args:
|
||||
event: A WorkflowEvent with type='request_info'.
|
||||
|
||||
Returns:
|
||||
A tuple of (FunctionCallContent, FunctionApprovalRequestContent).
|
||||
"""
|
||||
request_id = event.request_id
|
||||
if not request_id:
|
||||
raise ValueError("request_info event must have a request_id")
|
||||
|
||||
self.pending_requests[request_id] = event
|
||||
|
||||
args = self.RequestInfoFunctionArgs(request_id=request_id, data=event.data).to_dict()
|
||||
|
||||
function_call = Content.from_function_call(
|
||||
call_id=request_id,
|
||||
name=self.REQUEST_INFO_FUNCTION_NAME,
|
||||
arguments=args,
|
||||
)
|
||||
approval_request = Content.from_function_approval_request(
|
||||
id=request_id,
|
||||
function_call=function_call,
|
||||
additional_properties={"request_id": request_id},
|
||||
)
|
||||
return function_call, approval_request
|
||||
|
||||
def _convert_workflow_event_to_agent_response_updates(
|
||||
self,
|
||||
response_id: str,
|
||||
@@ -731,85 +668,72 @@ class WorkflowAgent(BaseAgent):
|
||||
]
|
||||
|
||||
if event.type == "request_info":
|
||||
# Store the pending request for later correlation
|
||||
request_id = event.request_id
|
||||
if not request_id:
|
||||
raise ValueError("request_info event must have a request_id")
|
||||
|
||||
self.pending_requests[request_id] = event
|
||||
|
||||
args = self.RequestInfoFunctionArgs(request_id=request_id, data=event.data).to_dict()
|
||||
|
||||
function_call = Content.from_function_call(
|
||||
call_id=request_id,
|
||||
name=self.REQUEST_INFO_FUNCTION_NAME,
|
||||
arguments=args,
|
||||
)
|
||||
approval_request = Content.from_function_approval_request(
|
||||
id=request_id,
|
||||
function_call=function_call,
|
||||
additional_properties={"request_id": request_id},
|
||||
)
|
||||
request_content = self._process_request_info_event(event)
|
||||
return [
|
||||
AgentResponseUpdate(
|
||||
contents=[function_call, approval_request],
|
||||
contents=[request_content],
|
||||
role="assistant",
|
||||
author_name=self.name,
|
||||
response_id=response_id,
|
||||
message_id=str(uuid.uuid4()),
|
||||
created_at=datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
|
||||
raw_representation=event,
|
||||
)
|
||||
]
|
||||
|
||||
# Ignore workflow-internal events
|
||||
return []
|
||||
|
||||
def _process_request_info_event(
|
||||
self,
|
||||
event: WorkflowEvent[Any],
|
||||
) -> Content:
|
||||
"""Convert a request_info event to FunctionApprovalRequestContent.
|
||||
|
||||
Args:
|
||||
event: A WorkflowEvent with type='request_info'.
|
||||
|
||||
Returns:
|
||||
A content object representing the request info. The content can be a `function_approval_request`
|
||||
or a `function_call` depending on the structure of the event data.
|
||||
|
||||
Note:
|
||||
If the event data is already a FunctionApprovalRequestContent, it will be returned as-is.
|
||||
"""
|
||||
if isinstance(event.data, Content) and event.data.user_input_request:
|
||||
# Return the event data as-is if it's already a properly formed FunctionApprovalRequestContent
|
||||
return event.data
|
||||
|
||||
request_id = event.request_id
|
||||
args = self.RequestInfoFunctionArgs(request_id=request_id, request_event=event).to_dict()
|
||||
|
||||
return Content.from_function_call(
|
||||
call_id=request_id,
|
||||
name=self.REQUEST_INFO_FUNCTION_NAME,
|
||||
arguments=args,
|
||||
)
|
||||
|
||||
def _extract_function_responses(self, input_messages: Sequence[Message]) -> dict[str, Any]:
|
||||
"""Extract function responses from input messages."""
|
||||
"""Extract function responses from input messages.
|
||||
|
||||
The responses are for pending requests that the workflow is waiting on, and
|
||||
will be passed to the workflow. The pending requests are processed to either
|
||||
`function_approval_request` or `function_call` content by `_process_request_info_event`.
|
||||
"""
|
||||
function_responses: dict[str, Any] = {}
|
||||
for message in input_messages:
|
||||
for content in message.contents:
|
||||
if content.type == "function_approval_response":
|
||||
# Parse the function arguments to recover request payload
|
||||
arguments_payload = content.function_call.arguments # type: ignore[attr-defined, union-attr]
|
||||
if isinstance(arguments_payload, str):
|
||||
try:
|
||||
parsed_args = self.RequestInfoFunctionArgs.from_json(arguments_payload)
|
||||
except ValueError as exc:
|
||||
raise AgentInvalidResponseException(
|
||||
"FunctionApprovalResponseContent arguments must decode to a mapping."
|
||||
) from exc
|
||||
elif isinstance(arguments_payload, dict):
|
||||
parsed_args = self.RequestInfoFunctionArgs.from_dict(arguments_payload)
|
||||
else:
|
||||
raise AgentInvalidResponseException(
|
||||
"FunctionApprovalResponseContent arguments must be a mapping or JSON string."
|
||||
)
|
||||
|
||||
request_id = parsed_args.request_id or content.id # type: ignore[attr-defined]
|
||||
if not content.approved: # type: ignore[attr-defined]
|
||||
raise AgentInvalidResponseException(f"Request '{request_id}' was not approved by the caller.")
|
||||
|
||||
if request_id in self.pending_requests:
|
||||
function_responses[request_id] = parsed_args.data
|
||||
elif bool(self.pending_requests):
|
||||
raise AgentInvalidRequestException(
|
||||
"Only responses for pending requests are allowed when there are outstanding approvals."
|
||||
)
|
||||
request_id: str = content.id # type: ignore[assignment]
|
||||
function_responses[request_id] = content
|
||||
elif content.type == "function_result":
|
||||
request_id = content.call_id # type: ignore[attr-defined]
|
||||
if request_id in self.pending_requests:
|
||||
response_data = content.result if hasattr(content, "result") else str(content) # type: ignore[attr-defined]
|
||||
function_responses[request_id] = response_data
|
||||
elif bool(self.pending_requests):
|
||||
raise AgentInvalidRequestException(
|
||||
"Only function responses for pending requests are allowed while requests are outstanding."
|
||||
)
|
||||
response_data = content.result if hasattr(content, "result") else str(content) # type: ignore[attr-defined]
|
||||
function_responses[content.call_id] = response_data # type: ignore
|
||||
else:
|
||||
if bool(self.pending_requests):
|
||||
raise AgentInvalidResponseException(
|
||||
"Unexpected content type while awaiting request info responses."
|
||||
)
|
||||
raise AgentInvalidResponseException(
|
||||
"Unexpected content type while awaiting request info responses."
|
||||
)
|
||||
|
||||
return function_responses
|
||||
|
||||
def _extract_contents(self, data: Any) -> list[Content]:
|
||||
|
||||
@@ -429,15 +429,30 @@ class AgentExecutor(Executor):
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
)
|
||||
await ctx.yield_output(response)
|
||||
|
||||
# Handle any user input requests
|
||||
if response.user_input_requests:
|
||||
user_input_request_count = len(response.user_input_requests)
|
||||
total_message_content_count = sum(len(msg.contents) for msg in response.messages)
|
||||
if user_input_request_count != total_message_content_count:
|
||||
logger.warning(
|
||||
"Response %s contains %d user input requests but total message contents are %d. "
|
||||
"This indicates the response contains both user input requests and message contents. "
|
||||
"Double check if this is the intended behavior, as non user input request contents in "
|
||||
"this response will not be emitted.",
|
||||
response.response_id,
|
||||
user_input_request_count,
|
||||
total_message_content_count,
|
||||
)
|
||||
for user_input_request in response.user_input_requests:
|
||||
self._pending_agent_requests[user_input_request.id] = user_input_request # type: ignore[index]
|
||||
await ctx.request_info(user_input_request, Content)
|
||||
await ctx.request_info(user_input_request, Content, request_id=user_input_request.id)
|
||||
return None
|
||||
|
||||
# Only yield output if the response is complete and not waiting for user input.
|
||||
# This is to avoid emitting two events of different types ('output' and 'request_info')
|
||||
# that carry the same payload.
|
||||
await ctx.yield_output(response)
|
||||
return response
|
||||
|
||||
async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUpdate]) -> AgentResponse | None:
|
||||
@@ -472,9 +487,25 @@ class AgentExecutor(Executor):
|
||||
)
|
||||
async for update in stream:
|
||||
updates.append(update)
|
||||
await ctx.yield_output(update)
|
||||
if update.user_input_requests:
|
||||
user_input_request_count = len(update.user_input_requests)
|
||||
total_message_content_count = len(update.contents)
|
||||
if user_input_request_count != total_message_content_count:
|
||||
logger.warning(
|
||||
"Response update %s contains %d user input requests but total message contents are %d. "
|
||||
"This indicates the response update contains both user input requests and message contents. "
|
||||
"Double check if this is the intended behavior, as non user input request contents will "
|
||||
"not be emitted.",
|
||||
update.response_id,
|
||||
user_input_request_count,
|
||||
total_message_content_count,
|
||||
)
|
||||
streamed_user_input_requests.extend(update.user_input_requests)
|
||||
else:
|
||||
# Only yield output events for updates that do not contain user input requests.
|
||||
# This is to avoid emitting two events of different types ('output' and 'request_info')
|
||||
# that carry the same payload.
|
||||
await ctx.yield_output(update)
|
||||
|
||||
# Prefer stream finalization when available so result hooks run
|
||||
# (e.g., thread conversation updates). Fall back to reconstructing from updates
|
||||
@@ -509,7 +540,7 @@ class AgentExecutor(Executor):
|
||||
if user_input_requests:
|
||||
for user_input_request in user_input_requests:
|
||||
self._pending_agent_requests[user_input_request.id] = user_input_request # type: ignore[index]
|
||||
await ctx.request_info(user_input_request, Content)
|
||||
await ctx.request_info(user_input_request, Content, request_id=user_input_request.id)
|
||||
return None
|
||||
|
||||
return response
|
||||
|
||||
@@ -360,6 +360,22 @@ class Workflow(DictConvertible):
|
||||
# Flag to prevent concurrent workflow executions
|
||||
self._is_running = False
|
||||
|
||||
# Current run-level status of this workflow instance. Updated in lockstep with
|
||||
# the status events emitted from `_run_workflow_with_tracing`. Defaults to IDLE
|
||||
# for a freshly built workflow that has not yet been run.
|
||||
self._status: WorkflowRunState = WorkflowRunState.IDLE
|
||||
|
||||
@property
|
||||
def status(self) -> WorkflowRunState:
|
||||
"""Return the current run-level status of this workflow instance.
|
||||
|
||||
Mirrors the most recent status event emitted by the workflow. Safe to read at
|
||||
any time: workflows run on a single asyncio event loop, and the underlying
|
||||
attribute is a single enum reference whose assignment is atomic under the
|
||||
CPython GIL, so no locking is required.
|
||||
"""
|
||||
return self._status
|
||||
|
||||
def _ensure_not_running(self) -> None:
|
||||
"""Ensure the workflow is not already running."""
|
||||
if self._is_running:
|
||||
@@ -513,8 +529,9 @@ class Workflow(DictConvertible):
|
||||
with _framework_event_origin():
|
||||
started = WorkflowEvent.started()
|
||||
yield started # noqa: RUF070
|
||||
self._status = WorkflowRunState.IN_PROGRESS
|
||||
with _framework_event_origin():
|
||||
in_progress = WorkflowEvent.status(WorkflowRunState.IN_PROGRESS)
|
||||
in_progress = WorkflowEvent.status(self._status)
|
||||
yield in_progress # noqa: RUF070
|
||||
|
||||
# Per-run reset for fresh-message runs only. We deliberately
|
||||
@@ -569,17 +586,20 @@ class Workflow(DictConvertible):
|
||||
|
||||
if event.type == "request_info" and not emitted_in_progress_pending:
|
||||
emitted_in_progress_pending = True
|
||||
self._status = WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS
|
||||
with _framework_event_origin():
|
||||
pending_status = WorkflowEvent.status(WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS)
|
||||
pending_status = WorkflowEvent.status(self._status)
|
||||
yield pending_status # noqa: RUF070
|
||||
# Workflow runs until idle - emit final status based on whether requests are pending
|
||||
if saw_request:
|
||||
self._status = WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
|
||||
with _framework_event_origin():
|
||||
terminal_status = WorkflowEvent.status(WorkflowRunState.IDLE_WITH_PENDING_REQUESTS)
|
||||
terminal_status = WorkflowEvent.status(self._status)
|
||||
yield terminal_status
|
||||
else:
|
||||
self._status = WorkflowRunState.IDLE
|
||||
with _framework_event_origin():
|
||||
terminal_status = WorkflowEvent.status(WorkflowRunState.IDLE)
|
||||
terminal_status = WorkflowEvent.status(self._status)
|
||||
yield terminal_status
|
||||
|
||||
span.add_event(OtelAttr.WORKFLOW_COMPLETED)
|
||||
@@ -593,6 +613,7 @@ class Workflow(DictConvertible):
|
||||
with _framework_event_origin():
|
||||
failed_event = WorkflowEvent.failed(details)
|
||||
yield failed_event # noqa: RUF070
|
||||
self._status = WorkflowRunState.FAILED
|
||||
with _framework_event_origin():
|
||||
failed_status = WorkflowEvent.status(WorkflowRunState.FAILED)
|
||||
yield failed_status # noqa: RUF070
|
||||
|
||||
@@ -25,6 +25,7 @@ from agent_framework import (
|
||||
prepend_agent_framework_to_user_agent,
|
||||
tool,
|
||||
)
|
||||
from agent_framework._serialization import make_json_safe
|
||||
from agent_framework.observability import (
|
||||
ROLE_EVENT_MAP,
|
||||
AgentTelemetryLayer,
|
||||
@@ -3195,17 +3196,15 @@ def test_capture_messages_with_prepared_request_info_function_call_arguments(spa
|
||||
|
||||
from opentelemetry import trace
|
||||
|
||||
from agent_framework import WorkflowAgent
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HandoffRequest:
|
||||
target_agent: str
|
||||
reason: str
|
||||
|
||||
arguments = WorkflowAgent.RequestInfoFunctionArgs(
|
||||
request_id="call_dc",
|
||||
data=HandoffRequest(target_agent="helper", reason="overflow"),
|
||||
).to_dict()
|
||||
arguments = {
|
||||
"request_id": "call_dc",
|
||||
"data": make_json_safe(HandoffRequest(target_agent="helper", reason="overflow")),
|
||||
}
|
||||
msg = Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
|
||||
@@ -699,3 +699,171 @@ async def test_resolve_executor_kwargs_empty_per_executor_does_not_fallback_to_g
|
||||
resolved = {"exec_a": {}, GLOBAL_KWARGS_KEY: {"global_key": "global_val"}}
|
||||
result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage]
|
||||
assert result == {}
|
||||
|
||||
|
||||
# region Tool approval emission
|
||||
|
||||
|
||||
class _ApprovalEmittingAgent(BaseAgent):
|
||||
"""Agent that returns a single ``function_approval_request`` Content.
|
||||
|
||||
Used to verify that ``AgentExecutor`` does *not* surface the approval
|
||||
payload via both an ``output`` event and a ``request_info`` event in the
|
||||
same superstep — only the ``request_info`` event must carry it.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
approval_request_id: str = "apr_1",
|
||||
tool_name: str = "delete_file",
|
||||
tool_arguments: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._approval_request_id = approval_request_id
|
||||
self._tool_name = tool_name
|
||||
self._tool_arguments: dict[str, Any] = tool_arguments or {"path": "/tmp/secret.txt"}
|
||||
self.run_count = 0
|
||||
|
||||
def _build_approval_content(self) -> Content:
|
||||
function_call = Content.from_function_call(
|
||||
call_id=self._approval_request_id,
|
||||
name=self._tool_name,
|
||||
arguments=self._tool_arguments,
|
||||
)
|
||||
return Content.from_function_approval_request(id=self._approval_request_id, function_call=function_call)
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
self.run_count += 1
|
||||
approval = self._build_approval_content()
|
||||
|
||||
if stream:
|
||||
|
||||
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
|
||||
yield AgentResponseUpdate(contents=[approval], role="assistant")
|
||||
|
||||
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates)
|
||||
|
||||
async def _run() -> AgentResponse:
|
||||
return AgentResponse(messages=[Message("assistant", [approval])])
|
||||
|
||||
return _run()
|
||||
|
||||
|
||||
def _has_approval_payload(event: WorkflowEvent[Any]) -> bool:
|
||||
"""Return True if the event's data carries a ``function_approval_request`` content."""
|
||||
data: Any = event.data
|
||||
|
||||
def _contents_of(value: Any) -> list[Content]:
|
||||
if isinstance(value, AgentResponseUpdate):
|
||||
return list(value.contents)
|
||||
if isinstance(value, AgentResponse):
|
||||
return [c for m in value.messages for c in m.contents]
|
||||
if isinstance(value, AgentExecutorResponse):
|
||||
return [c for m in value.agent_response.messages for c in m.contents]
|
||||
if isinstance(value, Message):
|
||||
return list(value.contents)
|
||||
if isinstance(value, Content):
|
||||
return [value]
|
||||
return []
|
||||
|
||||
return any(c.type == "function_approval_request" for c in _contents_of(data))
|
||||
|
||||
|
||||
async def test_agent_executor_does_not_double_emit_approval_non_streaming() -> None:
|
||||
"""Non-streaming: approval payload must only appear in the ``request_info`` event.
|
||||
|
||||
Regression test for the bug where ``AgentExecutor._run_agent`` first
|
||||
``yield_output``-ed the response (carrying the approval Content) and then
|
||||
additionally emitted a ``request_info`` event for the same payload.
|
||||
"""
|
||||
agent = _ApprovalEmittingAgent(id="approve_agent", name="ApproveAgent", approval_request_id="apr_ns_1")
|
||||
executor = AgentExecutor(agent, id="approve_exec")
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
|
||||
request_info_events: list[WorkflowEvent[Any]] = []
|
||||
output_events: list[WorkflowEvent[Any]] = []
|
||||
|
||||
for event in await workflow.run("please delete it"):
|
||||
if event.type == "request_info":
|
||||
request_info_events.append(event)
|
||||
elif event.type == "output":
|
||||
output_events.append(event)
|
||||
|
||||
assert len(request_info_events) == 1
|
||||
assert _has_approval_payload(request_info_events[0])
|
||||
# The approval payload must not also be surfaced as a workflow output.
|
||||
assert not any(_has_approval_payload(e) for e in output_events)
|
||||
assert agent.run_count == 1
|
||||
|
||||
|
||||
async def test_agent_executor_does_not_double_emit_approval_streaming() -> None:
|
||||
"""Streaming: per-update approval payload must not be ``yield_output``-ed."""
|
||||
agent = _ApprovalEmittingAgent(id="approve_agent_s", name="ApproveAgentS", approval_request_id="apr_st_1")
|
||||
executor = AgentExecutor(agent, id="approve_exec_s")
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
|
||||
request_info_events: list[WorkflowEvent[Any]] = []
|
||||
output_events: list[WorkflowEvent[Any]] = []
|
||||
|
||||
async for event in workflow.run("please delete it", stream=True):
|
||||
if event.type == "request_info":
|
||||
request_info_events.append(event)
|
||||
elif event.type == "output":
|
||||
output_events.append(event)
|
||||
|
||||
assert len(request_info_events) == 1
|
||||
assert _has_approval_payload(request_info_events[0])
|
||||
assert not any(_has_approval_payload(e) for e in output_events)
|
||||
assert agent.run_count == 1
|
||||
|
||||
|
||||
async def test_agent_executor_request_info_uses_user_input_request_id() -> None:
|
||||
"""``ctx.request_info`` must register the request under the agent's approval id.
|
||||
|
||||
This makes the workflow's pending-request id round-trip with the
|
||||
``function_approval_response.id`` the caller echoes back, so
|
||||
``Workflow._send_responses_internal`` can look it up directly.
|
||||
"""
|
||||
agent = _ApprovalEmittingAgent(id="approve_agent_id", name="ApproveAgentId", approval_request_id="apr_match")
|
||||
executor = AgentExecutor(agent, id="approve_exec_id")
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
|
||||
request_info_events: list[WorkflowEvent[Any]] = []
|
||||
async for event in workflow.run("please delete it", stream=True):
|
||||
if event.type == "request_info":
|
||||
request_info_events.append(event)
|
||||
|
||||
assert len(request_info_events) == 1
|
||||
assert request_info_events[0].request_id == "apr_match"
|
||||
|
||||
|
||||
# endregion Tool approval emission
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Sequence
|
||||
from dataclasses import dataclass
|
||||
@@ -30,6 +29,20 @@ from agent_framework import (
|
||||
handler,
|
||||
response_handler,
|
||||
)
|
||||
from agent_framework._workflows._typing_utils import deserialize_type
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandoffRequest:
|
||||
"""Module-level dataclass used by request_info tests.
|
||||
|
||||
Defined at module scope (not nested inside a test method) so
|
||||
``serialize_type``/``deserialize_type`` can round-trip the request_type via
|
||||
the importable qualified name ``tests.workflow.test_workflow_agent.HandoffRequest``.
|
||||
"""
|
||||
|
||||
target_agent: str
|
||||
reason: str
|
||||
|
||||
|
||||
class SimpleExecutor(Executor):
|
||||
@@ -240,52 +253,45 @@ class TestWorkflowAgent:
|
||||
# Should have received an approval request for the request info
|
||||
assert len(updates) > 0
|
||||
|
||||
approval_update: AgentResponseUpdate | None = None
|
||||
request_update: AgentResponseUpdate | None = None
|
||||
for update in updates:
|
||||
if any(content.type == "function_approval_request" for content in update.contents):
|
||||
approval_update = update
|
||||
if any(content.type == "function_call" for content in update.contents):
|
||||
request_update = update
|
||||
break
|
||||
|
||||
assert approval_update is not None, "Should have received a request_info approval request"
|
||||
assert request_update is not None, "Should have received a request_info wrapped in a function_call content"
|
||||
|
||||
function_call = next(content for content in approval_update.contents if content.type == "function_call")
|
||||
approval_request = next(
|
||||
content for content in approval_update.contents if content.type == "function_approval_request"
|
||||
)
|
||||
request_function_call = next(content for content in request_update.contents if content.type == "function_call")
|
||||
assert request_function_call.call_id is not None
|
||||
|
||||
# Verify the function call has expected structure
|
||||
assert function_call.call_id is not None
|
||||
assert function_call.name == "request_info"
|
||||
assert isinstance(function_call.arguments, dict)
|
||||
assert function_call.arguments.get("request_id") == approval_request.id
|
||||
assert request_function_call.name == WorkflowAgent.REQUEST_INFO_FUNCTION_NAME
|
||||
assert isinstance(request_function_call.arguments, dict)
|
||||
assert request_function_call.arguments.get("request_id") is not None
|
||||
assert request_function_call.arguments.get("request_event") is not None
|
||||
request_event = request_function_call.arguments["request_event"]
|
||||
assert request_event.get("type") == "request_info"
|
||||
assert deserialize_type(request_event.get("response_type")) is str
|
||||
|
||||
# Approval request should reference the same function call
|
||||
assert approval_request.id is not None
|
||||
assert approval_request.function_call is not None
|
||||
assert approval_request.function_call.call_id == function_call.call_id
|
||||
assert approval_request.function_call.name == function_call.name
|
||||
deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(request_function_call.arguments)
|
||||
assert deserialized_args.request_id == request_function_call.call_id
|
||||
assert isinstance(deserialized_args.request_event, WorkflowEvent)
|
||||
assert deserialized_args.request_event.type == "request_info"
|
||||
assert deserialized_args.request_event.data == "Mock request data"
|
||||
assert deserialized_args.request_event.response_type is str
|
||||
|
||||
# Verify the request is tracked in pending_requests
|
||||
assert len(agent.pending_requests) == 1
|
||||
assert function_call.call_id in agent.pending_requests
|
||||
pending_requests = await workflow._runner_context.get_pending_request_info_events()
|
||||
assert len(pending_requests) == 1
|
||||
assert request_function_call.call_id in pending_requests
|
||||
|
||||
# Now provide an approval response with updated arguments to test continuation
|
||||
response_args = WorkflowAgent.RequestInfoFunctionArgs(
|
||||
request_id=approval_request.id,
|
||||
data="User provided answer",
|
||||
).to_dict()
|
||||
|
||||
approval_response = Content.from_function_approval_response(
|
||||
approved=True,
|
||||
id=approval_request.id,
|
||||
function_call=Content.from_function_call(
|
||||
call_id=function_call.call_id,
|
||||
name=function_call.name,
|
||||
arguments=response_args,
|
||||
),
|
||||
# Now provide a function result response with updated arguments to test continuation
|
||||
function_result = Content.from_function_result(
|
||||
call_id=request_function_call.call_id,
|
||||
result="Mock response to request info",
|
||||
)
|
||||
|
||||
response_message = Message(role="user", contents=[approval_response])
|
||||
response_message = Message(role="user", contents=[function_result])
|
||||
|
||||
# Continue the workflow with the response
|
||||
continuation_result = await agent.run(response_message)
|
||||
@@ -294,16 +300,11 @@ class TestWorkflowAgent:
|
||||
assert isinstance(continuation_result, AgentResponse)
|
||||
|
||||
# Verify cleanup - pending requests should be cleared after function response handling
|
||||
assert len(agent.pending_requests) == 0
|
||||
pending_requests = await workflow._runner_context.get_pending_request_info_events()
|
||||
assert len(pending_requests) == 0
|
||||
|
||||
def test_request_info_dataclass_arguments_are_serialized_when_content_is_created(self) -> None:
|
||||
"""Test WorkflowAgent prepares request_info arguments before observability captures messages."""
|
||||
|
||||
@dataclass
|
||||
class HandoffRequest:
|
||||
target_agent: str
|
||||
reason: str
|
||||
|
||||
executor = SimpleExecutor(id="executor1", response_text="Response")
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="Request Test Agent")
|
||||
@@ -314,14 +315,367 @@ class TestWorkflowAgent:
|
||||
response_type=str,
|
||||
)
|
||||
|
||||
function_call, approval_request = agent._process_request_info_event(event) # pyright: ignore[reportPrivateUsage]
|
||||
request_function_call = agent._process_request_info_event(event) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
assert function_call.arguments == {
|
||||
"request_id": "request_123",
|
||||
"data": {"target_agent": "helper", "reason": "overflow"},
|
||||
}
|
||||
assert approval_request.function_call is function_call
|
||||
assert json.loads(json.dumps(function_call.arguments)) == function_call.arguments
|
||||
assert request_function_call.call_id == "request_123"
|
||||
assert isinstance(request_function_call.arguments, dict)
|
||||
assert request_function_call.arguments.get("request_event") is not None
|
||||
request_event = request_function_call.arguments["request_event"]
|
||||
assert request_event.get("type") == "request_info"
|
||||
assert request_event.get("request_id") == "request_123"
|
||||
assert request_event.get("source_executor_id") == "executor1"
|
||||
assert deserialize_type(request_event.get("response_type")) is str
|
||||
assert request_event.get("data") == HandoffRequest(target_agent="helper", reason="overflow")
|
||||
|
||||
deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(request_function_call.arguments)
|
||||
assert deserialized_args.request_id == "request_123"
|
||||
assert isinstance(deserialized_args.request_event, WorkflowEvent)
|
||||
assert deserialized_args.request_event.type == "request_info"
|
||||
assert deserialized_args.request_event.data == HandoffRequest(target_agent="helper", reason="overflow")
|
||||
assert deserialized_args.request_event.response_type is str
|
||||
|
||||
def test_process_request_info_event_passes_through_function_approval_request(self) -> None:
|
||||
"""If the event data is already a function approval request, it is forwarded unchanged.
|
||||
|
||||
Tool-approval requests emitted by an inner agent surface as ``Content``
|
||||
objects with ``user_input_request=True``. ``WorkflowAgent`` must not
|
||||
re-wrap these inside a synthesized ``request_info`` function call;
|
||||
instead it should return the original content as-is so callers can
|
||||
respond with a matching ``function_approval_response``.
|
||||
"""
|
||||
executor = SimpleExecutor(id="executor1", response_text="Response")
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="Approval Passthrough Agent")
|
||||
|
||||
approval_id = "approval-passthrough-1"
|
||||
inner_function_call = Content.from_function_call(
|
||||
call_id="tool-call-1",
|
||||
name="delete_file",
|
||||
arguments={"path": "/tmp/x"},
|
||||
)
|
||||
approval_request = Content.from_function_approval_request(
|
||||
id=approval_id,
|
||||
function_call=inner_function_call,
|
||||
)
|
||||
event = WorkflowEvent.request_info(
|
||||
request_id=approval_id,
|
||||
source_executor_id="executor1",
|
||||
request_data=approval_request,
|
||||
response_type=Content,
|
||||
)
|
||||
|
||||
result = agent._process_request_info_event(event) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# The original FunctionApprovalRequestContent is returned as-is — same
|
||||
# instance, with the original tool name preserved (NOT replaced by the
|
||||
# synthesized REQUEST_INFO_FUNCTION_NAME).
|
||||
assert result is approval_request
|
||||
assert result.type == "function_approval_request"
|
||||
assert result.id == approval_id
|
||||
assert result.user_input_request is True
|
||||
assert result.function_call is inner_function_call # type: ignore[attr-defined]
|
||||
assert result.function_call.name == "delete_file" # type: ignore[attr-defined]
|
||||
assert result.function_call.name != WorkflowAgent.REQUEST_INFO_FUNCTION_NAME # type: ignore[attr-defined]
|
||||
|
||||
def test_extract_function_responses_passes_through_approval_response_approved(self) -> None:
|
||||
"""A function_approval_response with approved=True is keyed by content.id and forwarded as-is.
|
||||
|
||||
After the refactor, ``WorkflowAgent`` no longer unwraps a synthesized
|
||||
``request_info`` function call from approval responses — the response
|
||||
content is routed straight back to the workflow under its own ``id``,
|
||||
which matches the pending request id surfaced by
|
||||
``_process_request_info_event``.
|
||||
"""
|
||||
executor = SimpleExecutor(id="executor1", response_text="Response")
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="Approval Response Agent")
|
||||
|
||||
approval_id = "approval-response-approved-1"
|
||||
inner_function_call = Content.from_function_call(
|
||||
call_id="tool-call-1",
|
||||
name="delete_file",
|
||||
arguments={"path": "/tmp/x"},
|
||||
)
|
||||
approval_request = Content.from_function_approval_request(
|
||||
id=approval_id,
|
||||
function_call=inner_function_call,
|
||||
)
|
||||
approval_response = approval_request.to_function_approval_response(approved=True) # type: ignore[attr-defined]
|
||||
message = Message(role="user", contents=[approval_response])
|
||||
|
||||
responses = agent._extract_function_responses([message]) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
assert set(responses.keys()) == {approval_id}
|
||||
assert responses[approval_id] is approval_response
|
||||
assert responses[approval_id].approved is True # type: ignore[attr-defined]
|
||||
|
||||
def test_extract_function_responses_passes_through_approval_response_denied(self) -> None:
|
||||
"""A function_approval_response with approved=False is forwarded the same way as an approval.
|
||||
|
||||
Only the ``approved`` flag changes — routing back to the workflow is
|
||||
identical for accept and reject paths.
|
||||
"""
|
||||
executor = SimpleExecutor(id="executor1", response_text="Response")
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="Approval Response Agent")
|
||||
|
||||
approval_id = "approval-response-denied-1"
|
||||
inner_function_call = Content.from_function_call(
|
||||
call_id="tool-call-2",
|
||||
name="send_email",
|
||||
arguments={"to": "alice@example.com"},
|
||||
)
|
||||
approval_request = Content.from_function_approval_request(
|
||||
id=approval_id,
|
||||
function_call=inner_function_call,
|
||||
)
|
||||
approval_response = approval_request.to_function_approval_response(approved=False) # type: ignore[attr-defined]
|
||||
message = Message(role="user", contents=[approval_response])
|
||||
|
||||
responses = agent._extract_function_responses([message]) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
assert set(responses.keys()) == {approval_id}
|
||||
assert responses[approval_id] is approval_response
|
||||
assert responses[approval_id].approved is False # type: ignore[attr-defined]
|
||||
|
||||
async def test_function_approval_request_flows_end_to_end_approved(self) -> None:
|
||||
"""End-to-end: an executor emits a function_approval_request, the agent
|
||||
forwards it unchanged, and an ``approved=True`` response resumes the workflow.
|
||||
|
||||
This exercises the full pass-through path:
|
||||
``ctx.request_info(approval_content, ...)`` -> ``WorkflowAgent`` surfaces
|
||||
the original ``FunctionApprovalRequestContent`` -> caller responds with a
|
||||
``FunctionApprovalResponseContent`` -> ``WorkflowAgent`` routes it back
|
||||
to the workflow which delivers it to the executor's ``@response_handler``.
|
||||
"""
|
||||
approval_id = "e2e-approval-1"
|
||||
inner_function_call = Content.from_function_call(
|
||||
call_id="tool-call-e2e-1",
|
||||
name="delete_file",
|
||||
arguments={"path": "/tmp/x"},
|
||||
)
|
||||
approval_request = Content.from_function_approval_request(
|
||||
id=approval_id,
|
||||
function_call=inner_function_call,
|
||||
)
|
||||
|
||||
class ApprovalRequestingExecutor(Executor):
|
||||
@handler
|
||||
async def handle_message(self, _: list[Message], ctx: WorkflowContext) -> None:
|
||||
await ctx.request_info(approval_request, Content, request_id=approval_id)
|
||||
|
||||
@response_handler
|
||||
async def handle_response(
|
||||
self,
|
||||
original_request: Content,
|
||||
response: Content,
|
||||
ctx: WorkflowContext[Never, AgentResponse],
|
||||
) -> None:
|
||||
assert response.type == "function_approval_response"
|
||||
assert response.id == approval_id # type: ignore[attr-defined]
|
||||
approved = bool(response.approved) # type: ignore[attr-defined]
|
||||
tool_name = original_request.function_call.name # type: ignore[attr-defined]
|
||||
await ctx.yield_output(
|
||||
AgentResponse(
|
||||
messages=[
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[Content.from_text(text=f"{tool_name} approved={approved}")],
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
executor = ApprovalRequestingExecutor(id="approval_requester")
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="E2E Approval Agent")
|
||||
|
||||
# First run: workflow pauses with the approval request.
|
||||
first = await agent.run("please delete it")
|
||||
assert isinstance(first, AgentResponse)
|
||||
|
||||
forwarded = next(
|
||||
(
|
||||
c
|
||||
for m in first.messages
|
||||
for c in m.contents
|
||||
if c.type == "function_approval_request" and c.id == approval_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert forwarded is approval_request, "Approval request must surface unchanged"
|
||||
|
||||
pending = await workflow._runner_context.get_pending_request_info_events()
|
||||
assert approval_id in pending
|
||||
|
||||
# Respond with approved=True.
|
||||
approval_response = approval_request.to_function_approval_response(approved=True) # type: ignore[attr-defined]
|
||||
final = await agent.run(Message(role="user", contents=[approval_response]))
|
||||
|
||||
assert isinstance(final, AgentResponse)
|
||||
final_text = " ".join(m.text or "" for m in final.messages)
|
||||
assert "delete_file approved=True" in final_text
|
||||
|
||||
pending = await workflow._runner_context.get_pending_request_info_events()
|
||||
assert approval_id not in pending
|
||||
|
||||
async def test_function_approval_request_flows_end_to_end_denied(self) -> None:
|
||||
"""End-to-end denied path: ``approved=False`` is delivered to the executor's
|
||||
response handler so the workflow can branch on the rejection."""
|
||||
approval_id = "e2e-approval-deny-1"
|
||||
inner_function_call = Content.from_function_call(
|
||||
call_id="tool-call-e2e-deny-1",
|
||||
name="send_email",
|
||||
arguments={"to": "alice@example.com"},
|
||||
)
|
||||
approval_request = Content.from_function_approval_request(
|
||||
id=approval_id,
|
||||
function_call=inner_function_call,
|
||||
)
|
||||
|
||||
class ApprovalRequestingExecutor(Executor):
|
||||
@handler
|
||||
async def handle_message(self, _: list[Message], ctx: WorkflowContext) -> None:
|
||||
await ctx.request_info(approval_request, Content, request_id=approval_id)
|
||||
|
||||
@response_handler
|
||||
async def handle_response(
|
||||
self,
|
||||
original_request: Content,
|
||||
response: Content,
|
||||
ctx: WorkflowContext[Never, AgentResponse],
|
||||
) -> None:
|
||||
assert response.type == "function_approval_response"
|
||||
assert response.id == approval_id # type: ignore[attr-defined]
|
||||
approved = bool(response.approved) # type: ignore[attr-defined]
|
||||
tool_name = original_request.function_call.name # type: ignore[attr-defined]
|
||||
await ctx.yield_output(
|
||||
AgentResponse(
|
||||
messages=[
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[Content.from_text(text=f"{tool_name} approved={approved}")],
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
executor = ApprovalRequestingExecutor(id="approval_requester_deny")
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="E2E Approval Deny Agent")
|
||||
|
||||
first = await agent.run("please send")
|
||||
assert isinstance(first, AgentResponse)
|
||||
forwarded = next(
|
||||
(
|
||||
c
|
||||
for m in first.messages
|
||||
for c in m.contents
|
||||
if c.type == "function_approval_request" and c.id == approval_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert forwarded is approval_request
|
||||
|
||||
# Respond with approved=False.
|
||||
approval_response = approval_request.to_function_approval_response(approved=False) # type: ignore[attr-defined]
|
||||
final = await agent.run(Message(role="user", contents=[approval_response]))
|
||||
|
||||
assert isinstance(final, AgentResponse)
|
||||
final_text = " ".join(m.text or "" for m in final.messages)
|
||||
assert "send_email approved=False" in final_text
|
||||
|
||||
pending = await workflow._runner_context.get_pending_request_info_events()
|
||||
assert approval_id not in pending
|
||||
|
||||
async def test_request_info_non_approval_flows_end_to_end(self) -> None:
|
||||
"""End-to-end: when request data is not a function approval content, the
|
||||
agent surfaces a synthesized ``function_call`` (name=REQUEST_INFO_FUNCTION_NAME)
|
||||
and routes a matching ``function_result`` back to the executor.
|
||||
"""
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
class HandoffRequestingExecutor(Executor):
|
||||
@handler
|
||||
async def handle_message(self, _: list[Message], ctx: WorkflowContext) -> None:
|
||||
await ctx.request_info(
|
||||
HandoffRequest(target_agent="helper", reason="overflow"),
|
||||
str,
|
||||
)
|
||||
|
||||
@response_handler
|
||||
async def handle_response(
|
||||
self,
|
||||
original_request: HandoffRequest,
|
||||
response: str,
|
||||
ctx: WorkflowContext[Never, AgentResponse],
|
||||
) -> None:
|
||||
captured["original"] = original_request
|
||||
captured["response"] = response
|
||||
await ctx.yield_output(
|
||||
AgentResponse(
|
||||
messages=[
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_text(text=f"handoff to {original_request.target_agent}: {response}")
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
executor = HandoffRequestingExecutor(id="handoff_requester")
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="E2E Handoff Agent")
|
||||
|
||||
# First run: workflow pauses with a synthesized request_info function_call.
|
||||
first = await agent.run("start handoff")
|
||||
assert isinstance(first, AgentResponse)
|
||||
|
||||
function_call = next(
|
||||
(
|
||||
c
|
||||
for m in first.messages
|
||||
for c in m.contents
|
||||
if c.type == "function_call" and c.name == WorkflowAgent.REQUEST_INFO_FUNCTION_NAME
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert function_call is not None, "Expected a synthesized request_info function_call"
|
||||
assert function_call.call_id is not None
|
||||
assert isinstance(function_call.arguments, dict)
|
||||
request_id = function_call.arguments["request_id"]
|
||||
assert function_call.call_id == request_id
|
||||
request_payload = function_call.arguments["request_event"]
|
||||
assert request_payload.get("type") == "request_info"
|
||||
assert request_payload.get("data") == HandoffRequest(target_agent="helper", reason="overflow")
|
||||
|
||||
deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(function_call.arguments)
|
||||
assert deserialized_args.request_id == request_id
|
||||
assert isinstance(deserialized_args.request_event, WorkflowEvent)
|
||||
assert deserialized_args.request_event.type == "request_info"
|
||||
assert deserialized_args.request_event.data == HandoffRequest(target_agent="helper", reason="overflow")
|
||||
assert deserialized_args.request_event.response_type is str
|
||||
|
||||
pending = await workflow._runner_context.get_pending_request_info_events()
|
||||
assert request_id in pending
|
||||
|
||||
# Respond with a function_result keyed by the call_id.
|
||||
function_result = Content.from_function_result(call_id=request_id, result="ok-do-it")
|
||||
final = await agent.run(Message(role="user", contents=[function_result]))
|
||||
|
||||
assert isinstance(final, AgentResponse)
|
||||
final_text = " ".join(m.text or "" for m in final.messages)
|
||||
assert "handoff to helper: ok-do-it" in final_text
|
||||
|
||||
# The executor's response handler received the original request and the response.
|
||||
assert isinstance(captured.get("original"), HandoffRequest)
|
||||
assert captured["original"].target_agent == "helper"
|
||||
assert captured["response"] == "ok-do-it"
|
||||
|
||||
pending = await workflow._runner_context.get_pending_request_info_events()
|
||||
assert request_id not in pending
|
||||
|
||||
def test_workflow_as_agent_method(self) -> None:
|
||||
"""Test that Workflow.as_agent() creates a properly configured WorkflowAgent."""
|
||||
@@ -1592,3 +1946,406 @@ class TestWorkflowAgentMergeUpdates:
|
||||
|
||||
# Order: text (user), text (assistant), function_result (orphan at end)
|
||||
assert content_types == ["text", "text", "function_result"]
|
||||
|
||||
|
||||
class _ToolApprovalMockAgent(SupportsAgentRun):
|
||||
"""Mock agent whose first run returns a FunctionApprovalRequestContent.
|
||||
|
||||
Subsequent runs (after receiving an approval response in the input messages)
|
||||
return a final assistant text response that echoes the approved arguments.
|
||||
|
||||
This mirrors a real agent whose tool invocation requires user approval.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
tool_name: str = "delete_file",
|
||||
tool_arguments: dict[str, Any] | None = None,
|
||||
approval_request_ids: Sequence[str] | None = None,
|
||||
) -> None:
|
||||
self.id = str(uuid.uuid4())
|
||||
self.name = name
|
||||
self.description: str | None = None
|
||||
self._tool_name = tool_name
|
||||
self._tool_arguments = tool_arguments or {"path": "/tmp/example"}
|
||||
# Pre-allocated request ids so the test can verify what the WorkflowAgent forwards.
|
||||
self._approval_request_ids: list[str] = list(approval_request_ids) if approval_request_ids else []
|
||||
self.run_count = 0
|
||||
# Inputs received on the most recent (continuation) run, for assertions.
|
||||
self.last_run_messages: list[Message] = []
|
||||
|
||||
def create_session(self, **kwargs: Any) -> AgentSession:
|
||||
return AgentSession()
|
||||
|
||||
def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession:
|
||||
return AgentSession()
|
||||
|
||||
def _next_request_id(self) -> str:
|
||||
if self._approval_request_ids:
|
||||
return self._approval_request_ids.pop(0)
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def _build_approval_request(self) -> Content:
|
||||
request_id = self._next_request_id()
|
||||
function_call = Content.from_function_call(
|
||||
call_id=request_id,
|
||||
name=self._tool_name,
|
||||
arguments=self._tool_arguments,
|
||||
)
|
||||
return Content.from_function_approval_request(id=request_id, function_call=function_call)
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = ...,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = ...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
if stream:
|
||||
return self._run_stream(messages=messages, session=session, **kwargs)
|
||||
return self._run(messages=messages, session=session, **kwargs)
|
||||
|
||||
def _normalize(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None,
|
||||
) -> list[Message]:
|
||||
if messages is None:
|
||||
return []
|
||||
if isinstance(messages, str):
|
||||
return [Message(role="user", contents=[Content.from_text(text=messages)])]
|
||||
if isinstance(messages, Message):
|
||||
return [messages]
|
||||
if isinstance(messages, Content):
|
||||
return [Message(role="user", contents=[messages])]
|
||||
result: list[Message] = []
|
||||
for item in messages:
|
||||
if isinstance(item, Message):
|
||||
result.append(item)
|
||||
elif isinstance(item, Content):
|
||||
result.append(Message(role="user", contents=[item]))
|
||||
else:
|
||||
result.append(Message(role="user", contents=[Content.from_text(text=item)]))
|
||||
return result
|
||||
|
||||
def _approval_responses_in(self, messages: list[Message]) -> list[Content]:
|
||||
approvals: list[Content] = []
|
||||
for msg in messages:
|
||||
for content in msg.contents:
|
||||
if content.type == "function_approval_response":
|
||||
approvals.append(content)
|
||||
return approvals
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentResponse:
|
||||
normalized = self._normalize(messages)
|
||||
self.last_run_messages = normalized
|
||||
self.run_count += 1
|
||||
|
||||
approvals = self._approval_responses_in(normalized)
|
||||
if approvals:
|
||||
# Continuation: reflect approved arguments in the final response text.
|
||||
approved_text = "; ".join(
|
||||
f"approved={a.approved} id={a.id}" # type: ignore[attr-defined]
|
||||
for a in approvals
|
||||
)
|
||||
return AgentResponse(messages=[Message("assistant", [Content.from_text(text=f"done ({approved_text})")])])
|
||||
|
||||
# First run: ask for tool approval.
|
||||
approval = self._build_approval_request()
|
||||
return AgentResponse(messages=[Message("assistant", [approval])])
|
||||
|
||||
def _run_stream(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
normalized = self._normalize(messages)
|
||||
self.last_run_messages = normalized
|
||||
self.run_count += 1
|
||||
approvals = self._approval_responses_in(normalized)
|
||||
|
||||
async def _iter():
|
||||
if approvals:
|
||||
approved_text = "; ".join(
|
||||
f"approved={a.approved} id={a.id}" # type: ignore[attr-defined]
|
||||
for a in approvals
|
||||
)
|
||||
yield AgentResponseUpdate(
|
||||
contents=[Content.from_text(text=f"done ({approved_text})")],
|
||||
role="assistant",
|
||||
author_name=self.name,
|
||||
)
|
||||
return
|
||||
approval = self._build_approval_request()
|
||||
yield AgentResponseUpdate(
|
||||
contents=[approval],
|
||||
role="assistant",
|
||||
author_name=self.name,
|
||||
)
|
||||
|
||||
return ResponseStream(_iter(), finalizer=AgentResponse.from_updates)
|
||||
|
||||
|
||||
class TestWorkflowAgentToolApproval:
|
||||
"""Tests for tool-approval requests bubbling through WorkflowAgent.
|
||||
|
||||
Covers the case where a workflow contains an AgentExecutor whose underlying
|
||||
agent emits a FunctionApprovalRequestContent (tool needing user approval).
|
||||
The WorkflowAgent must:
|
||||
* forward the original FunctionApprovalRequestContent unchanged (no
|
||||
wrapping inside a synthesized 'request_info' function call), and
|
||||
* route a subsequent FunctionApprovalResponseContent back to the
|
||||
AgentExecutor so the agent can resume.
|
||||
"""
|
||||
|
||||
def _find_approval_request(
|
||||
self,
|
||||
contents: Sequence[Content],
|
||||
tool_name: str,
|
||||
) -> Content | None:
|
||||
for content in contents:
|
||||
if (
|
||||
content.type == "function_approval_request"
|
||||
and getattr(content.function_call, "name", None) == tool_name # type: ignore[attr-defined]
|
||||
):
|
||||
return content
|
||||
return None
|
||||
|
||||
async def test_tool_approval_request_forwarded_unchanged(self) -> None:
|
||||
"""The agent's FunctionApprovalRequestContent surfaces verbatim (not re-wrapped)."""
|
||||
approval_id = "approval-abc-123"
|
||||
mock_agent = _ToolApprovalMockAgent(
|
||||
name="approval-agent",
|
||||
tool_name="delete_file",
|
||||
tool_arguments={"path": "/tmp/secret.txt"},
|
||||
approval_request_ids=[approval_id],
|
||||
)
|
||||
|
||||
@executor
|
||||
async def start(messages: list[Message], ctx: WorkflowContext[AgentExecutorRequest]) -> None:
|
||||
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
|
||||
|
||||
workflow = WorkflowBuilder(start_executor=start).add_edge(start, mock_agent).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="Approval Test Agent")
|
||||
|
||||
result = await agent.run("please delete the file")
|
||||
|
||||
assert isinstance(result, AgentResponse)
|
||||
|
||||
# Locate the approval request emitted by the WorkflowAgent.
|
||||
all_contents: list[Content] = [c for m in result.messages for c in m.contents]
|
||||
approval = self._find_approval_request(all_contents, tool_name="delete_file")
|
||||
assert approval is not None, "WorkflowAgent did not forward the tool approval request"
|
||||
|
||||
# The id and inner function_call must match what the underlying agent produced
|
||||
# — i.e. the WorkflowAgent must NOT have re-wrapped it inside a synthesized
|
||||
# 'request_info' approval request.
|
||||
assert approval.id == approval_id
|
||||
function_call = approval.function_call # type: ignore[attr-defined]
|
||||
assert function_call is not None
|
||||
assert function_call.name == "delete_file"
|
||||
assert function_call.name != WorkflowAgent.REQUEST_INFO_FUNCTION_NAME
|
||||
assert function_call.arguments == {"path": "/tmp/secret.txt"}
|
||||
|
||||
# The agent must be paused awaiting the approval response.
|
||||
pending = await workflow._runner_context.get_pending_request_info_events()
|
||||
assert approval_id in pending
|
||||
|
||||
async def test_tool_approval_request_forwarded_unchanged_streaming(self) -> None:
|
||||
"""Streaming variant: the approval request is forwarded as-is in updates."""
|
||||
approval_id = "approval-stream-1"
|
||||
mock_agent = _ToolApprovalMockAgent(
|
||||
name="approval-agent-stream",
|
||||
tool_name="send_email",
|
||||
tool_arguments={"to": "alice@example.com"},
|
||||
approval_request_ids=[approval_id],
|
||||
)
|
||||
|
||||
@executor
|
||||
async def start(messages: list[Message], ctx: WorkflowContext[AgentExecutorRequest]) -> None:
|
||||
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
|
||||
|
||||
workflow = WorkflowBuilder(start_executor=start).add_edge(start, mock_agent).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="Approval Stream Agent")
|
||||
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
async for update in agent.run("hi", stream=True):
|
||||
updates.append(update)
|
||||
|
||||
approval_updates = [u for u in updates if any(c.type == "function_approval_request" for c in u.contents)]
|
||||
assert approval_updates, "Streaming did not surface a tool approval request"
|
||||
|
||||
approval = self._find_approval_request(approval_updates[-1].contents, tool_name="send_email")
|
||||
assert approval is not None
|
||||
assert approval.id == approval_id
|
||||
function_call = approval.function_call # type: ignore[attr-defined]
|
||||
assert function_call is not None
|
||||
assert function_call.name == "send_email"
|
||||
assert function_call.name != WorkflowAgent.REQUEST_INFO_FUNCTION_NAME
|
||||
assert function_call.arguments == {"to": "alice@example.com"}
|
||||
|
||||
async def test_tool_approval_response_resumes_agent(self) -> None:
|
||||
"""Sending the approval response back resumes the agent and clears pending requests."""
|
||||
approval_id = "approval-resume-1"
|
||||
mock_agent = _ToolApprovalMockAgent(
|
||||
name="approval-resume-agent",
|
||||
tool_name="delete_file",
|
||||
tool_arguments={"path": "/tmp/x"},
|
||||
approval_request_ids=[approval_id],
|
||||
)
|
||||
|
||||
@executor
|
||||
async def start(messages: list[Message], ctx: WorkflowContext[AgentExecutorRequest]) -> None:
|
||||
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
|
||||
|
||||
workflow = WorkflowBuilder(start_executor=start).add_edge(start, mock_agent).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="Approval Resume Agent")
|
||||
|
||||
first_result = await agent.run("delete it")
|
||||
approval = self._find_approval_request(
|
||||
[c for m in first_result.messages for c in m.contents],
|
||||
tool_name="delete_file",
|
||||
)
|
||||
assert approval is not None
|
||||
assert mock_agent.run_count == 1
|
||||
|
||||
# Build the approval response. NOTE: the inner function_call's name is the
|
||||
# original tool name ('delete_file'), NOT 'request_info'. This exercises the
|
||||
# branch in WorkflowAgent._extract_function_responses that routes raw
|
||||
# tool-approval responses straight through using content.id.
|
||||
approval_response = approval.to_function_approval_response(approved=True) # type: ignore[attr-defined]
|
||||
response_message = Message(role="user", contents=[approval_response])
|
||||
|
||||
final_result = await agent.run(response_message)
|
||||
assert isinstance(final_result, AgentResponse)
|
||||
|
||||
# The mock agent should have been invoked a second time and seen the
|
||||
# approval response in its inputs.
|
||||
assert mock_agent.run_count == 2
|
||||
approvals_seen = [
|
||||
c for m in mock_agent.last_run_messages for c in m.contents if c.type == "function_approval_response"
|
||||
]
|
||||
assert len(approvals_seen) == 1
|
||||
assert approvals_seen[0].id == approval_id # type: ignore[attr-defined]
|
||||
assert approvals_seen[0].approved is True # type: ignore[attr-defined]
|
||||
|
||||
# The pending approval should now be cleared.
|
||||
pending = await workflow._runner_context.get_pending_request_info_events()
|
||||
assert approval_id not in pending
|
||||
|
||||
# The final assistant message reflects the resumption.
|
||||
final_text = " ".join(m.text or "" for m in final_result.messages)
|
||||
assert "done" in final_text
|
||||
assert approval_id in final_text
|
||||
|
||||
async def test_tool_approval_response_rejected_resumes_agent(self) -> None:
|
||||
"""Rejection path: ``approved=False`` is forwarded to the inner agent and clears the pending request.
|
||||
|
||||
The WorkflowAgent must route a rejection response back to the paused
|
||||
``AgentExecutor`` exactly the same way as an approval — only the
|
||||
``approved`` flag differs. The inner agent decides what to do with it.
|
||||
"""
|
||||
approval_id = "approval-reject-1"
|
||||
mock_agent = _ToolApprovalMockAgent(
|
||||
name="approval-reject-agent",
|
||||
tool_name="delete_file",
|
||||
tool_arguments={"path": "/tmp/x"},
|
||||
approval_request_ids=[approval_id],
|
||||
)
|
||||
|
||||
@executor
|
||||
async def start(messages: list[Message], ctx: WorkflowContext[AgentExecutorRequest]) -> None:
|
||||
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
|
||||
|
||||
workflow = WorkflowBuilder(start_executor=start).add_edge(start, mock_agent).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="Approval Reject Agent")
|
||||
|
||||
first_result = await agent.run("delete it")
|
||||
approval = self._find_approval_request(
|
||||
[c for m in first_result.messages for c in m.contents],
|
||||
tool_name="delete_file",
|
||||
)
|
||||
assert approval is not None
|
||||
assert mock_agent.run_count == 1
|
||||
|
||||
# Reject the tool invocation.
|
||||
approval_response = approval.to_function_approval_response(approved=False) # type: ignore[attr-defined]
|
||||
response_message = Message(role="user", contents=[approval_response])
|
||||
|
||||
final_result = await agent.run(response_message)
|
||||
assert isinstance(final_result, AgentResponse)
|
||||
|
||||
# The inner agent must have been resumed and seen ``approved=False``.
|
||||
assert mock_agent.run_count == 2
|
||||
approvals_seen = [
|
||||
c for m in mock_agent.last_run_messages for c in m.contents if c.type == "function_approval_response"
|
||||
]
|
||||
assert len(approvals_seen) == 1
|
||||
assert approvals_seen[0].id == approval_id # type: ignore[attr-defined]
|
||||
assert approvals_seen[0].approved is False # type: ignore[attr-defined]
|
||||
|
||||
# Pending approval cleared regardless of approve/reject.
|
||||
pending = await workflow._runner_context.get_pending_request_info_events()
|
||||
assert approval_id not in pending
|
||||
|
||||
# The final assistant message reflects the rejection.
|
||||
final_text = " ".join(m.text or "" for m in final_result.messages)
|
||||
assert "approved=False" in final_text
|
||||
assert approval_id in final_text
|
||||
|
||||
async def test_tool_approval_request_id_matches_pending_request(self) -> None:
|
||||
"""The approval request id surfaced by WorkflowAgent matches the workflow's pending request id.
|
||||
|
||||
This guards the AgentExecutor change that forwards
|
||||
request_id=user_input_request.id to ctx.request_info(...), which is what
|
||||
allows the response routed back via WorkflowAgent to resolve the pending
|
||||
request without an id-mismatch error.
|
||||
"""
|
||||
approval_id = "approval-id-match-1"
|
||||
mock_agent = _ToolApprovalMockAgent(
|
||||
name="approval-id-match-agent",
|
||||
approval_request_ids=[approval_id],
|
||||
)
|
||||
|
||||
@executor
|
||||
async def start(messages: list[Message], ctx: WorkflowContext[AgentExecutorRequest]) -> None:
|
||||
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
|
||||
|
||||
workflow = WorkflowBuilder(start_executor=start).add_edge(start, mock_agent).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="Approval Id Agent")
|
||||
|
||||
await agent.run("go")
|
||||
|
||||
pending = await workflow._runner_context.get_pending_request_info_events()
|
||||
# The agent's approval id is used as the workflow's pending request id.
|
||||
assert list(pending.keys()) == [approval_id]
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for the ``Workflow.status`` property."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import (
|
||||
Executor,
|
||||
Workflow,
|
||||
WorkflowBuilder,
|
||||
WorkflowContext,
|
||||
WorkflowEvent,
|
||||
WorkflowRunState,
|
||||
handler,
|
||||
response_handler,
|
||||
)
|
||||
from agent_framework._workflows._executor import Executor as _Executor
|
||||
from agent_framework._workflows._request_info_mixin import RequestInfoMixin
|
||||
|
||||
|
||||
class PassThroughExecutor(Executor):
|
||||
"""Executor that yields its input as a workflow output and stops."""
|
||||
|
||||
@handler
|
||||
async def passthrough(self, msg: str, ctx: WorkflowContext[str, str]) -> None:
|
||||
await ctx.yield_output(msg)
|
||||
|
||||
|
||||
class FailingExecutor(Executor):
|
||||
"""Executor that raises at runtime to drive the FAILED status."""
|
||||
|
||||
@handler
|
||||
async def fail(self, msg: int, ctx: WorkflowContext) -> None: # pragma: no cover - invoked via workflow
|
||||
raise RuntimeError("boom")
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ApprovalRequest:
|
||||
prompt: str
|
||||
request_id: str = ""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.request_id:
|
||||
import uuid
|
||||
|
||||
self.request_id = str(uuid.uuid4())
|
||||
|
||||
|
||||
class ApprovalExecutor(_Executor, RequestInfoMixin):
|
||||
"""Executor that issues a single request_info call and finalizes on response."""
|
||||
|
||||
def __init__(self, id: str = "approval"):
|
||||
super().__init__(id=id)
|
||||
|
||||
@handler
|
||||
async def start(self, message: str, ctx: WorkflowContext[str, str]) -> None:
|
||||
await ctx.request_info(_ApprovalRequest(prompt=message), bool)
|
||||
|
||||
@response_handler
|
||||
async def on_response(
|
||||
self, original_request: _ApprovalRequest, approved: bool, ctx: WorkflowContext[str, str]
|
||||
) -> None:
|
||||
await ctx.yield_output(f"approved={approved}")
|
||||
|
||||
|
||||
def _build_passthrough_workflow() -> Workflow:
|
||||
executor = PassThroughExecutor(id="p")
|
||||
return WorkflowBuilder(start_executor=executor, output_from=[executor]).build()
|
||||
|
||||
|
||||
def _build_failing_workflow() -> Workflow:
|
||||
# FailingExecutor has no workflow_output_types, so we leave designation
|
||||
# implicit; the deprecation warning is filtered at call sites that need it.
|
||||
return WorkflowBuilder(start_executor=FailingExecutor(id="f")).build()
|
||||
|
||||
|
||||
def _build_approval_workflow() -> Workflow:
|
||||
executor = ApprovalExecutor(id="approval")
|
||||
return WorkflowBuilder(start_executor=executor, output_from=[executor]).build()
|
||||
|
||||
|
||||
async def test_status_default_is_idle_before_first_run():
|
||||
wf = _build_passthrough_workflow()
|
||||
assert wf.status is WorkflowRunState.IDLE
|
||||
|
||||
|
||||
async def test_status_is_idle_after_successful_run():
|
||||
wf = _build_passthrough_workflow()
|
||||
await wf.run("hello")
|
||||
assert wf.status is WorkflowRunState.IDLE
|
||||
|
||||
|
||||
async def test_status_is_failed_after_failure():
|
||||
wf = _build_failing_workflow()
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await wf.run(0)
|
||||
assert wf.status is WorkflowRunState.FAILED
|
||||
|
||||
|
||||
async def test_status_transitions_during_streaming_run():
|
||||
"""Workflow.status mirrors the most recent emitted status event."""
|
||||
wf = _build_passthrough_workflow()
|
||||
observed: list[WorkflowRunState] = []
|
||||
|
||||
async for event in wf.run("hi", stream=True):
|
||||
if isinstance(event, WorkflowEvent) and event.type == "status":
|
||||
# By the time a status event surfaces to the consumer, the property
|
||||
# must already reflect that state (updated in lockstep with emission).
|
||||
assert wf.status == event.state
|
||||
observed.append(event.state) # type: ignore
|
||||
|
||||
# IN_PROGRESS must precede IDLE; both must appear.
|
||||
assert WorkflowRunState.IN_PROGRESS in observed
|
||||
assert observed[-1] is WorkflowRunState.IDLE
|
||||
assert wf.status is WorkflowRunState.IDLE
|
||||
|
||||
|
||||
async def test_status_idle_with_pending_requests_then_resolves_to_idle():
|
||||
wf = _build_approval_workflow()
|
||||
|
||||
request_event: WorkflowEvent | None = None
|
||||
async for event in wf.run("please approve", stream=True):
|
||||
if isinstance(event, WorkflowEvent) and event.type == "request_info":
|
||||
request_event = event
|
||||
|
||||
assert request_event is not None
|
||||
assert wf.status is WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
|
||||
|
||||
async for _ in wf.run(stream=True, responses={request_event.request_id: True}):
|
||||
pass
|
||||
|
||||
assert wf.status is WorkflowRunState.IDLE
|
||||
|
||||
|
||||
async def test_status_in_progress_pending_requests_observed_mid_run():
|
||||
"""While streaming, status reaches IN_PROGRESS_PENDING_REQUESTS after a request_info event."""
|
||||
wf = _build_approval_workflow()
|
||||
seen_states: list[WorkflowRunState] = []
|
||||
|
||||
async for event in wf.run("please approve", stream=True):
|
||||
if isinstance(event, WorkflowEvent) and event.type == "status":
|
||||
seen_states.append(event.state) # type: ignore
|
||||
|
||||
assert WorkflowRunState.IN_PROGRESS in seen_states
|
||||
assert WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS in seen_states
|
||||
assert seen_states[-1] is WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
|
||||
assert wf.status is WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
|
||||
@@ -26,7 +26,7 @@ dependencies = [
|
||||
"agent-framework-core>=1.8.0,<2",
|
||||
"agent-framework-openai>=1.8.0,<2",
|
||||
"azure-ai-inference>=1.0.0b9,<1.0.0b10",
|
||||
"azure-ai-projects>=2.1.0,<3.0",
|
||||
"azure-ai-projects>=2.2.0,<3.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
|
||||
@@ -567,7 +567,7 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
by the hosting infrastructure or files will be preserved upon deactivation.
|
||||
"""
|
||||
input_items = await context.get_input_items()
|
||||
input_messages = await _items_to_messages(input_items)
|
||||
input_messages = await _items_to_messages(input_items, approval_storage=self._approval_storage)
|
||||
is_streaming_request = request.stream is not None and request.stream is True
|
||||
|
||||
_, are_options_set = _to_chat_options(request)
|
||||
@@ -664,7 +664,11 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
checkpoint_storage=write_storage,
|
||||
)
|
||||
|
||||
async for item in _to_outputs_for_messages(response_event_stream, response.messages):
|
||||
async for item in _to_outputs_for_messages(
|
||||
response_event_stream,
|
||||
response.messages,
|
||||
approval_storage=self._approval_storage,
|
||||
):
|
||||
yield item
|
||||
|
||||
await self._delete_not_latest_checkpoints(write_storage, self._agent.workflow.name)
|
||||
@@ -685,7 +689,9 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
for event in tracker.handle(content):
|
||||
yield event
|
||||
if tracker.needs_async:
|
||||
async for item in _to_outputs(response_event_stream, content):
|
||||
async for item in _to_outputs(
|
||||
response_event_stream, content, approval_storage=self._approval_storage
|
||||
):
|
||||
yield item
|
||||
tracker.needs_async = False
|
||||
|
||||
|
||||
@@ -11,24 +11,33 @@ the registered _handle_create handler.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, overload
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from agent_framework import (
|
||||
AgentExecutorRequest,
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentSession,
|
||||
Content,
|
||||
FileCheckpointStorage,
|
||||
HistoryProvider,
|
||||
Message,
|
||||
RawAgent,
|
||||
ResponseStream,
|
||||
SupportsAgentRun,
|
||||
WorkflowAgent,
|
||||
WorkflowBuilder,
|
||||
WorkflowCheckpoint,
|
||||
WorkflowCheckpointException,
|
||||
WorkflowContext,
|
||||
WorkflowMessage,
|
||||
executor,
|
||||
)
|
||||
from azure.ai.agentserver.responses import InMemoryResponseProvider
|
||||
from mcp import McpError
|
||||
@@ -102,7 +111,7 @@ def _make_agent(
|
||||
return agent
|
||||
|
||||
|
||||
def _make_server(agent: MagicMock, **kwargs: Any) -> ResponsesHostServer:
|
||||
def _make_server(agent: Any, **kwargs: Any) -> ResponsesHostServer:
|
||||
"""Create a ResponsesHostServer with an in-memory store."""
|
||||
return ResponsesHostServer(agent, store=InMemoryResponseProvider(), **kwargs)
|
||||
|
||||
@@ -3469,3 +3478,498 @@ class TestOAuthConsentSurfacing:
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Workflow agent hosting (end-to-end)
|
||||
|
||||
|
||||
class _ToolApprovalWorkflowAgentMock(SupportsAgentRun):
|
||||
"""Inner agent for a hosted ``WorkflowAgent`` whose first run emits a
|
||||
``FunctionApprovalRequestContent`` and whose follow-up run (after
|
||||
receiving a ``FunctionApprovalResponseContent`` in its inputs) returns a
|
||||
final assistant text response.
|
||||
|
||||
Mirrors a real agent whose tool invocation requires user approval. Used
|
||||
here to exercise the full HTTP pipeline through ``ResponsesHostServer``
|
||||
when the hosted agent is a ``WorkflowAgent`` containing a tool-approval
|
||||
flow.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
tool_name: str = "delete_file",
|
||||
tool_arguments: dict[str, Any] | None = None,
|
||||
approval_request_ids: Sequence[str] | None = None,
|
||||
final_text: str = "done",
|
||||
) -> None:
|
||||
self.id = str(uuid.uuid4())
|
||||
self.name = name
|
||||
self.description: str | None = None
|
||||
self._tool_name = tool_name
|
||||
self._tool_arguments = tool_arguments or {"path": "/tmp/example"}
|
||||
self._approval_request_ids: list[str] = list(approval_request_ids) if approval_request_ids else []
|
||||
self._final_text = final_text
|
||||
self.run_count = 0
|
||||
self.last_run_messages: list[Message] = []
|
||||
|
||||
def create_session(self, **kwargs: Any) -> AgentSession:
|
||||
return AgentSession()
|
||||
|
||||
def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession:
|
||||
return AgentSession()
|
||||
|
||||
def _next_request_id(self) -> str:
|
||||
# Stable across calls: when the workflow checkpoint round-trips through
|
||||
# restore, ``AgentExecutor`` re-invokes the inner agent during replay.
|
||||
# We must surface the *same* approval request id on each invocation so
|
||||
# the workflow's pending-request id matches the id the test echoes
|
||||
# back as ``mcp_approval_response``.
|
||||
if self._approval_request_ids:
|
||||
return self._approval_request_ids[0]
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def _build_approval_request(self) -> Content:
|
||||
request_id = self._next_request_id()
|
||||
function_call = Content.from_function_call(
|
||||
call_id=request_id,
|
||||
name=self._tool_name,
|
||||
arguments=self._tool_arguments,
|
||||
additional_properties={"server_label": "test_server"},
|
||||
)
|
||||
return Content.from_function_approval_request(id=request_id, function_call=function_call)
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = ...,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = ...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
if stream:
|
||||
return self._run_stream(messages=messages, **kwargs)
|
||||
return self._run(messages=messages, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _normalize(
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None,
|
||||
) -> list[Message]:
|
||||
if messages is None:
|
||||
return []
|
||||
if isinstance(messages, str):
|
||||
return [Message(role="user", contents=[Content.from_text(text=messages)])]
|
||||
if isinstance(messages, Message):
|
||||
return [messages]
|
||||
if isinstance(messages, Content):
|
||||
return [Message(role="user", contents=[messages])]
|
||||
result: list[Message] = []
|
||||
for item in messages:
|
||||
if isinstance(item, Message):
|
||||
result.append(item)
|
||||
elif isinstance(item, Content):
|
||||
result.append(Message(role="user", contents=[item]))
|
||||
else:
|
||||
result.append(Message(role="user", contents=[Content.from_text(text=item)]))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _approval_responses_in(messages: list[Message]) -> list[Content]:
|
||||
return [c for m in messages for c in m.contents if c.type == "function_approval_response"]
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentResponse:
|
||||
normalized = self._normalize(messages)
|
||||
self.last_run_messages = normalized
|
||||
self.run_count += 1
|
||||
if self._approval_responses_in(normalized):
|
||||
return AgentResponse(messages=[Message("assistant", [Content.from_text(text=self._final_text)])])
|
||||
approval = self._build_approval_request()
|
||||
return AgentResponse(messages=[Message("assistant", [approval])])
|
||||
|
||||
def _run_stream(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
normalized = self._normalize(messages)
|
||||
self.last_run_messages = normalized
|
||||
self.run_count += 1
|
||||
approvals = self._approval_responses_in(normalized)
|
||||
|
||||
async def _iter() -> AsyncIterator[AgentResponseUpdate]:
|
||||
if approvals:
|
||||
yield AgentResponseUpdate(
|
||||
contents=[Content.from_text(text=self._final_text)],
|
||||
role="assistant",
|
||||
author_name=self.name,
|
||||
)
|
||||
return
|
||||
yield AgentResponseUpdate(
|
||||
contents=[self._build_approval_request()],
|
||||
role="assistant",
|
||||
author_name=self.name,
|
||||
)
|
||||
|
||||
return ResponseStream(_iter(), finalizer=AgentResponse.from_updates)
|
||||
|
||||
|
||||
def _build_text_workflow_agent(text: str) -> WorkflowAgent:
|
||||
"""Build a minimal ``WorkflowAgent`` whose inner agent emits a fixed text."""
|
||||
|
||||
class _TextAgent(SupportsAgentRun):
|
||||
def __init__(self, name: str, text: str) -> None:
|
||||
self.id = str(uuid.uuid4())
|
||||
self.name = name
|
||||
self.description: str | None = None
|
||||
self._text = text
|
||||
|
||||
def create_session(self, **kwargs: Any) -> AgentSession:
|
||||
return AgentSession()
|
||||
|
||||
def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession:
|
||||
return AgentSession()
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: Any = ...,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: Any = ...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: Any = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
text = self._text
|
||||
name = self.name
|
||||
|
||||
async def _aresult() -> AgentResponse:
|
||||
return AgentResponse(messages=[Message("assistant", [Content.from_text(text=text)])])
|
||||
|
||||
async def _aiter() -> AsyncIterator[AgentResponseUpdate]:
|
||||
yield AgentResponseUpdate(
|
||||
contents=[Content.from_text(text=text)],
|
||||
role="assistant",
|
||||
author_name=name,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return ResponseStream(_aiter(), finalizer=AgentResponse.from_updates)
|
||||
return _aresult()
|
||||
|
||||
inner = _TextAgent("text-agent", text)
|
||||
|
||||
@executor
|
||||
async def start(messages: list[Message], ctx: WorkflowContext[AgentExecutorRequest]) -> None:
|
||||
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
|
||||
|
||||
workflow = WorkflowBuilder(start_executor=start).add_edge(start, inner).build()
|
||||
return WorkflowAgent(workflow=workflow, name="Text Workflow Agent")
|
||||
|
||||
|
||||
def _build_approval_workflow_agent(
|
||||
*,
|
||||
approval_request_id: str,
|
||||
tool_name: str = "delete_file",
|
||||
tool_arguments: dict[str, Any] | None = None,
|
||||
final_text: str = "done",
|
||||
) -> tuple[WorkflowAgent, _ToolApprovalWorkflowAgentMock]:
|
||||
"""Build a ``WorkflowAgent`` whose inner agent emits a tool approval request."""
|
||||
mock_agent = _ToolApprovalWorkflowAgentMock(
|
||||
name="approval-agent",
|
||||
tool_name=tool_name,
|
||||
tool_arguments=tool_arguments or {"path": "/tmp/secret.txt"},
|
||||
approval_request_ids=[approval_request_id],
|
||||
final_text=final_text,
|
||||
)
|
||||
|
||||
@executor
|
||||
async def start(messages: list[Message], ctx: WorkflowContext[AgentExecutorRequest]) -> None:
|
||||
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
|
||||
|
||||
workflow = WorkflowBuilder(start_executor=start).add_edge(start, mock_agent).build()
|
||||
workflow_agent = WorkflowAgent(workflow=workflow, name="Approval Workflow Agent")
|
||||
return workflow_agent, mock_agent
|
||||
|
||||
|
||||
class TestWorkflowAgentHosting:
|
||||
"""End-to-end HTTP tests for ``ResponsesHostServer`` hosting a ``WorkflowAgent``.
|
||||
|
||||
These tests drive ``_handle_inner_workflow`` through the ASGI stack:
|
||||
they exercise checkpoint write/restore (multi-turn) and the
|
||||
tool-approval round-trip path, which is the primary differentiator
|
||||
relative to the regular agent path.
|
||||
"""
|
||||
|
||||
async def test_basic_text_response(self) -> None:
|
||||
workflow_agent = _build_text_workflow_agent("hello from workflow")
|
||||
server = _make_server(workflow_agent)
|
||||
|
||||
resp = await _post(server, input_text="hi", stream=False)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["status"] == "completed"
|
||||
|
||||
text_found = any(
|
||||
part.get("type") == "output_text" and part.get("text") == "hello from workflow"
|
||||
for item in body["output"]
|
||||
if item["type"] == "message"
|
||||
for part in item.get("content", [])
|
||||
)
|
||||
assert text_found, f"Expected workflow output text in {body['output']}"
|
||||
|
||||
async def test_basic_text_response_streaming(self) -> None:
|
||||
workflow_agent = _build_text_workflow_agent("hello stream")
|
||||
server = _make_server(workflow_agent)
|
||||
|
||||
resp = await _post(server, input_text="hi", stream=True)
|
||||
assert resp.status_code == 200
|
||||
events = _parse_sse_events(resp.text)
|
||||
types = _sse_event_types(events)
|
||||
assert types[0] == "response.created"
|
||||
assert types[-1] == "response.completed"
|
||||
assert "response.output_text.delta" in types
|
||||
text_done = [e for e in events if e["event"] == "response.output_text.done"]
|
||||
assert any(e["data"]["text"] == "hello stream" for e in text_done)
|
||||
|
||||
async def test_non_streaming_emits_mcp_approval_request_and_persists_to_storage(self) -> None:
|
||||
workflow_agent, mock_agent = _build_approval_workflow_agent(approval_request_id="apr_wf_ns")
|
||||
server = _make_server(workflow_agent)
|
||||
|
||||
resp = await _post(server, stream=False)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["status"] == "completed"
|
||||
|
||||
approval_items = [it for it in body["output"] if it["type"] == "mcp_approval_request"]
|
||||
assert len(approval_items) == 1
|
||||
assert approval_items[0]["name"] == "delete_file"
|
||||
assert approval_items[0]["server_label"] == "test_server"
|
||||
approval_request_id = approval_items[0]["id"]
|
||||
|
||||
# The id surfaced over the wire is generated by the response stream
|
||||
# builder; the original approval ``Content`` (carrying the inner
|
||||
# ``function_call``) must be persisted under that id so the next
|
||||
# turn can reconstruct it.
|
||||
loaded = await server._approval_storage.load_approval_request( # pyright: ignore[reportPrivateUsage]
|
||||
approval_request_id
|
||||
)
|
||||
assert loaded.type == "function_approval_request"
|
||||
assert loaded.function_call.name == "delete_file" # type: ignore[attr-defined]
|
||||
assert mock_agent.run_count == 1
|
||||
|
||||
async def test_streaming_emits_mcp_approval_request_and_persists_to_storage(self) -> None:
|
||||
workflow_agent, mock_agent = _build_approval_workflow_agent(approval_request_id="apr_wf_st")
|
||||
server = _make_server(workflow_agent)
|
||||
|
||||
resp = await _post(server, stream=True)
|
||||
assert resp.status_code == 200
|
||||
|
||||
events = _parse_sse_events(resp.text)
|
||||
types = _sse_event_types(events)
|
||||
assert types[0] == "response.created"
|
||||
assert types[-1] == "response.completed"
|
||||
|
||||
approval_request_id: str | None = None
|
||||
for e in events:
|
||||
if e["event"] != "response.output_item.added":
|
||||
continue
|
||||
item = e["data"].get("item") or {}
|
||||
if item.get("type") == "mcp_approval_request":
|
||||
approval_request_id = item.get("id")
|
||||
break
|
||||
assert approval_request_id is not None
|
||||
|
||||
loaded = await server._approval_storage.load_approval_request( # pyright: ignore[reportPrivateUsage]
|
||||
approval_request_id
|
||||
)
|
||||
assert loaded.type == "function_approval_request"
|
||||
assert mock_agent.run_count == 1
|
||||
|
||||
async def test_round_trip_approval_response_resumes_workflow_agent(self) -> None:
|
||||
"""Two-turn HTTP round-trip:
|
||||
|
||||
Turn 1 emits ``mcp_approval_request`` and writes a workflow
|
||||
checkpoint under the response id. Turn 2 sends the
|
||||
``mcp_approval_response`` with ``previous_response_id`` set, so the
|
||||
host restores the checkpoint, the WorkflowAgent routes the
|
||||
approval response back to the paused inner agent, and the inner
|
||||
agent emits the final assistant text.
|
||||
"""
|
||||
workflow_agent, mock_agent = _build_approval_workflow_agent(
|
||||
approval_request_id="apr_wf_rt",
|
||||
final_text="done with approval",
|
||||
)
|
||||
server = _make_server(workflow_agent)
|
||||
|
||||
first = await _post(server, stream=False)
|
||||
assert first.status_code == 200
|
||||
first_body = first.json()
|
||||
first_response_id = first_body["id"]
|
||||
approval_items = [it for it in first_body["output"] if it["type"] == "mcp_approval_request"]
|
||||
assert len(approval_items) == 1
|
||||
approval_request_id = approval_items[0]["id"]
|
||||
assert mock_agent.run_count == 1
|
||||
|
||||
second_payload: dict[str, Any] = {
|
||||
"model": "test-model",
|
||||
"input": [
|
||||
{
|
||||
"type": "mcp_approval_response",
|
||||
"approval_request_id": approval_request_id,
|
||||
"approve": True,
|
||||
}
|
||||
],
|
||||
"stream": False,
|
||||
"previous_response_id": first_response_id,
|
||||
}
|
||||
second = await _post_json(server, second_payload)
|
||||
assert second.status_code == 200
|
||||
second_body = second.json()
|
||||
assert second_body["status"] == "completed"
|
||||
|
||||
# The inner agent must have been resumed (restore replay + new turn).
|
||||
# Restore call is a no-op for the mock (no input); the new-turn call
|
||||
# delivers the approval response, so run_count grows by at least 1.
|
||||
assert mock_agent.run_count >= 2
|
||||
|
||||
# The final assistant text from the resumed inner agent surfaces in
|
||||
# the HTTP output.
|
||||
text_pieces = [
|
||||
part.get("text", "")
|
||||
for item in second_body["output"]
|
||||
if item["type"] == "message"
|
||||
for part in item.get("content", [])
|
||||
if part.get("type") == "output_text"
|
||||
]
|
||||
assert any("done with approval" in t for t in text_pieces), (
|
||||
f"expected resumed workflow output, got {second_body['output']}"
|
||||
)
|
||||
|
||||
# The new-turn invocation of the inner agent must have received the
|
||||
# approval response routed back through WorkflowAgent.
|
||||
approval_responses = [
|
||||
c for m in mock_agent.last_run_messages for c in m.contents if c.type == "function_approval_response"
|
||||
]
|
||||
assert len(approval_responses) == 1
|
||||
assert approval_responses[0].approved is True # type: ignore[attr-defined]
|
||||
|
||||
async def test_round_trip_approval_response_streaming(self) -> None:
|
||||
"""Streaming variant of the round-trip: turn 2 is requested with
|
||||
``stream=true`` and surfaces the resumed text as SSE events."""
|
||||
workflow_agent, mock_agent = _build_approval_workflow_agent(
|
||||
approval_request_id="apr_wf_rt_st",
|
||||
final_text="streamed-done",
|
||||
)
|
||||
server = _make_server(workflow_agent)
|
||||
|
||||
first = await _post(server, stream=False)
|
||||
first_body = first.json()
|
||||
first_response_id = first_body["id"]
|
||||
approval_request_id = next(it["id"] for it in first_body["output"] if it["type"] == "mcp_approval_request")
|
||||
|
||||
second = await _post_json(
|
||||
server,
|
||||
{
|
||||
"model": "test-model",
|
||||
"input": [
|
||||
{
|
||||
"type": "mcp_approval_response",
|
||||
"approval_request_id": approval_request_id,
|
||||
"approve": True,
|
||||
}
|
||||
],
|
||||
"stream": True,
|
||||
"previous_response_id": first_response_id,
|
||||
},
|
||||
)
|
||||
assert second.status_code == 200
|
||||
events = _parse_sse_events(second.text)
|
||||
types = _sse_event_types(events)
|
||||
assert types[0] == "response.created"
|
||||
assert types[-1] == "response.completed"
|
||||
|
||||
text_done = [e for e in events if e["event"] == "response.output_text.done"]
|
||||
assert any("streamed-done" in e["data"]["text"] for e in text_done)
|
||||
assert mock_agent.run_count >= 2
|
||||
|
||||
async def test_round_trip_approval_response_rejected(self) -> None:
|
||||
"""Sending ``approve=False`` must surface as ``approved=False`` to the
|
||||
inner agent on resume."""
|
||||
workflow_agent, mock_agent = _build_approval_workflow_agent(
|
||||
approval_request_id="apr_wf_reject",
|
||||
final_text="acknowledged",
|
||||
)
|
||||
server = _make_server(workflow_agent)
|
||||
|
||||
first = await _post(server, stream=False)
|
||||
first_body = first.json()
|
||||
first_response_id = first_body["id"]
|
||||
approval_request_id = next(it["id"] for it in first_body["output"] if it["type"] == "mcp_approval_request")
|
||||
|
||||
second = await _post_json(
|
||||
server,
|
||||
{
|
||||
"model": "test-model",
|
||||
"input": [
|
||||
{
|
||||
"type": "mcp_approval_response",
|
||||
"approval_request_id": approval_request_id,
|
||||
"approve": False,
|
||||
}
|
||||
],
|
||||
"stream": False,
|
||||
"previous_response_id": first_response_id,
|
||||
},
|
||||
)
|
||||
assert second.status_code == 200
|
||||
|
||||
approval_responses = [
|
||||
c for m in mock_agent.last_run_messages for c in m.contents if c.type == "function_approval_response"
|
||||
]
|
||||
assert len(approval_responses) == 1
|
||||
assert approval_responses[0].approved is False # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -134,15 +134,11 @@ def handle_response_and_requests(response: AgentResponse) -> dict[str, HandoffAg
|
||||
if message.text:
|
||||
print(f"- {message.author_name or message.role}: {message.text}")
|
||||
for content in message.contents:
|
||||
if content.type == "function_call":
|
||||
if isinstance(content.arguments, dict):
|
||||
request = WorkflowAgent.RequestInfoFunctionArgs.from_dict(content.arguments)
|
||||
elif isinstance(content.arguments, str):
|
||||
request = WorkflowAgent.RequestInfoFunctionArgs.from_json(content.arguments)
|
||||
else:
|
||||
raise ValueError("Invalid arguments type. Expecting a request info structure for this sample.")
|
||||
if isinstance(request.data, HandoffAgentUserRequest):
|
||||
pending_requests[request.request_id] = request.data
|
||||
if content.type == "function_call" and content.name == WorkflowAgent.REQUEST_INFO_FUNCTION_NAME:
|
||||
request_function_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(content.arguments) # type: ignore
|
||||
request_id = request_function_args.request_id
|
||||
request_event = request_function_args.request_event
|
||||
pending_requests[request_id] = request_event.data
|
||||
|
||||
return pending_requests
|
||||
|
||||
|
||||
@@ -3,10 +3,8 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from azure.identity import AzureCliCredential
|
||||
@@ -141,28 +139,14 @@ async def main() -> None:
|
||||
# Handle the human review if required.
|
||||
if human_review_function_call:
|
||||
# Parse the human review request arguments.
|
||||
human_request_args = human_review_function_call.arguments
|
||||
if isinstance(human_request_args, str):
|
||||
request: WorkflowAgent.RequestInfoFunctionArgs = WorkflowAgent.RequestInfoFunctionArgs.from_json(
|
||||
human_request_args
|
||||
)
|
||||
elif isinstance(human_request_args, Mapping):
|
||||
request = WorkflowAgent.RequestInfoFunctionArgs.from_dict(dict(human_request_args))
|
||||
else:
|
||||
raise TypeError("Unexpected argument type for human review function call.")
|
||||
|
||||
request_payload: Any = request.data
|
||||
human_request_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(human_review_function_call.arguments) # type: ignore
|
||||
request_payload = human_request_args.request_event.data
|
||||
if not isinstance(request_payload, HumanReviewRequest):
|
||||
raise ValueError("Human review request payload must be a HumanReviewRequest.")
|
||||
|
||||
agent_request = request_payload.agent_request
|
||||
if agent_request is None:
|
||||
raise ValueError("Human review request must include agent_request.")
|
||||
|
||||
request_id = agent_request.request_id
|
||||
if not request_payload.agent_request:
|
||||
raise ValueError("Human review request must contain an agent_request.")
|
||||
# Mock a human response approval for demonstration purposes.
|
||||
human_response = ReviewResponse(request_id=request_id, feedback="", approved=True)
|
||||
|
||||
human_response = ReviewResponse(request_id=request_payload.agent_request.request_id, feedback="", approved=True)
|
||||
# Create the function call result object to send back to the agent.
|
||||
human_review_function_result = Content(
|
||||
"function_result",
|
||||
|
||||
+10
-4
@@ -28,6 +28,7 @@ import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
from azure.ai.projects.aio import AIProjectClient
|
||||
from azure.ai.projects.models import CreateSkillVersionFromFilesBody
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.identity.aio import DefaultAzureCredential
|
||||
from dotenv import load_dotenv
|
||||
@@ -68,8 +69,13 @@ async def main() -> None:
|
||||
name = skill_md.parent.name
|
||||
print(f"Provisioning skill '{name}' from {skill_md.relative_to(SKILLS_DIR.parent)}...")
|
||||
await _delete_skill_if_exists(project, name)
|
||||
imported = await project.beta.skills.create_from_package(_zip_skill_md(skill_md))
|
||||
print(f" Imported skill '{imported.name}' (id={imported.skill_id}, has_blob={imported.has_blob}).")
|
||||
imported = await project.beta.skills.create_from_files(
|
||||
name,
|
||||
content=CreateSkillVersionFromFilesBody(
|
||||
files=[(f"{name}.zip", _zip_skill_md(skill_md), "application/zip")]
|
||||
),
|
||||
)
|
||||
print(f" Imported skill '{imported.name}' (id={imported.skill_id}, version={imported.version}).")
|
||||
|
||||
print("Verifying skills via project.beta.skills.list()...")
|
||||
listed = {skill.name: skill async for skill in project.beta.skills.list()}
|
||||
@@ -79,8 +85,8 @@ async def main() -> None:
|
||||
if skill is None:
|
||||
raise RuntimeError(f"Skill '{name}' was imported but is not present in the project listing.")
|
||||
print(
|
||||
f" OK '{skill.name}': id={skill.skill_id}, "
|
||||
f"description={skill.description!r}, has_blob={skill.has_blob}"
|
||||
f" OK '{skill.name}': id={skill.id}, "
|
||||
f"description={skill.description!r}, default_version={skill.default_version}"
|
||||
)
|
||||
|
||||
print("Done.")
|
||||
|
||||
Reference in New Issue
Block a user