Python: Add tool approval middleware (#6414)

* Add Python tool approval middleware

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix tool approval restored state handling

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Gate hidden approvals on explicit approval responses

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Handle string inputs in approval replay scan

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Cover argument-scoped approval rules

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Refine tool approval state and budgets

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix tool approval PR CI failures

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Revert DevUI Aspire README link change

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Eduard van Valkenburg
2026-06-11 19:35:44 +02:00
committed by GitHub
Unverified
parent c79f886dc3
commit df29af611c
8 changed files with 1868 additions and 15 deletions
+17
View File
@@ -100,6 +100,23 @@ agent_framework/
- **`FileSearchResult`** / **`FileSearchMatch`** - `SerializationMixin` DTOs returned by `search_files`, carrying the matching file name, a context snippet, and the matching lines with 1-based line numbers.
- **`FileAccessProvider`** - `ContextProvider` that adds shared file-access tools (`file_access_save_file`, `file_access_read_file`, `file_access_delete_file`, `file_access_list_files`, `file_access_search_files`) plus default usage instructions to each invocation. Unlike `MemoryContextProvider`, the store is intentionally shared across sessions and agents.
### Tool Approval Harness (`_harness/_tool_approval.py`)
- **`ToolApprovalMiddleware`** - Experimental opt-in agent middleware that coordinates session-backed approval
rules, heuristic `auto_approval_rules`, queued approval requests, collected approval responses, and
streaming/non-streaming approval prompts. Heuristic callbacks receive the underlying `function_call` content.
- **`ToolApprovalRule`** / **`ToolApprovalState`** - Serializable state models for standing approvals and queued
approval flow. `ToolApprovalRule.arguments is None` means a tool-wide rule; an empty dict `{}` means an exact
no-argument call for `create_always_approve_tool_with_arguments_response`.
- **`create_always_approve_tool_response`** / **`create_always_approve_tool_with_arguments_response`** - Helpers
that return normal `function_approval_response` content with `additional_properties` metadata consumed by
`ToolApprovalMiddleware`. Standing rules for hosted tools include the `server_label` boundary, so same-named tools
on different hosted servers do not share approvals.
- Mixed tool-call batches use a default .NET-style bypass in the function invocation loop: when a session is
available, approval requests for known non-approval-required tools are treated as already approved, hidden, stored
in session state keyed to the visible approval request ids from that batch, and reinjected only when that visible
approval flow resumes.
### Workflows (`_workflows/`)
- **`Workflow`** - Graph-based workflow definition
@@ -125,6 +125,15 @@ from ._harness._todo import (
TodoStore,
)
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool, SamplingApprovalCallback
from ._harness._tool_approval import (
DEFAULT_TOOL_APPROVAL_SOURCE_ID,
ToolApprovalMiddleware,
ToolApprovalRule,
ToolApprovalRuleCallback,
ToolApprovalState,
create_always_approve_tool_response,
create_always_approve_tool_with_arguments_response,
)
from ._middleware import (
AgentContext,
AgentMiddleware,
@@ -330,6 +339,7 @@ __all__ = [
"DEFAULT_MEMORY_SOURCE_ID",
"DEFAULT_MODE_SOURCE_ID",
"DEFAULT_TODO_SOURCE_ID",
"DEFAULT_TOOL_APPROVAL_SOURCE_ID",
"EXCLUDED_KEY",
"EXCLUDE_REASON_KEY",
"GROUP_ANNOTATION_KEY",
@@ -509,6 +519,10 @@ __all__ = [
"TodoStore",
"TokenBudgetComposedStrategy",
"TokenizerProtocol",
"ToolApprovalMiddleware",
"ToolApprovalRule",
"ToolApprovalRuleCallback",
"ToolApprovalState",
"ToolMode",
"ToolResultCompactionStrategy",
"ToolTypes",
@@ -543,6 +557,8 @@ __all__ = [
"annotate_message_groups",
"apply_compaction",
"chat_middleware",
"create_always_approve_tool_response",
"create_always_approve_tool_with_arguments_response",
"create_edge_runner",
"create_harness_agent",
"detect_media_type_from_base64",
@@ -0,0 +1,632 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import copy
import inspect
import json
from asyncio import sleep
from collections.abc import AsyncIterable, Awaitable, Callable, Iterable, Mapping, MutableMapping, Sequence
from typing import Any, Literal, cast
from .._feature_stage import ExperimentalFeature, experimental
from .._middleware import AgentContext, AgentMiddleware
from .._serialization import SerializationMixin
from .._sessions import AgentSession
from .._types import (
AgentResponse,
AgentResponseUpdate,
Content,
FinishReason,
FinishReasonLiteral,
Message,
ResponseStream,
)
DEFAULT_TOOL_APPROVAL_SOURCE_ID = "tool_approval"
_FUNCTION_INVOCATION_BUDGET_STATE_KEY = "_function_invocation_budget_state"
ALWAYS_APPROVE_PROPERTY = "tool_approval"
ALWAYS_APPROVE_SCOPE_PROPERTY = "always_approve"
ALWAYS_APPROVE_TOOL: Literal["tool"] = "tool"
ALWAYS_APPROVE_TOOL_WITH_ARGUMENTS: Literal["tool_with_arguments"] = "tool_with_arguments"
_RULES_KEY = "rules"
_QUEUED_APPROVAL_REQUESTS_KEY = "queued_approval_requests"
_COLLECTED_APPROVAL_RESPONSES_KEY = "collected_approval_responses"
ToolApprovalScope = Literal["tool", "tool_with_arguments"]
ToolApprovalRuleCallback = Callable[[Content], bool | Awaitable[bool]]
def _parse_function_arguments(function_call: Content) -> dict[str, Any]:
arguments = function_call.parse_arguments()
return dict(arguments or {})
def _serialize_argument_value(value: Any) -> str:
return json.dumps(value, sort_keys=True, separators=(",", ":"), default=str)
def _serialize_arguments(function_call: Content) -> dict[str, str]:
"""Serialize arguments for exact matching.
``None`` is reserved on :class:`ToolApprovalRule` for tool-wide rules.
An argument-scoped rule for a no-argument call stores ``{}``, so it only
matches future no-argument calls and never becomes a wildcard.
"""
arguments = _parse_function_arguments(function_call)
return {key: _serialize_argument_value(value) for key, value in arguments.items()}
def _server_label(function_call: Content) -> str | None:
"""Return the hosted-tool server boundary for a function call, if present."""
value = function_call.additional_properties.get("server_label")
return value if isinstance(value, str) else None
def _content_from_state(value: Any) -> Content:
if isinstance(value, Content):
return value
if isinstance(value, Mapping):
return Content.from_dict(cast(Mapping[str, Any], value))
raise TypeError(f"Expected Content or mapping state item, got {type(value).__name__}.")
def _contents_from_state(values: Any) -> list[Content]:
if not isinstance(values, list):
return []
state_items = list(cast(Iterable[Any], values))
return [_content_from_state(value) for value in state_items]
def _content_to_state(content: Content) -> dict[str, Any]:
return content.to_dict()
@experimental(feature_id=ExperimentalFeature.HARNESS)
class ToolApprovalRule(SerializationMixin):
"""A standing rule for approving future matching tool calls."""
tool_name: str
arguments: dict[str, str] | None
server_label: str | None
def __init__(
self,
tool_name: str,
arguments: Mapping[str, str] | None = None,
*,
server_label: str | None = None,
) -> None:
"""Initialize a tool approval rule.
Args:
tool_name: The function tool name this rule applies to.
arguments: Optional canonicalized argument values. When omitted, the
rule applies to every call to the tool. Use an empty mapping to
match only no-argument calls.
Keyword Args:
server_label: Optional hosted-tool server boundary. Hosted approvals
only match future approvals from the same server label.
"""
normalized_name = tool_name.strip()
if not normalized_name:
raise ValueError("Tool approval rule tool_name must be a non-empty string.")
self.tool_name = normalized_name
self.arguments = dict(arguments) if arguments is not None else None
self.server_label = server_label
@classmethod
def from_dict(
cls,
value: MutableMapping[str, Any],
/,
*,
dependencies: MutableMapping[str, Any] | None = None,
) -> ToolApprovalRule:
"""Create a rule from serialized state."""
del dependencies
tool_name = value.get("tool_name")
if not isinstance(tool_name, str):
raise ValueError("Tool approval rule tool_name must be a string.")
raw_arguments = value.get("arguments")
if raw_arguments is not None and not isinstance(raw_arguments, Mapping):
raise ValueError("Tool approval rule arguments must be a mapping or None.")
server_label = value.get("server_label")
if server_label is not None and not isinstance(server_label, str):
raise ValueError("Tool approval rule server_label must be a string or None.")
arguments = (
{str(key): str(argument_value) for key, argument_value in cast(Mapping[str, Any], raw_arguments).items()}
if isinstance(raw_arguments, Mapping)
else None
)
return cls(tool_name=tool_name, arguments=arguments, server_label=server_label)
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]:
"""Serialize the rule."""
exclude = exclude or set()
payload: dict[str, Any] = {"tool_name": self.tool_name}
if "type" not in exclude:
payload["type"] = self._get_type_identifier()
if self.arguments is not None or not exclude_none:
payload["arguments"] = self.arguments
if self.server_label is not None or not exclude_none:
payload["server_label"] = self.server_label
return payload
@experimental(feature_id=ExperimentalFeature.HARNESS)
class ToolApprovalState(SerializationMixin):
"""Session-backed state used by :class:`ToolApprovalMiddleware`."""
rules: list[ToolApprovalRule]
queued_approval_requests: list[Content]
collected_approval_responses: list[Content]
def __init__(
self,
*,
rules: Sequence[ToolApprovalRule | Mapping[str, Any]] | None = None,
queued_approval_requests: Sequence[Content | Mapping[str, Any]] | None = None,
collected_approval_responses: Sequence[Content | Mapping[str, Any]] | None = None,
) -> None:
"""Initialize approval state."""
self.rules = [
rule if isinstance(rule, ToolApprovalRule) else ToolApprovalRule.from_dict(dict(rule))
for rule in (rules or [])
]
self.queued_approval_requests = [
item if isinstance(item, Content) else Content.from_dict(item) for item in (queued_approval_requests or [])
]
self.collected_approval_responses = [
item if isinstance(item, Content) else Content.from_dict(item)
for item in (collected_approval_responses or [])
]
@classmethod
def from_dict(
cls,
value: MutableMapping[str, Any],
/,
*,
dependencies: MutableMapping[str, Any] | None = None,
) -> ToolApprovalState:
"""Create state from serialized state."""
del dependencies
return cls(
rules=cast(Sequence[Mapping[str, Any]], value.get("rules", [])),
queued_approval_requests=cast(Sequence[Mapping[str, Any]], value.get("queued_approval_requests", [])),
collected_approval_responses=cast(
Sequence[Mapping[str, Any]],
value.get("collected_approval_responses", []),
),
)
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]:
"""Serialize state."""
del exclude_none
exclude = exclude or set()
payload: dict[str, Any] = {
"rules": [rule.to_dict() for rule in self.rules],
"queued_approval_requests": [_content_to_state(item) for item in self.queued_approval_requests],
"collected_approval_responses": [_content_to_state(item) for item in self.collected_approval_responses],
}
if "type" not in exclude:
payload["type"] = self._get_type_identifier()
return payload
def create_always_approve_tool_response(request: Content, *, reason: str | None = None) -> Content:
"""Create an approval response that records a standing rule for the whole tool.
Args:
request: The ``function_approval_request`` content to approve.
Keyword Args:
reason: Optional approval reason stored in ``additional_properties``.
Returns:
A ``function_approval_response`` with metadata consumed by
:class:`ToolApprovalMiddleware`.
"""
return _create_always_approve_response(request, ALWAYS_APPROVE_TOOL, reason=reason)
def create_always_approve_tool_with_arguments_response(request: Content, *, reason: str | None = None) -> Content:
"""Create an approval response that records a standing rule for the tool and exact arguments."""
return _create_always_approve_response(request, ALWAYS_APPROVE_TOOL_WITH_ARGUMENTS, reason=reason)
def _create_always_approve_response(request: Content, scope: ToolApprovalScope, *, reason: str | None) -> Content:
response = request.to_function_approval_response(approved=True)
metadata: dict[str, Any] = {ALWAYS_APPROVE_SCOPE_PROPERTY: scope}
if reason is not None:
metadata["reason"] = reason
response.additional_properties[ALWAYS_APPROVE_PROPERTY] = metadata
return response
def _get_state(session: AgentSession, *, source_id: str) -> ToolApprovalState:
raw_state = session.state.get(source_id)
if isinstance(raw_state, ToolApprovalState):
return raw_state
if isinstance(raw_state, MutableMapping):
raw_state_mapping = cast(MutableMapping[str, Any], raw_state)
return ToolApprovalState(
rules=cast(Sequence[Mapping[str, Any]], raw_state_mapping.get(_RULES_KEY, [])),
queued_approval_requests=_contents_from_state(raw_state_mapping.get(_QUEUED_APPROVAL_REQUESTS_KEY, [])),
collected_approval_responses=_contents_from_state(
raw_state_mapping.get(_COLLECTED_APPROVAL_RESPONSES_KEY, []),
),
)
if raw_state is not None:
raise TypeError(f"Session state for {source_id!r} must be a mapping, got {type(raw_state).__name__}.")
state = ToolApprovalState()
session.state[source_id] = state.to_dict(exclude={"type"})
return state
def _save_state(session: AgentSession, state: ToolApprovalState, *, source_id: str) -> None:
serialized = state.to_dict(exclude={"type"})
existing = session.state.get(source_id)
if isinstance(existing, MutableMapping):
for key, value in cast(MutableMapping[str, Any], existing).items():
if key not in serialized and key != "type":
serialized[key] = value
session.state[source_id] = serialized
def _rule_exists(rules: Sequence[ToolApprovalRule], new_rule: ToolApprovalRule) -> bool:
for rule in rules:
if rule.tool_name != new_rule.tool_name:
continue
if rule.server_label != new_rule.server_label:
continue
if rule.arguments == new_rule.arguments:
return True
return False
def _add_rule_if_missing(state: ToolApprovalState, rule: ToolApprovalRule) -> None:
if not _rule_exists(state.rules, rule):
state.rules.append(rule)
def _function_call_from_request(request: Content) -> Content | None:
function_call = request.function_call
if function_call is None or function_call.type != "function_call" or function_call.name is None:
return None
return function_call
def _arguments_match(rule_arguments: Mapping[str, str], function_call: Content) -> bool:
call_arguments = _serialize_arguments(function_call) or {}
if len(rule_arguments) != len(call_arguments):
return False
return all(call_arguments.get(key) == value for key, value in rule_arguments.items())
def _matches_rule(request: Content, rules: Sequence[ToolApprovalRule]) -> bool:
function_call = _function_call_from_request(request)
if function_call is None:
return False
for rule in rules:
if rule.tool_name != function_call.name:
continue
if rule.server_label != _server_label(function_call):
continue
if rule.arguments is None:
return True
if _arguments_match(rule.arguments, function_call):
return True
return False
def _get_always_approve_scope(response: Content) -> ToolApprovalScope | None:
metadata = response.additional_properties.get(ALWAYS_APPROVE_PROPERTY)
if not isinstance(metadata, Mapping):
return None
metadata_mapping = cast(Mapping[str, Any], metadata)
scope = metadata_mapping.get(ALWAYS_APPROVE_SCOPE_PROPERTY)
if scope == ALWAYS_APPROVE_TOOL:
return ALWAYS_APPROVE_TOOL
if scope == ALWAYS_APPROVE_TOOL_WITH_ARGUMENTS:
return ALWAYS_APPROVE_TOOL_WITH_ARGUMENTS
return None
def _clone_without_always_approve_metadata(response: Content) -> Content:
cloned = copy.deepcopy(response)
cloned.additional_properties.pop(ALWAYS_APPROVE_PROPERTY, None)
return cloned
@experimental(feature_id=ExperimentalFeature.HARNESS)
class ToolApprovalMiddleware(AgentMiddleware):
"""Coordinate standing tool approvals and queued approval prompts for an agent.
This middleware is opt-in and requires callers to run the agent with an
:class:`AgentSession`, because approval rules and queued requests are stored
in session state.
"""
def __init__(
self,
*,
source_id: str = DEFAULT_TOOL_APPROVAL_SOURCE_ID,
auto_approval_rules: Sequence[ToolApprovalRuleCallback] | None = None,
) -> None:
"""Initialize the middleware.
Keyword Args:
source_id: Session-state key used by this middleware.
auto_approval_rules: Optional callbacks that can auto-approve a
``function_call``. Each callback receives the function-call
content and returns ``True`` to approve it.
"""
self.source_id = source_id
self.auto_approval_rules = tuple(auto_approval_rules or ())
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
"""Process one agent invocation."""
if context.session is None:
raise RuntimeError("ToolApprovalMiddleware requires an AgentSession.")
state = _get_state(context.session, source_id=self.source_id)
context.client_kwargs.setdefault(_FUNCTION_INVOCATION_BUDGET_STATE_KEY, {})
context.messages = self._prepare_inbound_messages(context.messages, state)
await self._drain_auto_approvable_queue(state)
if next_queued := self._pop_next_queued_request(state):
_save_state(context.session, state, source_id=self.source_id)
context.result = self._response_for_queued_request(next_queued, stream=context.stream)
return
if context.stream:
context.result = self._process_stream(context, call_next, state)
return
while True:
context.messages = self._inject_collected_responses(context.messages, state)
state_changed = bool(state.collected_approval_responses)
state.collected_approval_responses.clear()
if state_changed:
_save_state(context.session, state, source_id=self.source_id)
await call_next()
if isinstance(context.result, ResponseStream):
return
if context.result is None:
_save_state(context.session, state, source_id=self.source_id)
return
all_auto_approved = await self._process_outbound_messages(context.result.messages, state)
_save_state(context.session, state, source_id=self.source_id)
if not all_auto_approved:
return
context.messages = []
context.result = None
def _response_for_queued_request(
self,
request: Content,
*,
stream: bool,
) -> AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse]:
if not stream:
return AgentResponse(messages=[Message(role="assistant", contents=[request])])
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
await sleep(0)
yield AgentResponseUpdate(role="assistant", contents=[request])
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates)
def _process_stream(
self,
context: AgentContext,
call_next: Callable[[], Awaitable[None]],
state: ToolApprovalState,
) -> ResponseStream[AgentResponseUpdate, AgentResponse]:
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
if context.session is None:
raise RuntimeError("ToolApprovalMiddleware requires an AgentSession.")
while True:
context.messages = self._inject_collected_responses(context.messages, state)
state_changed = bool(state.collected_approval_responses)
state.collected_approval_responses.clear()
if state_changed:
_save_state(context.session, state, source_id=self.source_id)
await call_next()
if not isinstance(context.result, ResponseStream):
raise ValueError("Streaming ToolApprovalMiddleware requires a ResponseStream result.")
approval_requests: list[Content] = []
async for update in context.result:
approval_contents = [
content for content in update.contents if content.type == "function_approval_request"
]
if not approval_contents:
yield update
continue
approval_requests.extend(approval_contents)
remaining_contents = [
content for content in update.contents if content.type != "function_approval_request"
]
if remaining_contents:
raw_finish_reason = update.finish_reason
finish_reason: FinishReasonLiteral | FinishReason | None
if isinstance(raw_finish_reason, str):
finish_reason = FinishReason(raw_finish_reason)
else:
finish_reason = cast(FinishReasonLiteral | FinishReason | None, raw_finish_reason)
yield AgentResponseUpdate(
contents=remaining_contents,
role=update.role,
author_name=update.author_name,
agent_id=update.agent_id,
response_id=update.response_id,
message_id=update.message_id,
created_at=update.created_at,
finish_reason=finish_reason,
continuation_token=update.continuation_token,
additional_properties=update.additional_properties,
raw_representation=update.raw_representation,
)
await context.result.get_final_response()
if not approval_requests:
return
response_messages = [Message(role="assistant", contents=approval_requests)]
all_auto_approved = await self._process_outbound_messages(response_messages, state)
_save_state(context.session, state, source_id=self.source_id)
if not all_auto_approved:
for message in response_messages:
if message.contents:
yield AgentResponseUpdate(role=message.role, contents=message.contents)
return
context.messages = []
context.result = None
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates)
def _prepare_inbound_messages(self, messages: Sequence[Message], state: ToolApprovalState) -> list[Message]:
prepared: list[Message] = []
for message in messages:
replacement_contents: list[Content] = []
changed = False
for content in message.contents:
if content.type == "function_approval_response":
replacement = self._handle_inbound_approval_response(content, state)
state.collected_approval_responses.append(replacement)
changed = True
continue
replacement_contents.append(content)
if not changed:
prepared.append(message)
continue
if replacement_contents:
cloned = copy.copy(message)
cloned.contents = replacement_contents
prepared.append(cloned)
return prepared
def _handle_inbound_approval_response(self, response: Content, state: ToolApprovalState) -> Content:
scope = _get_always_approve_scope(response)
if scope is None or not response.approved:
return response
function_call = response.function_call
if function_call is not None and function_call.type == "function_call" and function_call.name is not None:
if scope == ALWAYS_APPROVE_TOOL:
_add_rule_if_missing(
state,
ToolApprovalRule(
tool_name=function_call.name,
server_label=_server_label(function_call),
),
)
else:
_add_rule_if_missing(
state,
ToolApprovalRule(
tool_name=function_call.name,
arguments=_serialize_arguments(function_call),
server_label=_server_label(function_call),
),
)
return _clone_without_always_approve_metadata(response)
def _inject_collected_responses(self, messages: Sequence[Message], state: ToolApprovalState) -> list[Message]:
if not state.collected_approval_responses:
return list(messages)
return [Message(role="user", contents=list(state.collected_approval_responses)), *messages]
async def _drain_auto_approvable_queue(self, state: ToolApprovalState) -> None:
remaining: list[Content] = []
for request in state.queued_approval_requests:
if _matches_rule(request, state.rules) or await self._matches_auto_rule(request):
state.collected_approval_responses.append(request.to_function_approval_response(approved=True))
continue
remaining.append(request)
state.queued_approval_requests = remaining
def _pop_next_queued_request(self, state: ToolApprovalState) -> Content | None:
if not state.queued_approval_requests:
return None
return state.queued_approval_requests.pop(0)
async def _process_outbound_messages(self, messages: list[Message], state: ToolApprovalState) -> bool:
approval_requests = [
content
for message in messages
for content in message.contents
if content.type == "function_approval_request"
]
if not approval_requests:
return False
auto_approved: set[int] = set()
unresolved: list[Content] = []
for request in approval_requests:
if _matches_rule(request, state.rules) or await self._matches_auto_rule(request):
state.collected_approval_responses.append(request.to_function_approval_response(approved=True))
auto_approved.add(id(request))
else:
unresolved.append(request)
if not auto_approved and len(unresolved) <= 1:
return False
queued_ids: set[int] = set()
for request in unresolved[1:]:
queued_ids.add(id(request))
state.queued_approval_requests.append(request)
remove_ids = auto_approved | queued_ids
self._remove_approval_requests(messages, remove_ids)
return not unresolved
@staticmethod
def _remove_approval_requests(messages: list[Message], remove_ids: set[int]) -> None:
for message_index in range(len(messages) - 1, -1, -1):
message = messages[message_index]
filtered = [
content
for content in message.contents
if content.type != "function_approval_request" or id(content) not in remove_ids
]
if len(filtered) == len(message.contents):
continue
if filtered:
message.contents = filtered
else:
messages.pop(message_index)
async def _matches_auto_rule(self, request: Content) -> bool:
function_call = _function_call_from_request(request)
if function_call is None:
return False
for rule in self.auto_approval_rules:
result = rule(function_call)
if inspect.isawaitable(result):
result = await result
if result:
return True
return False
__all__ = [
"ALWAYS_APPROVE_PROPERTY",
"ALWAYS_APPROVE_SCOPE_PROPERTY",
"ALWAYS_APPROVE_TOOL",
"ALWAYS_APPROVE_TOOL_WITH_ARGUMENTS",
"DEFAULT_TOOL_APPROVAL_SOURCE_ID",
"ToolApprovalMiddleware",
"ToolApprovalRule",
"ToolApprovalRuleCallback",
"ToolApprovalState",
"create_always_approve_tool_response",
"create_always_approve_tool_with_arguments_response",
]
+190 -15
View File
@@ -90,6 +90,9 @@ logger = logging.getLogger("agent_framework")
DEFAULT_MAX_ITERATIONS: Final[int] = 40
DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST: Final[int] = 3
SHELL_TOOL_KIND_VALUE: Final[str] = "shell"
_TOOL_APPROVAL_STATE_KEY: Final[str] = "tool_approval"
_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY: Final[str] = "already_approved_approval_request_groups"
_FUNCTION_INVOCATION_BUDGET_STATE_KEY: Final[str] = "_function_invocation_budget_state"
ApprovalMode: TypeAlias = Literal["always_require", "never_require"]
ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]")
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
@@ -1685,15 +1688,15 @@ async def _try_execute_function_calls(
# The live tools list (when tools is the run-local list) is exposed on the
# FunctionInvocationContext so tools can add/remove tools during the run.
live_tools: list[ToolTypes] | None = cast("list[ToolTypes]", tools) if isinstance(tools, list) else None
approval_tools = [tool_name for tool_name, tool in tool_map.items() if tool.approval_mode == "always_require"]
approval_tools = {tool_name for tool_name, tool in tool_map.items() if tool.approval_mode == "always_require"}
logger.debug(
"_try_execute_function_calls: tool_map keys=%s, approval_tools=%s",
list(tool_map.keys()),
approval_tools,
)
declaration_only = [tool_name for tool_name, tool in tool_map.items() if tool.declaration_only]
declaration_only = {tool_name for tool_name, tool in tool_map.items() if tool.declaration_only}
configured_additional_tools = config.get("additional_tools") or []
additional_tool_names = [tool.name for tool in configured_additional_tools]
additional_tool_names = {tool.name for tool in configured_additional_tools}
# check if any are calling functions that need approval
# if so, we return approval request for all
approval_needed = False
@@ -1719,15 +1722,39 @@ async def _try_execute_function_calls(
raise KeyError(f'Error: Requested function "{fcc.name}" not found.') # type: ignore[attr-defined]
if approval_needed:
# approval can only be needed for Function Call Content, not Approval Responses.
logger.debug("Returning function_approval_request contents")
return (
[
Content.from_function_approval_request(id=fcc.call_id, function_call=fcc) # type: ignore[attr-defined, arg-type]
for fcc in function_calls
if fcc.type == "function_call"
],
False,
logger.debug("Returning visible function_approval_request contents and storing already-approved requests")
visible_requests: list[Content] = []
already_approved_requests: list[Content] = []
for fcc in function_calls:
if fcc.type != "function_call":
continue
approval_request = Content.from_function_approval_request(
id=fcc.call_id, # type: ignore[arg-type]
function_call=fcc,
)
tool_name = fcc.name # type: ignore[attr-defined]
if tool_name is None:
visible_requests.append(approval_request)
continue
tool = tool_map.get(tool_name)
if (
tool_name in approval_tools
or tool is None
or tool_name in declaration_only
or tool_name in additional_tool_names
):
visible_requests.append(approval_request)
continue
if invocation_session is None:
visible_requests.append(approval_request)
continue
already_approved_requests.append(approval_request)
_store_already_approved_approval_requests(
invocation_session,
visible_requests,
already_approved_requests,
)
return (visible_requests, False)
if declaration_only_flag:
# return the declaration only tools to the user, since we cannot execute them.
# Mark as user_input_request so AgentExecutor emits request_info events and pauses the workflow.
@@ -1912,6 +1939,108 @@ def _is_hosted_tool_approval(content: Any) -> bool:
return bool(ap and ap.get("server_label"))
def _get_tool_approval_state(invocation_session: AgentSession | None) -> dict[str, Any] | None:
"""Return the shared tool-approval state bag for the invocation session."""
if invocation_session is None:
return None
raw_state = invocation_session.state.get(_TOOL_APPROVAL_STATE_KEY)
if isinstance(raw_state, dict):
return cast(dict[str, Any], raw_state)
from ._harness._tool_approval import ToolApprovalState
if isinstance(raw_state, ToolApprovalState):
serialized_state = raw_state.to_dict(exclude={"type"})
invocation_session.state[_TOOL_APPROVAL_STATE_KEY] = serialized_state
return serialized_state
if raw_state is not None:
raise TypeError(
f"Session state for {_TOOL_APPROVAL_STATE_KEY!r} must be a dict or ToolApprovalState, "
f"got {type(raw_state).__name__}."
)
new_state: dict[str, Any] = {}
invocation_session.state[_TOOL_APPROVAL_STATE_KEY] = new_state
return new_state
def _content_from_state(value: Any) -> Content | None:
"""Restore a Content item stored in session state."""
from ._types import Content
if isinstance(value, Content):
return value
if isinstance(value, Mapping):
return Content.from_dict(cast(Mapping[str, Any], value))
return None
def _store_already_approved_approval_requests(
invocation_session: AgentSession | None,
visible_approval_requests: Sequence[Content],
already_approved_requests: Sequence[Content],
) -> None:
"""Store hidden already-approved requests keyed by the visible approvals that resume the batch."""
if not already_approved_requests:
return
state = _get_tool_approval_state(invocation_session)
if state is None:
return
visible_ids = [request.id for request in visible_approval_requests if request.id]
if not visible_ids:
return
existing_groups = state.get(_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY)
pending_groups: list[Any] = (
list(cast(Iterable[Any], existing_groups)) if isinstance(existing_groups, list) else []
)
pending_groups.append({
"approval_request_ids": visible_ids,
"approval_requests": [request.to_dict() for request in already_approved_requests],
})
state[_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY] = pending_groups
def _pop_already_approved_approval_responses(
invocation_session: AgentSession | None,
approval_response_ids: set[str],
) -> list[Content]:
"""Pop already-approved requests for the visible approval ids being answered."""
if not approval_response_ids:
return []
state = _get_tool_approval_state(invocation_session)
if state is None:
return []
raw_groups = state.get(_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY, [])
if not isinstance(raw_groups, list):
return []
responses: list[Content] = []
remaining_groups: list[Any] = []
raw_group_items = list(cast(Iterable[Any], raw_groups))
for raw_group in raw_group_items:
if not isinstance(raw_group, Mapping):
continue
group = cast(Mapping[str, Any], raw_group)
raw_ids = group.get("approval_request_ids")
raw_group_ids: Iterable[Any] = cast(Iterable[Any], raw_ids) if isinstance(raw_ids, list) else ()
group_ids = {str(item) for item in raw_group_ids}
if group_ids.isdisjoint(approval_response_ids):
remaining_groups.append(raw_group)
continue
raw_requests = group.get("approval_requests")
if not isinstance(raw_requests, list):
continue
for raw_request in list(cast(Iterable[Any], raw_requests)):
request = _content_from_state(raw_request)
if request is None or request.type != "function_approval_request":
continue
responses.append(request.to_function_approval_response(approved=True))
if remaining_groups:
state[_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY] = remaining_groups
else:
state.pop(_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY, None)
return responses
def _collect_approval_responses(
messages: list[Message],
) -> dict[str, Content]:
@@ -2157,8 +2286,24 @@ async def _process_function_requests(
errors_in_a_row: int,
max_errors: int,
execute_function_calls: Callable[..., Awaitable[tuple[list[Content], bool, bool]]],
invocation_session: AgentSession | None = None,
) -> FunctionRequestResult:
from ._types import Message
if prepped_messages is not None:
explicit_approval_response_ids = {
content.id
for message in prepped_messages
if isinstance(message, Message)
for content in message.contents
if content.type == "function_approval_response" and content.id
}
already_approved_responses = _pop_already_approved_approval_responses(
invocation_session,
explicit_approval_response_ids,
)
if already_approved_responses:
prepped_messages.append(Message(role="user", contents=already_approved_responses))
fcc_todo = _collect_approval_responses(prepped_messages)
if not fcc_todo:
fcc_todo = {}
@@ -2362,6 +2507,10 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
function_middleware_pipeline = self._get_function_middleware_pipeline(runtime_middleware["function"])
if runtime_middleware["chat"]:
effective_client_kwargs["middleware"] = runtime_middleware["chat"]
raw_budget_state = effective_client_kwargs.pop(_FUNCTION_INVOCATION_BUDGET_STATE_KEY, None)
budget_state: dict[str, Any] = (
cast(dict[str, Any], raw_budget_state) if isinstance(raw_budget_state, dict) else {}
)
max_errors = self.function_invocation_configuration.get(
"max_consecutive_errors_per_request", DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST
)
@@ -2411,7 +2560,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
nonlocal mutable_options
nonlocal filtered_kwargs
errors_in_a_row: int = 0
total_function_calls: int = 0
total_function_calls = int(budget_state.get("total_function_calls", 0) or 0)
max_function_calls: int | None = self.function_invocation_configuration.get("max_function_calls")
prepped_messages = list(messages)
fcc_messages: list[Message] = []
@@ -2420,7 +2569,9 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
loop_enabled = self.function_invocation_configuration.get("enabled", True)
max_iterations = self.function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS)
for attempt_idx in range(max_iterations if loop_enabled else 0):
attempt_start = int(budget_state.get("attempt_count", 0) or 0)
for attempt_idx in range(attempt_start, max_iterations if loop_enabled else 0):
budget_state["attempt_count"] = attempt_idx + 1
approval_result = await _process_function_requests(
response=None,
prepped_messages=prepped_messages,
@@ -2430,12 +2581,21 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
errors_in_a_row=errors_in_a_row,
max_errors=max_errors,
execute_function_calls=execute_function_calls,
invocation_session=invocation_session,
)
if approval_result.get("action") == "stop":
response = ChatResponse(messages=prepped_messages)
break
errors_in_a_row = approval_result.get("errors_in_a_row", errors_in_a_row)
total_function_calls += approval_result.get("function_call_count", 0)
budget_state["total_function_calls"] = total_function_calls
if max_function_calls is not None and total_function_calls >= max_function_calls:
logger.info(
"Maximum function calls reached (%d/%d). Stopping further function calls for this request.",
total_function_calls,
max_function_calls,
)
mutable_options["tool_choice"] = "none"
response = cast(
ChatResponse[Any],
@@ -2468,11 +2628,13 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
errors_in_a_row=errors_in_a_row,
max_errors=max_errors,
execute_function_calls=execute_function_calls,
invocation_session=invocation_session,
)
if result.get("action") == "return":
response.usage_details = aggregated_usage
return _clear_internal_conversation_id(response)
total_function_calls += result.get("function_call_count", 0)
budget_state["total_function_calls"] = total_function_calls
if result.get("action") == "stop":
# Error threshold reached: force a final non-tool turn so
# function_call_output items are submitted before exit.
@@ -2549,7 +2711,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
nonlocal mutable_options
nonlocal stream_result_hooks
errors_in_a_row: int = 0
total_function_calls: int = 0
total_function_calls = int(budget_state.get("total_function_calls", 0) or 0)
max_function_calls: int | None = self.function_invocation_configuration.get("max_function_calls")
prepped_messages = list(messages)
fcc_messages: list[Message] = []
@@ -2557,7 +2719,9 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
loop_enabled = self.function_invocation_configuration.get("enabled", True)
max_iterations = self.function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS)
for attempt_idx in range(max_iterations if loop_enabled else 0):
attempt_start = int(budget_state.get("attempt_count", 0) or 0)
for attempt_idx in range(attempt_start, max_iterations if loop_enabled else 0):
budget_state["attempt_count"] = attempt_idx + 1
approval_result = await _process_function_requests(
response=None,
prepped_messages=prepped_messages,
@@ -2567,9 +2731,18 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
errors_in_a_row=errors_in_a_row,
max_errors=max_errors,
execute_function_calls=execute_function_calls,
invocation_session=invocation_session,
)
errors_in_a_row = approval_result.get("errors_in_a_row", errors_in_a_row)
total_function_calls += approval_result.get("function_call_count", 0)
budget_state["total_function_calls"] = total_function_calls
if max_function_calls is not None and total_function_calls >= max_function_calls:
logger.info(
"Maximum function calls reached (%d/%d). Stopping further function calls for this request.",
total_function_calls,
max_function_calls,
)
mutable_options["tool_choice"] = "none"
if approval_result.get("action") == "stop":
mutable_options["tool_choice"] = "none"
return
@@ -2622,9 +2795,11 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
errors_in_a_row=errors_in_a_row,
max_errors=max_errors,
execute_function_calls=execute_function_calls,
invocation_session=invocation_session,
)
errors_in_a_row = result.get("errors_in_a_row", errors_in_a_row)
total_function_calls += result.get("function_call_count", 0)
budget_state["total_function_calls"] = total_function_calls
if role := result.get("update_role"):
yield ChatResponseUpdate(
contents=result.get("function_call_results") or [],
@@ -0,0 +1,817 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
from agent_framework import (
DEFAULT_TOOL_APPROVAL_SOURCE_ID,
Agent,
AgentSession,
ChatResponse,
ChatResponseUpdate,
Content,
Message,
SupportsChatGetResponse,
ToolApprovalMiddleware,
ToolApprovalState,
create_always_approve_tool_response,
create_always_approve_tool_with_arguments_response,
tool,
)
def _approval_requests(messages: list[Message]) -> list[Content]:
return [
content for message in messages for content in message.contents if content.type == "function_approval_request"
]
async def test_mixed_batch_hides_already_approved_request_until_approval_replay(
chat_client_base: SupportsChatGetResponse,
) -> None:
"""Mixed batches should only show real approval requests when a session can store hidden requests."""
no_approval_calls = 0
approval_calls = 0
@tool(name="lookup_work_items", approval_mode="never_require")
def lookup_work_items(query: str) -> str:
nonlocal no_approval_calls
no_approval_calls += 1
return f"found {query}"
@tool(name="add_comment", approval_mode="always_require")
def add_comment(comment: str) -> str:
nonlocal approval_calls
approval_calls += 1
return f"added {comment}"
agent = Agent(client=chat_client_base, tools=[lookup_work_items, add_comment])
session = AgentSession(session_id="approval-session")
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_lookup",
name="lookup_work_items",
arguments='{"query": "mine"}',
),
Content.from_function_call(
call_id="call_comment",
name="add_comment",
arguments='{"comment": "done"}',
),
],
)
)
]
first_response = await agent.run("update work item", session=session)
requests = _approval_requests(first_response.messages)
assert [request.function_call.name for request in requests] == ["add_comment"]
assert no_approval_calls == 0
assert approval_calls == 0
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["complete"]))]
second_response = await agent.run(requests[0].to_function_approval_response(approved=True), session=session)
assert second_response.text == "complete"
assert no_approval_calls == 1
assert approval_calls == 1
async def test_mixed_batch_accepts_restored_tool_approval_state(
chat_client_base: SupportsChatGetResponse,
) -> None:
"""Mixed-batch bypass should work when session state contains ToolApprovalState."""
safe_calls = 0
risky_calls = 0
@tool(name="safe_read", approval_mode="never_require")
def safe_read() -> str:
nonlocal safe_calls
safe_calls += 1
return "safe"
@tool(name="risky_write", approval_mode="always_require")
def risky_write() -> str:
nonlocal risky_calls
risky_calls += 1
return "risky"
agent = Agent(client=chat_client_base, tools=[safe_read, risky_write])
session = AgentSession(session_id="restored-state-session")
session.state[DEFAULT_TOOL_APPROVAL_SOURCE_ID] = ToolApprovalState()
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="call_safe", name="safe_read", arguments="{}"),
Content.from_function_call(call_id="call_risky", name="risky_write", arguments="{}"),
],
)
)
]
first_response = await agent.run("read and write", session=session)
requests = _approval_requests(first_response.messages)
assert [request.function_call.name for request in requests] == ["risky_write"]
assert safe_calls == 0
assert risky_calls == 0
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["done"]))]
final_response = await agent.run(requests[0].to_function_approval_response(approved=True), session=session)
assert final_response.text == "done"
assert safe_calls == 1
assert risky_calls == 1
async def test_hidden_mixed_batch_requests_do_not_replay_on_unrelated_turn(
chat_client_base: SupportsChatGetResponse,
) -> None:
"""Stored hidden approvals should only replay when an approval response resumes the flow."""
safe_calls = 0
risky_calls = 0
@tool(name="safe_lookup", approval_mode="never_require")
def safe_lookup() -> str:
nonlocal safe_calls
safe_calls += 1
return "safe"
@tool(name="risky_update", approval_mode="always_require")
def risky_update() -> str:
nonlocal risky_calls
risky_calls += 1
return "risky"
agent = Agent(client=chat_client_base, tools=[safe_lookup, risky_update])
session = AgentSession(session_id="stale-hidden-session")
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="call_safe", name="safe_lookup", arguments="{}"),
Content.from_function_call(call_id="call_risky", name="risky_update", arguments="{}"),
],
)
)
]
first_response = await agent.run("lookup and update", session=session)
request = _approval_requests(first_response.messages)[0]
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["unrelated"]))]
unrelated_response = await agent.run("never mind, answer something else", session=session)
assert unrelated_response.text == "unrelated"
assert safe_calls == 0
assert risky_calls == 0
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["done"]))]
final_response = await agent.run(request.to_function_approval_response(approved=True), session=session)
assert final_response.text == "done"
assert safe_calls == 1
assert risky_calls == 1
async def test_hidden_mixed_batch_requests_replay_only_for_matching_visible_approval(
chat_client_base: SupportsChatGetResponse,
) -> None:
"""Approving one mixed batch must not replay hidden calls from another abandoned batch."""
safe_a_calls = 0
safe_b_calls = 0
risky_a_calls = 0
risky_b_calls = 0
@tool(name="safe_a", approval_mode="never_require")
def safe_a() -> str:
nonlocal safe_a_calls
safe_a_calls += 1
return "safe-a"
@tool(name="safe_b", approval_mode="never_require")
def safe_b() -> str:
nonlocal safe_b_calls
safe_b_calls += 1
return "safe-b"
@tool(name="risky_a", approval_mode="always_require")
def risky_a() -> str:
nonlocal risky_a_calls
risky_a_calls += 1
return "risky-a"
@tool(name="risky_b", approval_mode="always_require")
def risky_b() -> str:
nonlocal risky_b_calls
risky_b_calls += 1
return "risky-b"
agent = Agent(client=chat_client_base, tools=[safe_a, safe_b, risky_a, risky_b])
session = AgentSession(session_id="grouped-hidden-session")
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="call_safe_a", name="safe_a", arguments="{}"),
Content.from_function_call(call_id="call_risky_a", name="risky_a", arguments="{}"),
],
)
)
]
first_response = await agent.run("batch a", session=session)
assert [request.function_call.name for request in _approval_requests(first_response.messages)] == ["risky_a"]
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="call_safe_b", name="safe_b", arguments="{}"),
Content.from_function_call(call_id="call_risky_b", name="risky_b", arguments="{}"),
],
)
)
]
second_response = await agent.run("batch b", session=session)
second_request = _approval_requests(second_response.messages)[0]
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["done"]))]
final_response = await agent.run(second_request.to_function_approval_response(approved=True), session=session)
assert final_response.text == "done"
assert safe_a_calls == 0
assert risky_a_calls == 0
assert safe_b_calls == 1
assert risky_b_calls == 1
async def test_tool_approval_middleware_queues_multiple_approval_requests(
chat_client_base: SupportsChatGetResponse,
) -> None:
"""The opt-in middleware should present multiple unresolved approvals one at a time."""
first_calls = 0
second_calls = 0
@tool(name="first_tool", approval_mode="always_require")
def first_tool() -> str:
nonlocal first_calls
first_calls += 1
return "first"
@tool(name="second_tool", approval_mode="always_require")
def second_tool() -> str:
nonlocal second_calls
second_calls += 1
return "second"
agent = Agent(
client=chat_client_base,
tools=[first_tool, second_tool],
middleware=[ToolApprovalMiddleware()],
)
session = AgentSession(session_id="queue-session")
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="call_first", name="first_tool", arguments="{}"),
Content.from_function_call(call_id="call_second", name="second_tool", arguments="{}"),
],
)
)
]
first_response = await agent.run("call both", session=session)
first_requests = _approval_requests(first_response.messages)
assert [request.function_call.name for request in first_requests] == ["first_tool"]
assert first_calls == 0
assert second_calls == 0
second_response = await agent.run(first_requests[0].to_function_approval_response(approved=True), session=session)
second_requests = _approval_requests(second_response.messages)
assert [request.function_call.name for request in second_requests] == ["second_tool"]
assert first_calls == 0
assert second_calls == 0
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["done"]))]
final_response = await agent.run(second_requests[0].to_function_approval_response(approved=True), session=session)
assert final_response.text == "done"
assert first_calls == 1
assert second_calls == 1
async def test_tool_approval_middleware_preserves_hidden_mixed_batch_requests(
chat_client_base: SupportsChatGetResponse,
) -> None:
"""Middleware state saves should not discard core hidden already-approved requests."""
lookup_calls = 0
write_calls = 0
@tool(name="lookup_records", approval_mode="never_require")
def lookup_records() -> str:
nonlocal lookup_calls
lookup_calls += 1
return "records"
@tool(name="write_record", approval_mode="always_require")
def write_record() -> str:
nonlocal write_calls
write_calls += 1
return "written"
agent = Agent(
client=chat_client_base,
tools=[lookup_records, write_record],
middleware=[ToolApprovalMiddleware()],
)
session = AgentSession(session_id="mixed-middleware-session")
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="call_lookup", name="lookup_records", arguments="{}"),
Content.from_function_call(call_id="call_write", name="write_record", arguments="{}"),
],
)
)
]
first_response = await agent.run("lookup and write", session=session)
request = _approval_requests(first_response.messages)[0]
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["done"]))]
second_response = await agent.run(request.to_function_approval_response(approved=True), session=session)
assert second_response.text == "done"
assert lookup_calls == 1
assert write_calls == 1
async def test_tool_approval_middleware_auto_approval_rule_receives_function_call(
chat_client_base: SupportsChatGetResponse,
) -> None:
"""Heuristic auto-approval callbacks should receive function-call content and approve matching calls."""
auto_calls = 0
manual_calls = 0
seen_calls: list[tuple[str, str | None]] = []
@tool(name="auto_write", approval_mode="always_require")
def auto_write() -> str:
nonlocal auto_calls
auto_calls += 1
return "auto"
@tool(name="manual_write", approval_mode="always_require")
def manual_write() -> str:
nonlocal manual_calls
manual_calls += 1
return "manual"
async def auto_approve_auto_write(function_call: Content) -> bool:
seen_calls.append((function_call.type, function_call.name))
return function_call.name == "auto_write"
agent = Agent(
client=chat_client_base,
tools=[auto_write, manual_write],
middleware=[ToolApprovalMiddleware(auto_approval_rules=[auto_approve_auto_write])],
)
session = AgentSession(session_id="heuristic-session")
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="call_auto", name="auto_write", arguments="{}"),
Content.from_function_call(call_id="call_manual", name="manual_write", arguments="{}"),
],
)
)
]
first_response = await agent.run("write both", session=session)
requests = _approval_requests(first_response.messages)
assert [request.function_call.name for request in requests] == ["manual_write"]
assert seen_calls == [("function_call", "auto_write"), ("function_call", "manual_write")]
assert auto_calls == 0
assert manual_calls == 0
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["done"]))]
final_response = await agent.run(requests[0].to_function_approval_response(approved=True), session=session)
assert final_response.text == "done"
assert auto_calls == 1
assert manual_calls == 1
async def test_tool_approval_middleware_auto_approved_loops_share_function_call_budget(
chat_client_base: SupportsChatGetResponse,
) -> None:
"""Auto-approved re-entry should not reset max_function_calls."""
calls = 0
@tool(name="budgeted_tool", approval_mode="always_require")
def budgeted_tool(value: str) -> str:
nonlocal calls
calls += 1
return value
def auto_approve_budgeted_tool(function_call: Content) -> bool:
return function_call.name == "budgeted_tool"
chat_client_base.function_invocation_configuration["max_function_calls"] = 1 # type: ignore[attr-defined]
agent = Agent(
client=chat_client_base,
tools=[budgeted_tool],
middleware=[ToolApprovalMiddleware(auto_approval_rules=[auto_approve_budgeted_tool])],
)
session = AgentSession(session_id="shared-budget-session")
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_first",
name="budgeted_tool",
arguments='{"value": "first"}',
)
],
)
),
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_second",
name="budgeted_tool",
arguments='{"value": "second"}',
)
],
)
),
]
response = await agent.run("call repeatedly", session=session)
assert response.text == "I broke out of the function invocation loop..."
assert calls == 1
async def test_tool_approval_middleware_queues_streamed_approval_requests(
chat_client_base: SupportsChatGetResponse,
) -> None:
"""Streaming approval requests should also be queued one at a time."""
calls = 0
@tool(name="first_streamed_tool", approval_mode="always_require")
def first_streamed_tool() -> str:
nonlocal calls
calls += 1
return "first"
@tool(name="second_streamed_tool", approval_mode="always_require")
def second_streamed_tool() -> str:
nonlocal calls
calls += 1
return "second"
agent = Agent(
client=chat_client_base,
tools=[first_streamed_tool, second_streamed_tool],
middleware=[ToolApprovalMiddleware()],
)
session = AgentSession(session_id="stream-queue-session")
chat_client_base.streaming_responses = [
[
ChatResponseUpdate(
contents=[Content.from_function_call(call_id="call_first", name="first_streamed_tool", arguments="{}")],
role="assistant",
),
ChatResponseUpdate(
contents=[
Content.from_function_call(call_id="call_second", name="second_streamed_tool", arguments="{}")
],
role="assistant",
),
]
]
first_stream = agent.run("call both", stream=True, session=session)
first_updates = [update async for update in first_stream]
first_requests = [content for update in first_updates for content in update.user_input_requests]
assert [request.function_call.name for request in first_requests] == ["first_streamed_tool"]
assert calls == 0
second_stream = agent.run(
first_requests[0].to_function_approval_response(approved=True),
stream=True,
session=session,
)
second_updates = [update async for update in second_stream]
second_requests = [content for update in second_updates for content in update.user_input_requests]
assert [request.function_call.name for request in second_requests] == ["second_streamed_tool"]
assert calls == 0
chat_client_base.streaming_responses = [
[ChatResponseUpdate(contents=[Content.from_text("done")], role="assistant")]
]
final_stream = agent.run(
second_requests[0].to_function_approval_response(approved=True),
stream=True,
session=session,
)
final_updates = [update async for update in final_stream]
final_response = await final_stream.get_final_response()
assert final_updates[-1].text == "done"
assert final_response.text == "done"
assert calls == 2
async def test_tool_approval_middleware_always_approve_tool_rule(
chat_client_base: SupportsChatGetResponse,
) -> None:
"""An always-approve response should add a standing tool-level approval rule."""
calls = 0
@tool(name="dangerous_tool", approval_mode="always_require")
def dangerous_tool(value: str) -> str:
nonlocal calls
calls += 1
return value
agent = Agent(
client=chat_client_base,
tools=[dangerous_tool],
middleware=[ToolApprovalMiddleware()],
)
session = AgentSession(session_id="standing-rule-session")
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_initial",
name="dangerous_tool",
arguments='{"value": "one"}',
)
],
)
)
]
first_response = await agent.run("call once", session=session)
first_request = _approval_requests(first_response.messages)[0]
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["first done"]))]
await agent.run(create_always_approve_tool_response(first_request), session=session)
assert calls == 1
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_auto",
name="dangerous_tool",
arguments='{"value": "two"}',
)
],
)
),
ChatResponse(messages=Message(role="assistant", contents=["second done"])),
]
second_response = await agent.run("call again", session=session)
assert second_response.text == "second done"
assert calls == 2
async def test_tool_approval_middleware_standing_rules_include_hosted_server_boundary(
chat_client_base: SupportsChatGetResponse,
) -> None:
"""A standing hosted-tool rule should only match the same server_label."""
calls = 0
@tool(name="hosted_tool", approval_mode="always_require")
def hosted_tool() -> str:
nonlocal calls
calls += 1
return "hosted"
def hosted_call(call_id: str, server_label: str) -> Content:
return Content.from_function_call(
call_id=call_id,
name="hosted_tool",
arguments="{}",
additional_properties={"server_label": server_label},
)
agent = Agent(
client=chat_client_base,
tools=[hosted_tool],
middleware=[ToolApprovalMiddleware()],
)
session = AgentSession(session_id="hosted-boundary-session")
chat_client_base.run_responses = [
ChatResponse(messages=Message(role="assistant", contents=[hosted_call("call_initial", "server-a")]))
]
first_response = await agent.run("call hosted a", session=session)
first_request = _approval_requests(first_response.messages)[0]
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["server a done"]))]
await agent.run(create_always_approve_tool_response(first_request), session=session)
assert calls == 0
chat_client_base.run_responses = [
ChatResponse(messages=Message(role="assistant", contents=[hosted_call("call_same_server", "server-a")])),
ChatResponse(messages=Message(role="assistant", contents=["same server done"])),
]
same_server_response = await agent.run("call hosted a again", session=session)
assert same_server_response.text == "same server done"
assert _approval_requests(same_server_response.messages) == []
assert calls == 0
chat_client_base.run_responses = [
ChatResponse(messages=Message(role="assistant", contents=[hosted_call("call_other_server", "server-b")]))
]
other_server_response = await agent.run("call hosted b", session=session)
requests = _approval_requests(other_server_response.messages)
assert [request.function_call.additional_properties["server_label"] for request in requests] == ["server-b"]
assert calls == 0
async def test_tool_approval_middleware_always_approve_tool_with_arguments_rule(
chat_client_base: SupportsChatGetResponse,
) -> None:
"""Argument-scoped always-approve rules should require exact argument matches."""
calls = 0
@tool(name="argument_scoped_tool", approval_mode="always_require")
def argument_scoped_tool(value: str) -> str:
nonlocal calls
calls += 1
return value
agent = Agent(
client=chat_client_base,
tools=[argument_scoped_tool],
middleware=[ToolApprovalMiddleware()],
)
session = AgentSession(session_id="argument-rule-session")
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_initial",
name="argument_scoped_tool",
arguments='{"value": "same"}',
)
],
)
)
]
first_response = await agent.run("call with same", session=session)
first_request = _approval_requests(first_response.messages)[0]
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["first done"]))]
await agent.run(create_always_approve_tool_with_arguments_response(first_request), session=session)
assert calls == 1
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_same",
name="argument_scoped_tool",
arguments='{"value": "same"}',
)
],
)
),
ChatResponse(messages=Message(role="assistant", contents=["same done"])),
]
second_response = await agent.run("call with same again", session=session)
assert second_response.text == "same done"
assert calls == 2
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_different",
name="argument_scoped_tool",
arguments='{"value": "different"}',
)
],
)
)
]
third_response = await agent.run("call with different args", session=session)
requests = _approval_requests(third_response.messages)
assert [request.function_call.arguments for request in requests] == ['{"value": "different"}']
assert calls == 2
async def test_tool_approval_middleware_empty_arguments_rule_is_not_tool_wide(
chat_client_base: SupportsChatGetResponse,
) -> None:
"""An argument-scoped no-argument approval should not become a wildcard."""
calls = 0
@tool(name="optional_args_tool", approval_mode="always_require")
def optional_args_tool(value: str = "default") -> str:
nonlocal calls
calls += 1
return value
agent = Agent(
client=chat_client_base,
tools=[optional_args_tool],
middleware=[ToolApprovalMiddleware()],
)
session = AgentSession(session_id="empty-arguments-rule-session")
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_empty",
name="optional_args_tool",
arguments="{}",
)
],
)
)
]
first_response = await agent.run("call without args", session=session)
first_request = _approval_requests(first_response.messages)[0]
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["empty done"]))]
await agent.run(create_always_approve_tool_with_arguments_response(first_request), session=session)
assert calls == 1
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_non_empty",
name="optional_args_tool",
arguments='{"value": "custom"}',
)
],
)
)
]
second_response = await agent.run("call with args", session=session)
requests = _approval_requests(second_response.messages)
assert [request.function_call.arguments for request in requests] == ['{"value": "custom"}']
assert calls == 1
@@ -978,6 +978,10 @@ async def test_sandbox_code_failure_returns_nonzero_exit(restored_sandbox) -> No
@skip_if_hyperlight_integration_tests_disabled
@pytest.mark.skipif(
sys.platform == "win32" and sys.version_info < (3, 11),
reason="Hyperlight sandbox snapshot/restore crashes on Windows Python 3.10.",
)
async def test_sandbox_snapshot_restore_keeps_sandbox_functional(restored_sandbox) -> None:
"""Verify snapshot/restore cycle leaves the sandbox in a working state."""
# Mutate the sandbox
+1
View File
@@ -22,6 +22,7 @@ injection, and dynamic (progressive) tool exposure.
|------|--------------|
| [`function_tool_with_approval.py`](function_tool_with_approval.py) | Requiring human approval before a tool runs. |
| [`function_tool_with_approval_and_sessions.py`](function_tool_with_approval_and_sessions.py) | Tool approvals combined with sessions. |
| [`tool_approval_middleware.py`](tool_approval_middleware.py) | Session-backed approval coordination, mixed-batch approvals, and "always approve" rules. |
| [`function_invocation_configuration.py`](function_invocation_configuration.py) | Configuring function-invocation settings (e.g. max iterations). |
| [`control_total_tool_executions.py`](control_total_tool_executions.py) | All the ways to cap how many times tools run. |
| [`function_tool_with_max_invocations.py`](function_tool_with_max_invocations.py) | Limiting the number of invocations per tool. |
@@ -0,0 +1,191 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from typing import Annotated
from agent_framework import (
Agent,
AgentResponse,
AgentSession,
Content,
Message,
ToolApprovalMiddleware,
create_always_approve_tool_response,
create_always_approve_tool_with_arguments_response,
tool,
)
from agent_framework.foundry import FoundryChatClient
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
"""
This sample demonstrates how a host application can decide which approval
requests may run now, which must be rejected, and which can be remembered for
future runs.
The model may not request every tool on every run. The important part is the
approval mechanism:
1. Tools that are safe to run immediately use ``approval_mode="never_require"``.
2. Sensitive tools use ``approval_mode="always_require"``.
3. ``ToolApprovalMiddleware`` coordinates approval prompts and standing rules.
4. The host turns user policy into ``function_approval_response`` content:
- approve for this request only;
- reject for this request;
- approve and remember the tool for future requests;
- approve and remember the tool only when called again with the same arguments.
5. Heuristic auto-approval rules can approve low-risk function calls before
the user is prompted.
"""
# Load environment variables from .env file
load_dotenv()
@tool(approval_mode="never_require")
def lookup_ticket(ticket_id: Annotated[str, "Support ticket id, for example T-123"]) -> str:
"""Look up a support ticket. This read-only tool runs without approval."""
return f"Ticket {ticket_id}: customer confirmed the issue can be closed."
@tool(approval_mode="always_require")
def close_ticket(
ticket_id: Annotated[str, "Support ticket id, for example T-123"],
resolution: Annotated[str, "Short resolution text"],
) -> str:
"""Close a support ticket."""
return f"Ticket {ticket_id} closed with resolution: {resolution}"
@tool(approval_mode="always_require")
def notify_customer(
ticket_id: Annotated[str, "Support ticket id, for example T-123"],
message: Annotated[str, "Message to send to the customer"],
) -> str:
"""Notify the customer about a ticket update."""
return f"Customer notified for {ticket_id}: {message}"
@tool(approval_mode="always_require")
def add_internal_note(
ticket_id: Annotated[str, "Support ticket id, for example T-123"],
note: Annotated[str, "Internal note text"],
) -> str:
"""Add an internal note to a support ticket."""
return f"Internal note added to {ticket_id}: {note}"
@tool(approval_mode="always_require")
def delete_attachment(
ticket_id: Annotated[str, "Support ticket id, for example T-123"],
attachment_name: Annotated[str, "Attachment file name"],
) -> str:
"""Delete an attachment from a support ticket."""
return f"Deleted {attachment_name} from ticket {ticket_id}."
def auto_approve_low_risk_notes(function_call: Content) -> bool:
"""Heuristic rule: auto-approve short internal notes for the target ticket."""
if function_call.name != "add_internal_note":
return False
arguments = function_call.parse_arguments() or {}
note = str(arguments.get("note", ""))
return arguments.get("ticket_id") == "T-123" and len(note) <= 120
def approval_response_for_user_policy(request: Content) -> Content:
"""Convert user/host policy into an approval response for one tool request."""
function_call = request.function_call
if function_call is None or function_call.name is None:
return request.to_function_approval_response(approved=False)
tool_name = function_call.name
print(f"Approval requested: {tool_name}({function_call.arguments})")
if tool_name in {"close_ticket"}:
print(f"Decision: approve and remember future {tool_name} calls with these exact arguments")
return create_always_approve_tool_with_arguments_response(request)
if tool_name in {"notify_customer"}:
print(f"Decision: approve and remember all future {tool_name} calls")
return create_always_approve_tool_response(request)
if tool_name in {"delete_attachment"}:
print(f"Decision: reject {tool_name} for this run")
return request.to_function_approval_response(approved=False)
print(f"Decision: reject {tool_name}; no policy allowed it")
return request.to_function_approval_response(approved=False)
async def resolve_approval_requests(agent: Agent, response: AgentResponse, session: AgentSession) -> AgentResponse:
"""Resolve approval prompts until the agent returns a regular answer."""
result = response
while result.user_input_requests:
approval_responses = [approval_response_for_user_policy(request) for request in result.user_input_requests]
result = await agent.run(Message(role="user", contents=approval_responses), session=session)
return result
async def main() -> None:
"""Run the tool approval middleware sample."""
# 1. Create a regular chat client.
client = FoundryChatClient(credential=AzureCliCredential())
# 2. Create an agent with sensitive tools and opt-in ToolApprovalMiddleware.
agent = Agent(
client=client,
name="SupportAgent",
instructions=(
"You are a support agent. Use tools when useful. "
"Look up ticket T-123, close it if the customer confirmed, notify the customer, "
"add a short internal note, and do not delete attachments unless the tool is approved."
),
tools=[lookup_ticket, close_ticket, notify_customer, add_internal_note, delete_attachment],
middleware=[ToolApprovalMiddleware(auto_approval_rules=[auto_approve_low_risk_notes])],
)
session = agent.create_session()
# 3. Ask for work that may trigger a mixed batch of safe and sensitive tool calls.
query = (
"Please process ticket T-123: check the ticket, close it as resolved, "
"notify the customer, add a short internal note, and remove debug.log if it is attached."
)
print(f"User: {query}")
result = await agent.run(query, session=session)
# 4. Convert approval requests into approve/reject/always-approve responses.
result = await resolve_approval_requests(agent, result, session)
print(f"Agent: {result.text}")
# 5. Later runs can use remembered approval rules:
# - notify_customer: all future calls to the tool.
# - close_ticket: only future calls with the same arguments.
# - add_internal_note: low-risk matching calls are auto-approved by the heuristic callback.
follow_up = "Send the customer a short follow-up for ticket T-123."
print(f"\nUser: {follow_up}")
result = await agent.run(follow_up, session=session)
result = await resolve_approval_requests(agent, result, session)
print(f"Agent: {result.text}")
if __name__ == "__main__":
asyncio.run(main())
"""
Sample output:
User: Please process ticket T-123: check the ticket, close it as resolved,
notify the customer, add a short internal note, and remove debug.log if it is attached.
Approval requested: close_ticket({"ticket_id": "T-123", "resolution": "resolved"})
Decision: approve and remember future close_ticket calls with these exact arguments
Approval requested: notify_customer({"ticket_id": "T-123", "message": "Your ticket has been resolved."})
Decision: approve and remember all future notify_customer calls
Approval requested: delete_attachment({"ticket_id": "T-123", "attachment_name": "debug.log"})
Decision: reject delete_attachment for this run
Agent: Ticket T-123 was closed, the customer was notified, and a short internal note was added.
I did not delete debug.log.
User: Send the customer a short follow-up for ticket T-123.
Agent: The customer was sent a short follow-up for ticket T-123.
"""