diff --git a/python/packages/core/AGENTS.md b/python/packages/core/AGENTS.md index 92e83c5df4..d12f127a25 100644 --- a/python/packages/core/AGENTS.md +++ b/python/packages/core/AGENTS.md @@ -100,6 +100,23 @@ agent_framework/ - **`FileSearchResult`** / **`FileSearchMatch`** - `SerializationMixin` DTOs returned by `search_files`, carrying the matching file name, a context snippet, and the matching lines with 1-based line numbers. - **`FileAccessProvider`** - `ContextProvider` that adds shared file-access tools (`file_access_save_file`, `file_access_read_file`, `file_access_delete_file`, `file_access_list_files`, `file_access_search_files`) plus default usage instructions to each invocation. Unlike `MemoryContextProvider`, the store is intentionally shared across sessions and agents. +### Tool Approval Harness (`_harness/_tool_approval.py`) + +- **`ToolApprovalMiddleware`** - Experimental opt-in agent middleware that coordinates session-backed approval + rules, heuristic `auto_approval_rules`, queued approval requests, collected approval responses, and + streaming/non-streaming approval prompts. Heuristic callbacks receive the underlying `function_call` content. +- **`ToolApprovalRule`** / **`ToolApprovalState`** - Serializable state models for standing approvals and queued + approval flow. `ToolApprovalRule.arguments is None` means a tool-wide rule; an empty dict `{}` means an exact + no-argument call for `create_always_approve_tool_with_arguments_response`. +- **`create_always_approve_tool_response`** / **`create_always_approve_tool_with_arguments_response`** - Helpers + that return normal `function_approval_response` content with `additional_properties` metadata consumed by + `ToolApprovalMiddleware`. Standing rules for hosted tools include the `server_label` boundary, so same-named tools + on different hosted servers do not share approvals. +- Mixed tool-call batches use a default .NET-style bypass in the function invocation loop: when a session is + available, approval requests for known non-approval-required tools are treated as already approved, hidden, stored + in session state keyed to the visible approval request ids from that batch, and reinjected only when that visible + approval flow resumes. + ### Workflows (`_workflows/`) - **`Workflow`** - Graph-based workflow definition diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index b287b4be57..478adb2dc7 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -125,6 +125,15 @@ from ._harness._todo import ( TodoStore, ) from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool, SamplingApprovalCallback +from ._harness._tool_approval import ( + DEFAULT_TOOL_APPROVAL_SOURCE_ID, + ToolApprovalMiddleware, + ToolApprovalRule, + ToolApprovalRuleCallback, + ToolApprovalState, + create_always_approve_tool_response, + create_always_approve_tool_with_arguments_response, +) from ._middleware import ( AgentContext, AgentMiddleware, @@ -330,6 +339,7 @@ __all__ = [ "DEFAULT_MEMORY_SOURCE_ID", "DEFAULT_MODE_SOURCE_ID", "DEFAULT_TODO_SOURCE_ID", + "DEFAULT_TOOL_APPROVAL_SOURCE_ID", "EXCLUDED_KEY", "EXCLUDE_REASON_KEY", "GROUP_ANNOTATION_KEY", @@ -509,6 +519,10 @@ __all__ = [ "TodoStore", "TokenBudgetComposedStrategy", "TokenizerProtocol", + "ToolApprovalMiddleware", + "ToolApprovalRule", + "ToolApprovalRuleCallback", + "ToolApprovalState", "ToolMode", "ToolResultCompactionStrategy", "ToolTypes", @@ -543,6 +557,8 @@ __all__ = [ "annotate_message_groups", "apply_compaction", "chat_middleware", + "create_always_approve_tool_response", + "create_always_approve_tool_with_arguments_response", "create_edge_runner", "create_harness_agent", "detect_media_type_from_base64", diff --git a/python/packages/core/agent_framework/_harness/_tool_approval.py b/python/packages/core/agent_framework/_harness/_tool_approval.py new file mode 100644 index 0000000000..56595278d0 --- /dev/null +++ b/python/packages/core/agent_framework/_harness/_tool_approval.py @@ -0,0 +1,632 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import copy +import inspect +import json +from asyncio import sleep +from collections.abc import AsyncIterable, Awaitable, Callable, Iterable, Mapping, MutableMapping, Sequence +from typing import Any, Literal, cast + +from .._feature_stage import ExperimentalFeature, experimental +from .._middleware import AgentContext, AgentMiddleware +from .._serialization import SerializationMixin +from .._sessions import AgentSession +from .._types import ( + AgentResponse, + AgentResponseUpdate, + Content, + FinishReason, + FinishReasonLiteral, + Message, + ResponseStream, +) + +DEFAULT_TOOL_APPROVAL_SOURCE_ID = "tool_approval" +_FUNCTION_INVOCATION_BUDGET_STATE_KEY = "_function_invocation_budget_state" +ALWAYS_APPROVE_PROPERTY = "tool_approval" +ALWAYS_APPROVE_SCOPE_PROPERTY = "always_approve" +ALWAYS_APPROVE_TOOL: Literal["tool"] = "tool" +ALWAYS_APPROVE_TOOL_WITH_ARGUMENTS: Literal["tool_with_arguments"] = "tool_with_arguments" + +_RULES_KEY = "rules" +_QUEUED_APPROVAL_REQUESTS_KEY = "queued_approval_requests" +_COLLECTED_APPROVAL_RESPONSES_KEY = "collected_approval_responses" + +ToolApprovalScope = Literal["tool", "tool_with_arguments"] +ToolApprovalRuleCallback = Callable[[Content], bool | Awaitable[bool]] + + +def _parse_function_arguments(function_call: Content) -> dict[str, Any]: + arguments = function_call.parse_arguments() + return dict(arguments or {}) + + +def _serialize_argument_value(value: Any) -> str: + return json.dumps(value, sort_keys=True, separators=(",", ":"), default=str) + + +def _serialize_arguments(function_call: Content) -> dict[str, str]: + """Serialize arguments for exact matching. + + ``None`` is reserved on :class:`ToolApprovalRule` for tool-wide rules. + An argument-scoped rule for a no-argument call stores ``{}``, so it only + matches future no-argument calls and never becomes a wildcard. + """ + arguments = _parse_function_arguments(function_call) + return {key: _serialize_argument_value(value) for key, value in arguments.items()} + + +def _server_label(function_call: Content) -> str | None: + """Return the hosted-tool server boundary for a function call, if present.""" + value = function_call.additional_properties.get("server_label") + return value if isinstance(value, str) else None + + +def _content_from_state(value: Any) -> Content: + if isinstance(value, Content): + return value + if isinstance(value, Mapping): + return Content.from_dict(cast(Mapping[str, Any], value)) + raise TypeError(f"Expected Content or mapping state item, got {type(value).__name__}.") + + +def _contents_from_state(values: Any) -> list[Content]: + if not isinstance(values, list): + return [] + state_items = list(cast(Iterable[Any], values)) + return [_content_from_state(value) for value in state_items] + + +def _content_to_state(content: Content) -> dict[str, Any]: + return content.to_dict() + + +@experimental(feature_id=ExperimentalFeature.HARNESS) +class ToolApprovalRule(SerializationMixin): + """A standing rule for approving future matching tool calls.""" + + tool_name: str + arguments: dict[str, str] | None + server_label: str | None + + def __init__( + self, + tool_name: str, + arguments: Mapping[str, str] | None = None, + *, + server_label: str | None = None, + ) -> None: + """Initialize a tool approval rule. + + Args: + tool_name: The function tool name this rule applies to. + arguments: Optional canonicalized argument values. When omitted, the + rule applies to every call to the tool. Use an empty mapping to + match only no-argument calls. + + Keyword Args: + server_label: Optional hosted-tool server boundary. Hosted approvals + only match future approvals from the same server label. + """ + normalized_name = tool_name.strip() + if not normalized_name: + raise ValueError("Tool approval rule tool_name must be a non-empty string.") + self.tool_name = normalized_name + self.arguments = dict(arguments) if arguments is not None else None + self.server_label = server_label + + @classmethod + def from_dict( + cls, + value: MutableMapping[str, Any], + /, + *, + dependencies: MutableMapping[str, Any] | None = None, + ) -> ToolApprovalRule: + """Create a rule from serialized state.""" + del dependencies + tool_name = value.get("tool_name") + if not isinstance(tool_name, str): + raise ValueError("Tool approval rule tool_name must be a string.") + raw_arguments = value.get("arguments") + if raw_arguments is not None and not isinstance(raw_arguments, Mapping): + raise ValueError("Tool approval rule arguments must be a mapping or None.") + server_label = value.get("server_label") + if server_label is not None and not isinstance(server_label, str): + raise ValueError("Tool approval rule server_label must be a string or None.") + arguments = ( + {str(key): str(argument_value) for key, argument_value in cast(Mapping[str, Any], raw_arguments).items()} + if isinstance(raw_arguments, Mapping) + else None + ) + return cls(tool_name=tool_name, arguments=arguments, server_label=server_label) + + def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: + """Serialize the rule.""" + exclude = exclude or set() + payload: dict[str, Any] = {"tool_name": self.tool_name} + if "type" not in exclude: + payload["type"] = self._get_type_identifier() + if self.arguments is not None or not exclude_none: + payload["arguments"] = self.arguments + if self.server_label is not None or not exclude_none: + payload["server_label"] = self.server_label + return payload + + +@experimental(feature_id=ExperimentalFeature.HARNESS) +class ToolApprovalState(SerializationMixin): + """Session-backed state used by :class:`ToolApprovalMiddleware`.""" + + rules: list[ToolApprovalRule] + queued_approval_requests: list[Content] + collected_approval_responses: list[Content] + + def __init__( + self, + *, + rules: Sequence[ToolApprovalRule | Mapping[str, Any]] | None = None, + queued_approval_requests: Sequence[Content | Mapping[str, Any]] | None = None, + collected_approval_responses: Sequence[Content | Mapping[str, Any]] | None = None, + ) -> None: + """Initialize approval state.""" + self.rules = [ + rule if isinstance(rule, ToolApprovalRule) else ToolApprovalRule.from_dict(dict(rule)) + for rule in (rules or []) + ] + self.queued_approval_requests = [ + item if isinstance(item, Content) else Content.from_dict(item) for item in (queued_approval_requests or []) + ] + self.collected_approval_responses = [ + item if isinstance(item, Content) else Content.from_dict(item) + for item in (collected_approval_responses or []) + ] + + @classmethod + def from_dict( + cls, + value: MutableMapping[str, Any], + /, + *, + dependencies: MutableMapping[str, Any] | None = None, + ) -> ToolApprovalState: + """Create state from serialized state.""" + del dependencies + return cls( + rules=cast(Sequence[Mapping[str, Any]], value.get("rules", [])), + queued_approval_requests=cast(Sequence[Mapping[str, Any]], value.get("queued_approval_requests", [])), + collected_approval_responses=cast( + Sequence[Mapping[str, Any]], + value.get("collected_approval_responses", []), + ), + ) + + def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: + """Serialize state.""" + del exclude_none + exclude = exclude or set() + payload: dict[str, Any] = { + "rules": [rule.to_dict() for rule in self.rules], + "queued_approval_requests": [_content_to_state(item) for item in self.queued_approval_requests], + "collected_approval_responses": [_content_to_state(item) for item in self.collected_approval_responses], + } + if "type" not in exclude: + payload["type"] = self._get_type_identifier() + return payload + + +def create_always_approve_tool_response(request: Content, *, reason: str | None = None) -> Content: + """Create an approval response that records a standing rule for the whole tool. + + Args: + request: The ``function_approval_request`` content to approve. + + Keyword Args: + reason: Optional approval reason stored in ``additional_properties``. + + Returns: + A ``function_approval_response`` with metadata consumed by + :class:`ToolApprovalMiddleware`. + """ + return _create_always_approve_response(request, ALWAYS_APPROVE_TOOL, reason=reason) + + +def create_always_approve_tool_with_arguments_response(request: Content, *, reason: str | None = None) -> Content: + """Create an approval response that records a standing rule for the tool and exact arguments.""" + return _create_always_approve_response(request, ALWAYS_APPROVE_TOOL_WITH_ARGUMENTS, reason=reason) + + +def _create_always_approve_response(request: Content, scope: ToolApprovalScope, *, reason: str | None) -> Content: + response = request.to_function_approval_response(approved=True) + metadata: dict[str, Any] = {ALWAYS_APPROVE_SCOPE_PROPERTY: scope} + if reason is not None: + metadata["reason"] = reason + response.additional_properties[ALWAYS_APPROVE_PROPERTY] = metadata + return response + + +def _get_state(session: AgentSession, *, source_id: str) -> ToolApprovalState: + raw_state = session.state.get(source_id) + if isinstance(raw_state, ToolApprovalState): + return raw_state + if isinstance(raw_state, MutableMapping): + raw_state_mapping = cast(MutableMapping[str, Any], raw_state) + return ToolApprovalState( + rules=cast(Sequence[Mapping[str, Any]], raw_state_mapping.get(_RULES_KEY, [])), + queued_approval_requests=_contents_from_state(raw_state_mapping.get(_QUEUED_APPROVAL_REQUESTS_KEY, [])), + collected_approval_responses=_contents_from_state( + raw_state_mapping.get(_COLLECTED_APPROVAL_RESPONSES_KEY, []), + ), + ) + if raw_state is not None: + raise TypeError(f"Session state for {source_id!r} must be a mapping, got {type(raw_state).__name__}.") + state = ToolApprovalState() + session.state[source_id] = state.to_dict(exclude={"type"}) + return state + + +def _save_state(session: AgentSession, state: ToolApprovalState, *, source_id: str) -> None: + serialized = state.to_dict(exclude={"type"}) + existing = session.state.get(source_id) + if isinstance(existing, MutableMapping): + for key, value in cast(MutableMapping[str, Any], existing).items(): + if key not in serialized and key != "type": + serialized[key] = value + session.state[source_id] = serialized + + +def _rule_exists(rules: Sequence[ToolApprovalRule], new_rule: ToolApprovalRule) -> bool: + for rule in rules: + if rule.tool_name != new_rule.tool_name: + continue + if rule.server_label != new_rule.server_label: + continue + if rule.arguments == new_rule.arguments: + return True + return False + + +def _add_rule_if_missing(state: ToolApprovalState, rule: ToolApprovalRule) -> None: + if not _rule_exists(state.rules, rule): + state.rules.append(rule) + + +def _function_call_from_request(request: Content) -> Content | None: + function_call = request.function_call + if function_call is None or function_call.type != "function_call" or function_call.name is None: + return None + return function_call + + +def _arguments_match(rule_arguments: Mapping[str, str], function_call: Content) -> bool: + call_arguments = _serialize_arguments(function_call) or {} + if len(rule_arguments) != len(call_arguments): + return False + return all(call_arguments.get(key) == value for key, value in rule_arguments.items()) + + +def _matches_rule(request: Content, rules: Sequence[ToolApprovalRule]) -> bool: + function_call = _function_call_from_request(request) + if function_call is None: + return False + for rule in rules: + if rule.tool_name != function_call.name: + continue + if rule.server_label != _server_label(function_call): + continue + if rule.arguments is None: + return True + if _arguments_match(rule.arguments, function_call): + return True + return False + + +def _get_always_approve_scope(response: Content) -> ToolApprovalScope | None: + metadata = response.additional_properties.get(ALWAYS_APPROVE_PROPERTY) + if not isinstance(metadata, Mapping): + return None + metadata_mapping = cast(Mapping[str, Any], metadata) + scope = metadata_mapping.get(ALWAYS_APPROVE_SCOPE_PROPERTY) + if scope == ALWAYS_APPROVE_TOOL: + return ALWAYS_APPROVE_TOOL + if scope == ALWAYS_APPROVE_TOOL_WITH_ARGUMENTS: + return ALWAYS_APPROVE_TOOL_WITH_ARGUMENTS + return None + + +def _clone_without_always_approve_metadata(response: Content) -> Content: + cloned = copy.deepcopy(response) + cloned.additional_properties.pop(ALWAYS_APPROVE_PROPERTY, None) + return cloned + + +@experimental(feature_id=ExperimentalFeature.HARNESS) +class ToolApprovalMiddleware(AgentMiddleware): + """Coordinate standing tool approvals and queued approval prompts for an agent. + + This middleware is opt-in and requires callers to run the agent with an + :class:`AgentSession`, because approval rules and queued requests are stored + in session state. + """ + + def __init__( + self, + *, + source_id: str = DEFAULT_TOOL_APPROVAL_SOURCE_ID, + auto_approval_rules: Sequence[ToolApprovalRuleCallback] | None = None, + ) -> None: + """Initialize the middleware. + + Keyword Args: + source_id: Session-state key used by this middleware. + auto_approval_rules: Optional callbacks that can auto-approve a + ``function_call``. Each callback receives the function-call + content and returns ``True`` to approve it. + """ + self.source_id = source_id + self.auto_approval_rules = tuple(auto_approval_rules or ()) + + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + """Process one agent invocation.""" + if context.session is None: + raise RuntimeError("ToolApprovalMiddleware requires an AgentSession.") + + state = _get_state(context.session, source_id=self.source_id) + context.client_kwargs.setdefault(_FUNCTION_INVOCATION_BUDGET_STATE_KEY, {}) + context.messages = self._prepare_inbound_messages(context.messages, state) + await self._drain_auto_approvable_queue(state) + if next_queued := self._pop_next_queued_request(state): + _save_state(context.session, state, source_id=self.source_id) + context.result = self._response_for_queued_request(next_queued, stream=context.stream) + return + if context.stream: + context.result = self._process_stream(context, call_next, state) + return + + while True: + context.messages = self._inject_collected_responses(context.messages, state) + state_changed = bool(state.collected_approval_responses) + state.collected_approval_responses.clear() + if state_changed: + _save_state(context.session, state, source_id=self.source_id) + + await call_next() + if isinstance(context.result, ResponseStream): + return + if context.result is None: + _save_state(context.session, state, source_id=self.source_id) + return + + all_auto_approved = await self._process_outbound_messages(context.result.messages, state) + _save_state(context.session, state, source_id=self.source_id) + if not all_auto_approved: + return + context.messages = [] + context.result = None + + def _response_for_queued_request( + self, + request: Content, + *, + stream: bool, + ) -> AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse]: + if not stream: + return AgentResponse(messages=[Message(role="assistant", contents=[request])]) + + async def _stream() -> AsyncIterable[AgentResponseUpdate]: + await sleep(0) + yield AgentResponseUpdate(role="assistant", contents=[request]) + + return ResponseStream(_stream(), finalizer=AgentResponse.from_updates) + + def _process_stream( + self, + context: AgentContext, + call_next: Callable[[], Awaitable[None]], + state: ToolApprovalState, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + async def _stream() -> AsyncIterable[AgentResponseUpdate]: + if context.session is None: + raise RuntimeError("ToolApprovalMiddleware requires an AgentSession.") + while True: + context.messages = self._inject_collected_responses(context.messages, state) + state_changed = bool(state.collected_approval_responses) + state.collected_approval_responses.clear() + if state_changed: + _save_state(context.session, state, source_id=self.source_id) + + await call_next() + if not isinstance(context.result, ResponseStream): + raise ValueError("Streaming ToolApprovalMiddleware requires a ResponseStream result.") + + approval_requests: list[Content] = [] + async for update in context.result: + approval_contents = [ + content for content in update.contents if content.type == "function_approval_request" + ] + if not approval_contents: + yield update + continue + approval_requests.extend(approval_contents) + remaining_contents = [ + content for content in update.contents if content.type != "function_approval_request" + ] + if remaining_contents: + raw_finish_reason = update.finish_reason + finish_reason: FinishReasonLiteral | FinishReason | None + if isinstance(raw_finish_reason, str): + finish_reason = FinishReason(raw_finish_reason) + else: + finish_reason = cast(FinishReasonLiteral | FinishReason | None, raw_finish_reason) + yield AgentResponseUpdate( + contents=remaining_contents, + role=update.role, + author_name=update.author_name, + agent_id=update.agent_id, + response_id=update.response_id, + message_id=update.message_id, + created_at=update.created_at, + finish_reason=finish_reason, + continuation_token=update.continuation_token, + additional_properties=update.additional_properties, + raw_representation=update.raw_representation, + ) + await context.result.get_final_response() + if not approval_requests: + return + + response_messages = [Message(role="assistant", contents=approval_requests)] + all_auto_approved = await self._process_outbound_messages(response_messages, state) + _save_state(context.session, state, source_id=self.source_id) + if not all_auto_approved: + for message in response_messages: + if message.contents: + yield AgentResponseUpdate(role=message.role, contents=message.contents) + return + context.messages = [] + context.result = None + + return ResponseStream(_stream(), finalizer=AgentResponse.from_updates) + + def _prepare_inbound_messages(self, messages: Sequence[Message], state: ToolApprovalState) -> list[Message]: + prepared: list[Message] = [] + for message in messages: + replacement_contents: list[Content] = [] + changed = False + for content in message.contents: + if content.type == "function_approval_response": + replacement = self._handle_inbound_approval_response(content, state) + state.collected_approval_responses.append(replacement) + changed = True + continue + replacement_contents.append(content) + + if not changed: + prepared.append(message) + continue + if replacement_contents: + cloned = copy.copy(message) + cloned.contents = replacement_contents + prepared.append(cloned) + return prepared + + def _handle_inbound_approval_response(self, response: Content, state: ToolApprovalState) -> Content: + scope = _get_always_approve_scope(response) + if scope is None or not response.approved: + return response + + function_call = response.function_call + if function_call is not None and function_call.type == "function_call" and function_call.name is not None: + if scope == ALWAYS_APPROVE_TOOL: + _add_rule_if_missing( + state, + ToolApprovalRule( + tool_name=function_call.name, + server_label=_server_label(function_call), + ), + ) + else: + _add_rule_if_missing( + state, + ToolApprovalRule( + tool_name=function_call.name, + arguments=_serialize_arguments(function_call), + server_label=_server_label(function_call), + ), + ) + return _clone_without_always_approve_metadata(response) + + def _inject_collected_responses(self, messages: Sequence[Message], state: ToolApprovalState) -> list[Message]: + if not state.collected_approval_responses: + return list(messages) + return [Message(role="user", contents=list(state.collected_approval_responses)), *messages] + + async def _drain_auto_approvable_queue(self, state: ToolApprovalState) -> None: + remaining: list[Content] = [] + for request in state.queued_approval_requests: + if _matches_rule(request, state.rules) or await self._matches_auto_rule(request): + state.collected_approval_responses.append(request.to_function_approval_response(approved=True)) + continue + remaining.append(request) + state.queued_approval_requests = remaining + + def _pop_next_queued_request(self, state: ToolApprovalState) -> Content | None: + if not state.queued_approval_requests: + return None + return state.queued_approval_requests.pop(0) + + async def _process_outbound_messages(self, messages: list[Message], state: ToolApprovalState) -> bool: + approval_requests = [ + content + for message in messages + for content in message.contents + if content.type == "function_approval_request" + ] + if not approval_requests: + return False + + auto_approved: set[int] = set() + unresolved: list[Content] = [] + for request in approval_requests: + if _matches_rule(request, state.rules) or await self._matches_auto_rule(request): + state.collected_approval_responses.append(request.to_function_approval_response(approved=True)) + auto_approved.add(id(request)) + else: + unresolved.append(request) + + if not auto_approved and len(unresolved) <= 1: + return False + + queued_ids: set[int] = set() + for request in unresolved[1:]: + queued_ids.add(id(request)) + state.queued_approval_requests.append(request) + + remove_ids = auto_approved | queued_ids + self._remove_approval_requests(messages, remove_ids) + return not unresolved + + @staticmethod + def _remove_approval_requests(messages: list[Message], remove_ids: set[int]) -> None: + for message_index in range(len(messages) - 1, -1, -1): + message = messages[message_index] + filtered = [ + content + for content in message.contents + if content.type != "function_approval_request" or id(content) not in remove_ids + ] + if len(filtered) == len(message.contents): + continue + if filtered: + message.contents = filtered + else: + messages.pop(message_index) + + async def _matches_auto_rule(self, request: Content) -> bool: + function_call = _function_call_from_request(request) + if function_call is None: + return False + for rule in self.auto_approval_rules: + result = rule(function_call) + if inspect.isawaitable(result): + result = await result + if result: + return True + return False + + +__all__ = [ + "ALWAYS_APPROVE_PROPERTY", + "ALWAYS_APPROVE_SCOPE_PROPERTY", + "ALWAYS_APPROVE_TOOL", + "ALWAYS_APPROVE_TOOL_WITH_ARGUMENTS", + "DEFAULT_TOOL_APPROVAL_SOURCE_ID", + "ToolApprovalMiddleware", + "ToolApprovalRule", + "ToolApprovalRuleCallback", + "ToolApprovalState", + "create_always_approve_tool_response", + "create_always_approve_tool_with_arguments_response", +] diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 7bb54ee2c9..ad232ffeb4 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -90,6 +90,9 @@ logger = logging.getLogger("agent_framework") DEFAULT_MAX_ITERATIONS: Final[int] = 40 DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST: Final[int] = 3 SHELL_TOOL_KIND_VALUE: Final[str] = "shell" +_TOOL_APPROVAL_STATE_KEY: Final[str] = "tool_approval" +_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY: Final[str] = "already_approved_approval_request_groups" +_FUNCTION_INVOCATION_BUDGET_STATE_KEY: Final[str] = "_function_invocation_budget_state" ApprovalMode: TypeAlias = Literal["always_require", "never_require"] ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]") ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel) @@ -1685,15 +1688,15 @@ async def _try_execute_function_calls( # The live tools list (when tools is the run-local list) is exposed on the # FunctionInvocationContext so tools can add/remove tools during the run. live_tools: list[ToolTypes] | None = cast("list[ToolTypes]", tools) if isinstance(tools, list) else None - approval_tools = [tool_name for tool_name, tool in tool_map.items() if tool.approval_mode == "always_require"] + approval_tools = {tool_name for tool_name, tool in tool_map.items() if tool.approval_mode == "always_require"} logger.debug( "_try_execute_function_calls: tool_map keys=%s, approval_tools=%s", list(tool_map.keys()), approval_tools, ) - declaration_only = [tool_name for tool_name, tool in tool_map.items() if tool.declaration_only] + declaration_only = {tool_name for tool_name, tool in tool_map.items() if tool.declaration_only} configured_additional_tools = config.get("additional_tools") or [] - additional_tool_names = [tool.name for tool in configured_additional_tools] + additional_tool_names = {tool.name for tool in configured_additional_tools} # check if any are calling functions that need approval # if so, we return approval request for all approval_needed = False @@ -1719,15 +1722,39 @@ async def _try_execute_function_calls( raise KeyError(f'Error: Requested function "{fcc.name}" not found.') # type: ignore[attr-defined] if approval_needed: # approval can only be needed for Function Call Content, not Approval Responses. - logger.debug("Returning function_approval_request contents") - return ( - [ - Content.from_function_approval_request(id=fcc.call_id, function_call=fcc) # type: ignore[attr-defined, arg-type] - for fcc in function_calls - if fcc.type == "function_call" - ], - False, + logger.debug("Returning visible function_approval_request contents and storing already-approved requests") + visible_requests: list[Content] = [] + already_approved_requests: list[Content] = [] + for fcc in function_calls: + if fcc.type != "function_call": + continue + approval_request = Content.from_function_approval_request( + id=fcc.call_id, # type: ignore[arg-type] + function_call=fcc, + ) + tool_name = fcc.name # type: ignore[attr-defined] + if tool_name is None: + visible_requests.append(approval_request) + continue + tool = tool_map.get(tool_name) + if ( + tool_name in approval_tools + or tool is None + or tool_name in declaration_only + or tool_name in additional_tool_names + ): + visible_requests.append(approval_request) + continue + if invocation_session is None: + visible_requests.append(approval_request) + continue + already_approved_requests.append(approval_request) + _store_already_approved_approval_requests( + invocation_session, + visible_requests, + already_approved_requests, ) + return (visible_requests, False) if declaration_only_flag: # return the declaration only tools to the user, since we cannot execute them. # Mark as user_input_request so AgentExecutor emits request_info events and pauses the workflow. @@ -1912,6 +1939,108 @@ def _is_hosted_tool_approval(content: Any) -> bool: return bool(ap and ap.get("server_label")) +def _get_tool_approval_state(invocation_session: AgentSession | None) -> dict[str, Any] | None: + """Return the shared tool-approval state bag for the invocation session.""" + if invocation_session is None: + return None + raw_state = invocation_session.state.get(_TOOL_APPROVAL_STATE_KEY) + if isinstance(raw_state, dict): + return cast(dict[str, Any], raw_state) + from ._harness._tool_approval import ToolApprovalState + + if isinstance(raw_state, ToolApprovalState): + serialized_state = raw_state.to_dict(exclude={"type"}) + invocation_session.state[_TOOL_APPROVAL_STATE_KEY] = serialized_state + return serialized_state + if raw_state is not None: + raise TypeError( + f"Session state for {_TOOL_APPROVAL_STATE_KEY!r} must be a dict or ToolApprovalState, " + f"got {type(raw_state).__name__}." + ) + new_state: dict[str, Any] = {} + invocation_session.state[_TOOL_APPROVAL_STATE_KEY] = new_state + return new_state + + +def _content_from_state(value: Any) -> Content | None: + """Restore a Content item stored in session state.""" + from ._types import Content + + if isinstance(value, Content): + return value + if isinstance(value, Mapping): + return Content.from_dict(cast(Mapping[str, Any], value)) + return None + + +def _store_already_approved_approval_requests( + invocation_session: AgentSession | None, + visible_approval_requests: Sequence[Content], + already_approved_requests: Sequence[Content], +) -> None: + """Store hidden already-approved requests keyed by the visible approvals that resume the batch.""" + if not already_approved_requests: + return + state = _get_tool_approval_state(invocation_session) + if state is None: + return + visible_ids = [request.id for request in visible_approval_requests if request.id] + if not visible_ids: + return + + existing_groups = state.get(_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY) + pending_groups: list[Any] = ( + list(cast(Iterable[Any], existing_groups)) if isinstance(existing_groups, list) else [] + ) + pending_groups.append({ + "approval_request_ids": visible_ids, + "approval_requests": [request.to_dict() for request in already_approved_requests], + }) + state[_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY] = pending_groups + + +def _pop_already_approved_approval_responses( + invocation_session: AgentSession | None, + approval_response_ids: set[str], +) -> list[Content]: + """Pop already-approved requests for the visible approval ids being answered.""" + if not approval_response_ids: + return [] + state = _get_tool_approval_state(invocation_session) + if state is None: + return [] + raw_groups = state.get(_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY, []) + if not isinstance(raw_groups, list): + return [] + + responses: list[Content] = [] + remaining_groups: list[Any] = [] + raw_group_items = list(cast(Iterable[Any], raw_groups)) + for raw_group in raw_group_items: + if not isinstance(raw_group, Mapping): + continue + group = cast(Mapping[str, Any], raw_group) + raw_ids = group.get("approval_request_ids") + raw_group_ids: Iterable[Any] = cast(Iterable[Any], raw_ids) if isinstance(raw_ids, list) else () + group_ids = {str(item) for item in raw_group_ids} + if group_ids.isdisjoint(approval_response_ids): + remaining_groups.append(raw_group) + continue + raw_requests = group.get("approval_requests") + if not isinstance(raw_requests, list): + continue + for raw_request in list(cast(Iterable[Any], raw_requests)): + request = _content_from_state(raw_request) + if request is None or request.type != "function_approval_request": + continue + responses.append(request.to_function_approval_response(approved=True)) + if remaining_groups: + state[_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY] = remaining_groups + else: + state.pop(_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY, None) + return responses + + def _collect_approval_responses( messages: list[Message], ) -> dict[str, Content]: @@ -2157,8 +2286,24 @@ async def _process_function_requests( errors_in_a_row: int, max_errors: int, execute_function_calls: Callable[..., Awaitable[tuple[list[Content], bool, bool]]], + invocation_session: AgentSession | None = None, ) -> FunctionRequestResult: + from ._types import Message + if prepped_messages is not None: + explicit_approval_response_ids = { + content.id + for message in prepped_messages + if isinstance(message, Message) + for content in message.contents + if content.type == "function_approval_response" and content.id + } + already_approved_responses = _pop_already_approved_approval_responses( + invocation_session, + explicit_approval_response_ids, + ) + if already_approved_responses: + prepped_messages.append(Message(role="user", contents=already_approved_responses)) fcc_todo = _collect_approval_responses(prepped_messages) if not fcc_todo: fcc_todo = {} @@ -2362,6 +2507,10 @@ class FunctionInvocationLayer(Generic[OptionsCoT]): function_middleware_pipeline = self._get_function_middleware_pipeline(runtime_middleware["function"]) if runtime_middleware["chat"]: effective_client_kwargs["middleware"] = runtime_middleware["chat"] + raw_budget_state = effective_client_kwargs.pop(_FUNCTION_INVOCATION_BUDGET_STATE_KEY, None) + budget_state: dict[str, Any] = ( + cast(dict[str, Any], raw_budget_state) if isinstance(raw_budget_state, dict) else {} + ) max_errors = self.function_invocation_configuration.get( "max_consecutive_errors_per_request", DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST ) @@ -2411,7 +2560,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]): nonlocal mutable_options nonlocal filtered_kwargs errors_in_a_row: int = 0 - total_function_calls: int = 0 + total_function_calls = int(budget_state.get("total_function_calls", 0) or 0) max_function_calls: int | None = self.function_invocation_configuration.get("max_function_calls") prepped_messages = list(messages) fcc_messages: list[Message] = [] @@ -2420,7 +2569,9 @@ class FunctionInvocationLayer(Generic[OptionsCoT]): loop_enabled = self.function_invocation_configuration.get("enabled", True) max_iterations = self.function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS) - for attempt_idx in range(max_iterations if loop_enabled else 0): + attempt_start = int(budget_state.get("attempt_count", 0) or 0) + for attempt_idx in range(attempt_start, max_iterations if loop_enabled else 0): + budget_state["attempt_count"] = attempt_idx + 1 approval_result = await _process_function_requests( response=None, prepped_messages=prepped_messages, @@ -2430,12 +2581,21 @@ class FunctionInvocationLayer(Generic[OptionsCoT]): errors_in_a_row=errors_in_a_row, max_errors=max_errors, execute_function_calls=execute_function_calls, + invocation_session=invocation_session, ) if approval_result.get("action") == "stop": response = ChatResponse(messages=prepped_messages) break errors_in_a_row = approval_result.get("errors_in_a_row", errors_in_a_row) total_function_calls += approval_result.get("function_call_count", 0) + budget_state["total_function_calls"] = total_function_calls + if max_function_calls is not None and total_function_calls >= max_function_calls: + logger.info( + "Maximum function calls reached (%d/%d). Stopping further function calls for this request.", + total_function_calls, + max_function_calls, + ) + mutable_options["tool_choice"] = "none" response = cast( ChatResponse[Any], @@ -2468,11 +2628,13 @@ class FunctionInvocationLayer(Generic[OptionsCoT]): errors_in_a_row=errors_in_a_row, max_errors=max_errors, execute_function_calls=execute_function_calls, + invocation_session=invocation_session, ) if result.get("action") == "return": response.usage_details = aggregated_usage return _clear_internal_conversation_id(response) total_function_calls += result.get("function_call_count", 0) + budget_state["total_function_calls"] = total_function_calls if result.get("action") == "stop": # Error threshold reached: force a final non-tool turn so # function_call_output items are submitted before exit. @@ -2549,7 +2711,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]): nonlocal mutable_options nonlocal stream_result_hooks errors_in_a_row: int = 0 - total_function_calls: int = 0 + total_function_calls = int(budget_state.get("total_function_calls", 0) or 0) max_function_calls: int | None = self.function_invocation_configuration.get("max_function_calls") prepped_messages = list(messages) fcc_messages: list[Message] = [] @@ -2557,7 +2719,9 @@ class FunctionInvocationLayer(Generic[OptionsCoT]): loop_enabled = self.function_invocation_configuration.get("enabled", True) max_iterations = self.function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS) - for attempt_idx in range(max_iterations if loop_enabled else 0): + attempt_start = int(budget_state.get("attempt_count", 0) or 0) + for attempt_idx in range(attempt_start, max_iterations if loop_enabled else 0): + budget_state["attempt_count"] = attempt_idx + 1 approval_result = await _process_function_requests( response=None, prepped_messages=prepped_messages, @@ -2567,9 +2731,18 @@ class FunctionInvocationLayer(Generic[OptionsCoT]): errors_in_a_row=errors_in_a_row, max_errors=max_errors, execute_function_calls=execute_function_calls, + invocation_session=invocation_session, ) errors_in_a_row = approval_result.get("errors_in_a_row", errors_in_a_row) total_function_calls += approval_result.get("function_call_count", 0) + budget_state["total_function_calls"] = total_function_calls + if max_function_calls is not None and total_function_calls >= max_function_calls: + logger.info( + "Maximum function calls reached (%d/%d). Stopping further function calls for this request.", + total_function_calls, + max_function_calls, + ) + mutable_options["tool_choice"] = "none" if approval_result.get("action") == "stop": mutable_options["tool_choice"] = "none" return @@ -2622,9 +2795,11 @@ class FunctionInvocationLayer(Generic[OptionsCoT]): errors_in_a_row=errors_in_a_row, max_errors=max_errors, execute_function_calls=execute_function_calls, + invocation_session=invocation_session, ) errors_in_a_row = result.get("errors_in_a_row", errors_in_a_row) total_function_calls += result.get("function_call_count", 0) + budget_state["total_function_calls"] = total_function_calls if role := result.get("update_role"): yield ChatResponseUpdate( contents=result.get("function_call_results") or [], diff --git a/python/packages/core/tests/core/test_harness_tool_approval.py b/python/packages/core/tests/core/test_harness_tool_approval.py new file mode 100644 index 0000000000..9bc03c839e --- /dev/null +++ b/python/packages/core/tests/core/test_harness_tool_approval.py @@ -0,0 +1,817 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +from agent_framework import ( + DEFAULT_TOOL_APPROVAL_SOURCE_ID, + Agent, + AgentSession, + ChatResponse, + ChatResponseUpdate, + Content, + Message, + SupportsChatGetResponse, + ToolApprovalMiddleware, + ToolApprovalState, + create_always_approve_tool_response, + create_always_approve_tool_with_arguments_response, + tool, +) + + +def _approval_requests(messages: list[Message]) -> list[Content]: + return [ + content for message in messages for content in message.contents if content.type == "function_approval_request" + ] + + +async def test_mixed_batch_hides_already_approved_request_until_approval_replay( + chat_client_base: SupportsChatGetResponse, +) -> None: + """Mixed batches should only show real approval requests when a session can store hidden requests.""" + no_approval_calls = 0 + approval_calls = 0 + + @tool(name="lookup_work_items", approval_mode="never_require") + def lookup_work_items(query: str) -> str: + nonlocal no_approval_calls + no_approval_calls += 1 + return f"found {query}" + + @tool(name="add_comment", approval_mode="always_require") + def add_comment(comment: str) -> str: + nonlocal approval_calls + approval_calls += 1 + return f"added {comment}" + + agent = Agent(client=chat_client_base, tools=[lookup_work_items, add_comment]) + session = AgentSession(session_id="approval-session") + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_lookup", + name="lookup_work_items", + arguments='{"query": "mine"}', + ), + Content.from_function_call( + call_id="call_comment", + name="add_comment", + arguments='{"comment": "done"}', + ), + ], + ) + ) + ] + + first_response = await agent.run("update work item", session=session) + + requests = _approval_requests(first_response.messages) + assert [request.function_call.name for request in requests] == ["add_comment"] + assert no_approval_calls == 0 + assert approval_calls == 0 + + chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["complete"]))] + second_response = await agent.run(requests[0].to_function_approval_response(approved=True), session=session) + + assert second_response.text == "complete" + assert no_approval_calls == 1 + assert approval_calls == 1 + + +async def test_mixed_batch_accepts_restored_tool_approval_state( + chat_client_base: SupportsChatGetResponse, +) -> None: + """Mixed-batch bypass should work when session state contains ToolApprovalState.""" + safe_calls = 0 + risky_calls = 0 + + @tool(name="safe_read", approval_mode="never_require") + def safe_read() -> str: + nonlocal safe_calls + safe_calls += 1 + return "safe" + + @tool(name="risky_write", approval_mode="always_require") + def risky_write() -> str: + nonlocal risky_calls + risky_calls += 1 + return "risky" + + agent = Agent(client=chat_client_base, tools=[safe_read, risky_write]) + session = AgentSession(session_id="restored-state-session") + session.state[DEFAULT_TOOL_APPROVAL_SOURCE_ID] = ToolApprovalState() + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="call_safe", name="safe_read", arguments="{}"), + Content.from_function_call(call_id="call_risky", name="risky_write", arguments="{}"), + ], + ) + ) + ] + + first_response = await agent.run("read and write", session=session) + requests = _approval_requests(first_response.messages) + + assert [request.function_call.name for request in requests] == ["risky_write"] + assert safe_calls == 0 + assert risky_calls == 0 + + chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["done"]))] + final_response = await agent.run(requests[0].to_function_approval_response(approved=True), session=session) + + assert final_response.text == "done" + assert safe_calls == 1 + assert risky_calls == 1 + + +async def test_hidden_mixed_batch_requests_do_not_replay_on_unrelated_turn( + chat_client_base: SupportsChatGetResponse, +) -> None: + """Stored hidden approvals should only replay when an approval response resumes the flow.""" + safe_calls = 0 + risky_calls = 0 + + @tool(name="safe_lookup", approval_mode="never_require") + def safe_lookup() -> str: + nonlocal safe_calls + safe_calls += 1 + return "safe" + + @tool(name="risky_update", approval_mode="always_require") + def risky_update() -> str: + nonlocal risky_calls + risky_calls += 1 + return "risky" + + agent = Agent(client=chat_client_base, tools=[safe_lookup, risky_update]) + session = AgentSession(session_id="stale-hidden-session") + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="call_safe", name="safe_lookup", arguments="{}"), + Content.from_function_call(call_id="call_risky", name="risky_update", arguments="{}"), + ], + ) + ) + ] + + first_response = await agent.run("lookup and update", session=session) + request = _approval_requests(first_response.messages)[0] + + chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["unrelated"]))] + unrelated_response = await agent.run("never mind, answer something else", session=session) + + assert unrelated_response.text == "unrelated" + assert safe_calls == 0 + assert risky_calls == 0 + + chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["done"]))] + final_response = await agent.run(request.to_function_approval_response(approved=True), session=session) + + assert final_response.text == "done" + assert safe_calls == 1 + assert risky_calls == 1 + + +async def test_hidden_mixed_batch_requests_replay_only_for_matching_visible_approval( + chat_client_base: SupportsChatGetResponse, +) -> None: + """Approving one mixed batch must not replay hidden calls from another abandoned batch.""" + safe_a_calls = 0 + safe_b_calls = 0 + risky_a_calls = 0 + risky_b_calls = 0 + + @tool(name="safe_a", approval_mode="never_require") + def safe_a() -> str: + nonlocal safe_a_calls + safe_a_calls += 1 + return "safe-a" + + @tool(name="safe_b", approval_mode="never_require") + def safe_b() -> str: + nonlocal safe_b_calls + safe_b_calls += 1 + return "safe-b" + + @tool(name="risky_a", approval_mode="always_require") + def risky_a() -> str: + nonlocal risky_a_calls + risky_a_calls += 1 + return "risky-a" + + @tool(name="risky_b", approval_mode="always_require") + def risky_b() -> str: + nonlocal risky_b_calls + risky_b_calls += 1 + return "risky-b" + + agent = Agent(client=chat_client_base, tools=[safe_a, safe_b, risky_a, risky_b]) + session = AgentSession(session_id="grouped-hidden-session") + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="call_safe_a", name="safe_a", arguments="{}"), + Content.from_function_call(call_id="call_risky_a", name="risky_a", arguments="{}"), + ], + ) + ) + ] + + first_response = await agent.run("batch a", session=session) + assert [request.function_call.name for request in _approval_requests(first_response.messages)] == ["risky_a"] + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="call_safe_b", name="safe_b", arguments="{}"), + Content.from_function_call(call_id="call_risky_b", name="risky_b", arguments="{}"), + ], + ) + ) + ] + + second_response = await agent.run("batch b", session=session) + second_request = _approval_requests(second_response.messages)[0] + + chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["done"]))] + final_response = await agent.run(second_request.to_function_approval_response(approved=True), session=session) + + assert final_response.text == "done" + assert safe_a_calls == 0 + assert risky_a_calls == 0 + assert safe_b_calls == 1 + assert risky_b_calls == 1 + + +async def test_tool_approval_middleware_queues_multiple_approval_requests( + chat_client_base: SupportsChatGetResponse, +) -> None: + """The opt-in middleware should present multiple unresolved approvals one at a time.""" + first_calls = 0 + second_calls = 0 + + @tool(name="first_tool", approval_mode="always_require") + def first_tool() -> str: + nonlocal first_calls + first_calls += 1 + return "first" + + @tool(name="second_tool", approval_mode="always_require") + def second_tool() -> str: + nonlocal second_calls + second_calls += 1 + return "second" + + agent = Agent( + client=chat_client_base, + tools=[first_tool, second_tool], + middleware=[ToolApprovalMiddleware()], + ) + session = AgentSession(session_id="queue-session") + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="call_first", name="first_tool", arguments="{}"), + Content.from_function_call(call_id="call_second", name="second_tool", arguments="{}"), + ], + ) + ) + ] + + first_response = await agent.run("call both", session=session) + + first_requests = _approval_requests(first_response.messages) + assert [request.function_call.name for request in first_requests] == ["first_tool"] + assert first_calls == 0 + assert second_calls == 0 + + second_response = await agent.run(first_requests[0].to_function_approval_response(approved=True), session=session) + + second_requests = _approval_requests(second_response.messages) + assert [request.function_call.name for request in second_requests] == ["second_tool"] + assert first_calls == 0 + assert second_calls == 0 + + chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["done"]))] + final_response = await agent.run(second_requests[0].to_function_approval_response(approved=True), session=session) + + assert final_response.text == "done" + assert first_calls == 1 + assert second_calls == 1 + + +async def test_tool_approval_middleware_preserves_hidden_mixed_batch_requests( + chat_client_base: SupportsChatGetResponse, +) -> None: + """Middleware state saves should not discard core hidden already-approved requests.""" + lookup_calls = 0 + write_calls = 0 + + @tool(name="lookup_records", approval_mode="never_require") + def lookup_records() -> str: + nonlocal lookup_calls + lookup_calls += 1 + return "records" + + @tool(name="write_record", approval_mode="always_require") + def write_record() -> str: + nonlocal write_calls + write_calls += 1 + return "written" + + agent = Agent( + client=chat_client_base, + tools=[lookup_records, write_record], + middleware=[ToolApprovalMiddleware()], + ) + session = AgentSession(session_id="mixed-middleware-session") + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="call_lookup", name="lookup_records", arguments="{}"), + Content.from_function_call(call_id="call_write", name="write_record", arguments="{}"), + ], + ) + ) + ] + + first_response = await agent.run("lookup and write", session=session) + request = _approval_requests(first_response.messages)[0] + + chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["done"]))] + second_response = await agent.run(request.to_function_approval_response(approved=True), session=session) + + assert second_response.text == "done" + assert lookup_calls == 1 + assert write_calls == 1 + + +async def test_tool_approval_middleware_auto_approval_rule_receives_function_call( + chat_client_base: SupportsChatGetResponse, +) -> None: + """Heuristic auto-approval callbacks should receive function-call content and approve matching calls.""" + auto_calls = 0 + manual_calls = 0 + seen_calls: list[tuple[str, str | None]] = [] + + @tool(name="auto_write", approval_mode="always_require") + def auto_write() -> str: + nonlocal auto_calls + auto_calls += 1 + return "auto" + + @tool(name="manual_write", approval_mode="always_require") + def manual_write() -> str: + nonlocal manual_calls + manual_calls += 1 + return "manual" + + async def auto_approve_auto_write(function_call: Content) -> bool: + seen_calls.append((function_call.type, function_call.name)) + return function_call.name == "auto_write" + + agent = Agent( + client=chat_client_base, + tools=[auto_write, manual_write], + middleware=[ToolApprovalMiddleware(auto_approval_rules=[auto_approve_auto_write])], + ) + session = AgentSession(session_id="heuristic-session") + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="call_auto", name="auto_write", arguments="{}"), + Content.from_function_call(call_id="call_manual", name="manual_write", arguments="{}"), + ], + ) + ) + ] + + first_response = await agent.run("write both", session=session) + + requests = _approval_requests(first_response.messages) + assert [request.function_call.name for request in requests] == ["manual_write"] + assert seen_calls == [("function_call", "auto_write"), ("function_call", "manual_write")] + assert auto_calls == 0 + assert manual_calls == 0 + + chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["done"]))] + final_response = await agent.run(requests[0].to_function_approval_response(approved=True), session=session) + + assert final_response.text == "done" + assert auto_calls == 1 + assert manual_calls == 1 + + +async def test_tool_approval_middleware_auto_approved_loops_share_function_call_budget( + chat_client_base: SupportsChatGetResponse, +) -> None: + """Auto-approved re-entry should not reset max_function_calls.""" + calls = 0 + + @tool(name="budgeted_tool", approval_mode="always_require") + def budgeted_tool(value: str) -> str: + nonlocal calls + calls += 1 + return value + + def auto_approve_budgeted_tool(function_call: Content) -> bool: + return function_call.name == "budgeted_tool" + + chat_client_base.function_invocation_configuration["max_function_calls"] = 1 # type: ignore[attr-defined] + agent = Agent( + client=chat_client_base, + tools=[budgeted_tool], + middleware=[ToolApprovalMiddleware(auto_approval_rules=[auto_approve_budgeted_tool])], + ) + session = AgentSession(session_id="shared-budget-session") + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_first", + name="budgeted_tool", + arguments='{"value": "first"}', + ) + ], + ) + ), + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_second", + name="budgeted_tool", + arguments='{"value": "second"}', + ) + ], + ) + ), + ] + + response = await agent.run("call repeatedly", session=session) + + assert response.text == "I broke out of the function invocation loop..." + assert calls == 1 + + +async def test_tool_approval_middleware_queues_streamed_approval_requests( + chat_client_base: SupportsChatGetResponse, +) -> None: + """Streaming approval requests should also be queued one at a time.""" + calls = 0 + + @tool(name="first_streamed_tool", approval_mode="always_require") + def first_streamed_tool() -> str: + nonlocal calls + calls += 1 + return "first" + + @tool(name="second_streamed_tool", approval_mode="always_require") + def second_streamed_tool() -> str: + nonlocal calls + calls += 1 + return "second" + + agent = Agent( + client=chat_client_base, + tools=[first_streamed_tool, second_streamed_tool], + middleware=[ToolApprovalMiddleware()], + ) + session = AgentSession(session_id="stream-queue-session") + chat_client_base.streaming_responses = [ + [ + ChatResponseUpdate( + contents=[Content.from_function_call(call_id="call_first", name="first_streamed_tool", arguments="{}")], + role="assistant", + ), + ChatResponseUpdate( + contents=[ + Content.from_function_call(call_id="call_second", name="second_streamed_tool", arguments="{}") + ], + role="assistant", + ), + ] + ] + + first_stream = agent.run("call both", stream=True, session=session) + first_updates = [update async for update in first_stream] + first_requests = [content for update in first_updates for content in update.user_input_requests] + assert [request.function_call.name for request in first_requests] == ["first_streamed_tool"] + assert calls == 0 + + second_stream = agent.run( + first_requests[0].to_function_approval_response(approved=True), + stream=True, + session=session, + ) + second_updates = [update async for update in second_stream] + second_requests = [content for update in second_updates for content in update.user_input_requests] + assert [request.function_call.name for request in second_requests] == ["second_streamed_tool"] + assert calls == 0 + + chat_client_base.streaming_responses = [ + [ChatResponseUpdate(contents=[Content.from_text("done")], role="assistant")] + ] + final_stream = agent.run( + second_requests[0].to_function_approval_response(approved=True), + stream=True, + session=session, + ) + final_updates = [update async for update in final_stream] + final_response = await final_stream.get_final_response() + + assert final_updates[-1].text == "done" + assert final_response.text == "done" + assert calls == 2 + + +async def test_tool_approval_middleware_always_approve_tool_rule( + chat_client_base: SupportsChatGetResponse, +) -> None: + """An always-approve response should add a standing tool-level approval rule.""" + calls = 0 + + @tool(name="dangerous_tool", approval_mode="always_require") + def dangerous_tool(value: str) -> str: + nonlocal calls + calls += 1 + return value + + agent = Agent( + client=chat_client_base, + tools=[dangerous_tool], + middleware=[ToolApprovalMiddleware()], + ) + session = AgentSession(session_id="standing-rule-session") + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_initial", + name="dangerous_tool", + arguments='{"value": "one"}', + ) + ], + ) + ) + ] + + first_response = await agent.run("call once", session=session) + first_request = _approval_requests(first_response.messages)[0] + + chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["first done"]))] + await agent.run(create_always_approve_tool_response(first_request), session=session) + + assert calls == 1 + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_auto", + name="dangerous_tool", + arguments='{"value": "two"}', + ) + ], + ) + ), + ChatResponse(messages=Message(role="assistant", contents=["second done"])), + ] + + second_response = await agent.run("call again", session=session) + + assert second_response.text == "second done" + assert calls == 2 + + +async def test_tool_approval_middleware_standing_rules_include_hosted_server_boundary( + chat_client_base: SupportsChatGetResponse, +) -> None: + """A standing hosted-tool rule should only match the same server_label.""" + calls = 0 + + @tool(name="hosted_tool", approval_mode="always_require") + def hosted_tool() -> str: + nonlocal calls + calls += 1 + return "hosted" + + def hosted_call(call_id: str, server_label: str) -> Content: + return Content.from_function_call( + call_id=call_id, + name="hosted_tool", + arguments="{}", + additional_properties={"server_label": server_label}, + ) + + agent = Agent( + client=chat_client_base, + tools=[hosted_tool], + middleware=[ToolApprovalMiddleware()], + ) + session = AgentSession(session_id="hosted-boundary-session") + chat_client_base.run_responses = [ + ChatResponse(messages=Message(role="assistant", contents=[hosted_call("call_initial", "server-a")])) + ] + + first_response = await agent.run("call hosted a", session=session) + first_request = _approval_requests(first_response.messages)[0] + + chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["server a done"]))] + await agent.run(create_always_approve_tool_response(first_request), session=session) + + assert calls == 0 + + chat_client_base.run_responses = [ + ChatResponse(messages=Message(role="assistant", contents=[hosted_call("call_same_server", "server-a")])), + ChatResponse(messages=Message(role="assistant", contents=["same server done"])), + ] + + same_server_response = await agent.run("call hosted a again", session=session) + + assert same_server_response.text == "same server done" + assert _approval_requests(same_server_response.messages) == [] + assert calls == 0 + + chat_client_base.run_responses = [ + ChatResponse(messages=Message(role="assistant", contents=[hosted_call("call_other_server", "server-b")])) + ] + + other_server_response = await agent.run("call hosted b", session=session) + + requests = _approval_requests(other_server_response.messages) + assert [request.function_call.additional_properties["server_label"] for request in requests] == ["server-b"] + assert calls == 0 + + +async def test_tool_approval_middleware_always_approve_tool_with_arguments_rule( + chat_client_base: SupportsChatGetResponse, +) -> None: + """Argument-scoped always-approve rules should require exact argument matches.""" + calls = 0 + + @tool(name="argument_scoped_tool", approval_mode="always_require") + def argument_scoped_tool(value: str) -> str: + nonlocal calls + calls += 1 + return value + + agent = Agent( + client=chat_client_base, + tools=[argument_scoped_tool], + middleware=[ToolApprovalMiddleware()], + ) + session = AgentSession(session_id="argument-rule-session") + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_initial", + name="argument_scoped_tool", + arguments='{"value": "same"}', + ) + ], + ) + ) + ] + + first_response = await agent.run("call with same", session=session) + first_request = _approval_requests(first_response.messages)[0] + + chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["first done"]))] + await agent.run(create_always_approve_tool_with_arguments_response(first_request), session=session) + + assert calls == 1 + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_same", + name="argument_scoped_tool", + arguments='{"value": "same"}', + ) + ], + ) + ), + ChatResponse(messages=Message(role="assistant", contents=["same done"])), + ] + + second_response = await agent.run("call with same again", session=session) + + assert second_response.text == "same done" + assert calls == 2 + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_different", + name="argument_scoped_tool", + arguments='{"value": "different"}', + ) + ], + ) + ) + ] + + third_response = await agent.run("call with different args", session=session) + + requests = _approval_requests(third_response.messages) + assert [request.function_call.arguments for request in requests] == ['{"value": "different"}'] + assert calls == 2 + + +async def test_tool_approval_middleware_empty_arguments_rule_is_not_tool_wide( + chat_client_base: SupportsChatGetResponse, +) -> None: + """An argument-scoped no-argument approval should not become a wildcard.""" + calls = 0 + + @tool(name="optional_args_tool", approval_mode="always_require") + def optional_args_tool(value: str = "default") -> str: + nonlocal calls + calls += 1 + return value + + agent = Agent( + client=chat_client_base, + tools=[optional_args_tool], + middleware=[ToolApprovalMiddleware()], + ) + session = AgentSession(session_id="empty-arguments-rule-session") + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_empty", + name="optional_args_tool", + arguments="{}", + ) + ], + ) + ) + ] + + first_response = await agent.run("call without args", session=session) + first_request = _approval_requests(first_response.messages)[0] + + chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["empty done"]))] + await agent.run(create_always_approve_tool_with_arguments_response(first_request), session=session) + + assert calls == 1 + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_non_empty", + name="optional_args_tool", + arguments='{"value": "custom"}', + ) + ], + ) + ) + ] + + second_response = await agent.run("call with args", session=session) + + requests = _approval_requests(second_response.messages) + assert [request.function_call.arguments for request in requests] == ['{"value": "custom"}'] + assert calls == 1 diff --git a/python/packages/hyperlight/tests/hyperlight/test_hyperlight_codeact.py b/python/packages/hyperlight/tests/hyperlight/test_hyperlight_codeact.py index 03e3c2269c..484d056a0d 100644 --- a/python/packages/hyperlight/tests/hyperlight/test_hyperlight_codeact.py +++ b/python/packages/hyperlight/tests/hyperlight/test_hyperlight_codeact.py @@ -978,6 +978,10 @@ async def test_sandbox_code_failure_returns_nonzero_exit(restored_sandbox) -> No @skip_if_hyperlight_integration_tests_disabled +@pytest.mark.skipif( + sys.platform == "win32" and sys.version_info < (3, 11), + reason="Hyperlight sandbox snapshot/restore crashes on Windows Python 3.10.", +) async def test_sandbox_snapshot_restore_keeps_sandbox_functional(restored_sandbox) -> None: """Verify snapshot/restore cycle leaves the sandbox in a working state.""" # Mutate the sandbox diff --git a/python/samples/02-agents/tools/README.md b/python/samples/02-agents/tools/README.md index ad07d86ecd..b24fe2223b 100644 --- a/python/samples/02-agents/tools/README.md +++ b/python/samples/02-agents/tools/README.md @@ -22,6 +22,7 @@ injection, and dynamic (progressive) tool exposure. |------|--------------| | [`function_tool_with_approval.py`](function_tool_with_approval.py) | Requiring human approval before a tool runs. | | [`function_tool_with_approval_and_sessions.py`](function_tool_with_approval_and_sessions.py) | Tool approvals combined with sessions. | +| [`tool_approval_middleware.py`](tool_approval_middleware.py) | Session-backed approval coordination, mixed-batch approvals, and "always approve" rules. | | [`function_invocation_configuration.py`](function_invocation_configuration.py) | Configuring function-invocation settings (e.g. max iterations). | | [`control_total_tool_executions.py`](control_total_tool_executions.py) | All the ways to cap how many times tools run. | | [`function_tool_with_max_invocations.py`](function_tool_with_max_invocations.py) | Limiting the number of invocations per tool. | diff --git a/python/samples/02-agents/tools/tool_approval_middleware.py b/python/samples/02-agents/tools/tool_approval_middleware.py new file mode 100644 index 0000000000..5b8dc4fb42 --- /dev/null +++ b/python/samples/02-agents/tools/tool_approval_middleware.py @@ -0,0 +1,191 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from typing import Annotated + +from agent_framework import ( + Agent, + AgentResponse, + AgentSession, + Content, + Message, + ToolApprovalMiddleware, + create_always_approve_tool_response, + create_always_approve_tool_with_arguments_response, + tool, +) +from agent_framework.foundry import FoundryChatClient +from azure.identity import AzureCliCredential +from dotenv import load_dotenv + +""" +This sample demonstrates how a host application can decide which approval +requests may run now, which must be rejected, and which can be remembered for +future runs. + +The model may not request every tool on every run. The important part is the +approval mechanism: + +1. Tools that are safe to run immediately use ``approval_mode="never_require"``. +2. Sensitive tools use ``approval_mode="always_require"``. +3. ``ToolApprovalMiddleware`` coordinates approval prompts and standing rules. +4. The host turns user policy into ``function_approval_response`` content: + - approve for this request only; + - reject for this request; + - approve and remember the tool for future requests; + - approve and remember the tool only when called again with the same arguments. +5. Heuristic auto-approval rules can approve low-risk function calls before + the user is prompted. +""" + +# Load environment variables from .env file +load_dotenv() + + +@tool(approval_mode="never_require") +def lookup_ticket(ticket_id: Annotated[str, "Support ticket id, for example T-123"]) -> str: + """Look up a support ticket. This read-only tool runs without approval.""" + return f"Ticket {ticket_id}: customer confirmed the issue can be closed." + + +@tool(approval_mode="always_require") +def close_ticket( + ticket_id: Annotated[str, "Support ticket id, for example T-123"], + resolution: Annotated[str, "Short resolution text"], +) -> str: + """Close a support ticket.""" + return f"Ticket {ticket_id} closed with resolution: {resolution}" + + +@tool(approval_mode="always_require") +def notify_customer( + ticket_id: Annotated[str, "Support ticket id, for example T-123"], + message: Annotated[str, "Message to send to the customer"], +) -> str: + """Notify the customer about a ticket update.""" + return f"Customer notified for {ticket_id}: {message}" + + +@tool(approval_mode="always_require") +def add_internal_note( + ticket_id: Annotated[str, "Support ticket id, for example T-123"], + note: Annotated[str, "Internal note text"], +) -> str: + """Add an internal note to a support ticket.""" + return f"Internal note added to {ticket_id}: {note}" + + +@tool(approval_mode="always_require") +def delete_attachment( + ticket_id: Annotated[str, "Support ticket id, for example T-123"], + attachment_name: Annotated[str, "Attachment file name"], +) -> str: + """Delete an attachment from a support ticket.""" + return f"Deleted {attachment_name} from ticket {ticket_id}." + + +def auto_approve_low_risk_notes(function_call: Content) -> bool: + """Heuristic rule: auto-approve short internal notes for the target ticket.""" + if function_call.name != "add_internal_note": + return False + + arguments = function_call.parse_arguments() or {} + note = str(arguments.get("note", "")) + return arguments.get("ticket_id") == "T-123" and len(note) <= 120 + + +def approval_response_for_user_policy(request: Content) -> Content: + """Convert user/host policy into an approval response for one tool request.""" + function_call = request.function_call + if function_call is None or function_call.name is None: + return request.to_function_approval_response(approved=False) + + tool_name = function_call.name + print(f"Approval requested: {tool_name}({function_call.arguments})") + + if tool_name in {"close_ticket"}: + print(f"Decision: approve and remember future {tool_name} calls with these exact arguments") + return create_always_approve_tool_with_arguments_response(request) + + if tool_name in {"notify_customer"}: + print(f"Decision: approve and remember all future {tool_name} calls") + return create_always_approve_tool_response(request) + + if tool_name in {"delete_attachment"}: + print(f"Decision: reject {tool_name} for this run") + return request.to_function_approval_response(approved=False) + + print(f"Decision: reject {tool_name}; no policy allowed it") + return request.to_function_approval_response(approved=False) + + +async def resolve_approval_requests(agent: Agent, response: AgentResponse, session: AgentSession) -> AgentResponse: + """Resolve approval prompts until the agent returns a regular answer.""" + result = response + while result.user_input_requests: + approval_responses = [approval_response_for_user_policy(request) for request in result.user_input_requests] + result = await agent.run(Message(role="user", contents=approval_responses), session=session) + return result + + +async def main() -> None: + """Run the tool approval middleware sample.""" + # 1. Create a regular chat client. + client = FoundryChatClient(credential=AzureCliCredential()) + + # 2. Create an agent with sensitive tools and opt-in ToolApprovalMiddleware. + agent = Agent( + client=client, + name="SupportAgent", + instructions=( + "You are a support agent. Use tools when useful. " + "Look up ticket T-123, close it if the customer confirmed, notify the customer, " + "add a short internal note, and do not delete attachments unless the tool is approved." + ), + tools=[lookup_ticket, close_ticket, notify_customer, add_internal_note, delete_attachment], + middleware=[ToolApprovalMiddleware(auto_approval_rules=[auto_approve_low_risk_notes])], + ) + session = agent.create_session() + + # 3. Ask for work that may trigger a mixed batch of safe and sensitive tool calls. + query = ( + "Please process ticket T-123: check the ticket, close it as resolved, " + "notify the customer, add a short internal note, and remove debug.log if it is attached." + ) + print(f"User: {query}") + result = await agent.run(query, session=session) + + # 4. Convert approval requests into approve/reject/always-approve responses. + result = await resolve_approval_requests(agent, result, session) + print(f"Agent: {result.text}") + + # 5. Later runs can use remembered approval rules: + # - notify_customer: all future calls to the tool. + # - close_ticket: only future calls with the same arguments. + # - add_internal_note: low-risk matching calls are auto-approved by the heuristic callback. + follow_up = "Send the customer a short follow-up for ticket T-123." + print(f"\nUser: {follow_up}") + result = await agent.run(follow_up, session=session) + result = await resolve_approval_requests(agent, result, session) + print(f"Agent: {result.text}") + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Sample output: +User: Please process ticket T-123: check the ticket, close it as resolved, +notify the customer, add a short internal note, and remove debug.log if it is attached. +Approval requested: close_ticket({"ticket_id": "T-123", "resolution": "resolved"}) +Decision: approve and remember future close_ticket calls with these exact arguments +Approval requested: notify_customer({"ticket_id": "T-123", "message": "Your ticket has been resolved."}) +Decision: approve and remember all future notify_customer calls +Approval requested: delete_attachment({"ticket_id": "T-123", "attachment_name": "debug.log"}) +Decision: reject delete_attachment for this run +Agent: Ticket T-123 was closed, the customer was notified, and a short internal note was added. +I did not delete debug.log. + +User: Send the customer a short follow-up for ticket T-123. +Agent: The customer was sent a short follow-up for ticket T-123. +"""