mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add tool approval middleware (#6414)
* Add Python tool approval middleware Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix tool approval restored state handling Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Gate hidden approvals on explicit approval responses Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Handle string inputs in approval replay scan Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Cover argument-scoped approval rules Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Refine tool approval state and budgets Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix tool approval PR CI failures Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Revert DevUI Aspire README link change Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
c79f886dc3
commit
df29af611c
@@ -100,6 +100,23 @@ agent_framework/
|
||||
- **`FileSearchResult`** / **`FileSearchMatch`** - `SerializationMixin` DTOs returned by `search_files`, carrying the matching file name, a context snippet, and the matching lines with 1-based line numbers.
|
||||
- **`FileAccessProvider`** - `ContextProvider` that adds shared file-access tools (`file_access_save_file`, `file_access_read_file`, `file_access_delete_file`, `file_access_list_files`, `file_access_search_files`) plus default usage instructions to each invocation. Unlike `MemoryContextProvider`, the store is intentionally shared across sessions and agents.
|
||||
|
||||
### Tool Approval Harness (`_harness/_tool_approval.py`)
|
||||
|
||||
- **`ToolApprovalMiddleware`** - Experimental opt-in agent middleware that coordinates session-backed approval
|
||||
rules, heuristic `auto_approval_rules`, queued approval requests, collected approval responses, and
|
||||
streaming/non-streaming approval prompts. Heuristic callbacks receive the underlying `function_call` content.
|
||||
- **`ToolApprovalRule`** / **`ToolApprovalState`** - Serializable state models for standing approvals and queued
|
||||
approval flow. `ToolApprovalRule.arguments is None` means a tool-wide rule; an empty dict `{}` means an exact
|
||||
no-argument call for `create_always_approve_tool_with_arguments_response`.
|
||||
- **`create_always_approve_tool_response`** / **`create_always_approve_tool_with_arguments_response`** - Helpers
|
||||
that return normal `function_approval_response` content with `additional_properties` metadata consumed by
|
||||
`ToolApprovalMiddleware`. Standing rules for hosted tools include the `server_label` boundary, so same-named tools
|
||||
on different hosted servers do not share approvals.
|
||||
- Mixed tool-call batches use a default .NET-style bypass in the function invocation loop: when a session is
|
||||
available, approval requests for known non-approval-required tools are treated as already approved, hidden, stored
|
||||
in session state keyed to the visible approval request ids from that batch, and reinjected only when that visible
|
||||
approval flow resumes.
|
||||
|
||||
### Workflows (`_workflows/`)
|
||||
|
||||
- **`Workflow`** - Graph-based workflow definition
|
||||
|
||||
@@ -125,6 +125,15 @@ from ._harness._todo import (
|
||||
TodoStore,
|
||||
)
|
||||
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool, SamplingApprovalCallback
|
||||
from ._harness._tool_approval import (
|
||||
DEFAULT_TOOL_APPROVAL_SOURCE_ID,
|
||||
ToolApprovalMiddleware,
|
||||
ToolApprovalRule,
|
||||
ToolApprovalRuleCallback,
|
||||
ToolApprovalState,
|
||||
create_always_approve_tool_response,
|
||||
create_always_approve_tool_with_arguments_response,
|
||||
)
|
||||
from ._middleware import (
|
||||
AgentContext,
|
||||
AgentMiddleware,
|
||||
@@ -330,6 +339,7 @@ __all__ = [
|
||||
"DEFAULT_MEMORY_SOURCE_ID",
|
||||
"DEFAULT_MODE_SOURCE_ID",
|
||||
"DEFAULT_TODO_SOURCE_ID",
|
||||
"DEFAULT_TOOL_APPROVAL_SOURCE_ID",
|
||||
"EXCLUDED_KEY",
|
||||
"EXCLUDE_REASON_KEY",
|
||||
"GROUP_ANNOTATION_KEY",
|
||||
@@ -509,6 +519,10 @@ __all__ = [
|
||||
"TodoStore",
|
||||
"TokenBudgetComposedStrategy",
|
||||
"TokenizerProtocol",
|
||||
"ToolApprovalMiddleware",
|
||||
"ToolApprovalRule",
|
||||
"ToolApprovalRuleCallback",
|
||||
"ToolApprovalState",
|
||||
"ToolMode",
|
||||
"ToolResultCompactionStrategy",
|
||||
"ToolTypes",
|
||||
@@ -543,6 +557,8 @@ __all__ = [
|
||||
"annotate_message_groups",
|
||||
"apply_compaction",
|
||||
"chat_middleware",
|
||||
"create_always_approve_tool_response",
|
||||
"create_always_approve_tool_with_arguments_response",
|
||||
"create_edge_runner",
|
||||
"create_harness_agent",
|
||||
"detect_media_type_from_base64",
|
||||
|
||||
@@ -0,0 +1,632 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
from asyncio import sleep
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Iterable, Mapping, MutableMapping, Sequence
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from .._feature_stage import ExperimentalFeature, experimental
|
||||
from .._middleware import AgentContext, AgentMiddleware
|
||||
from .._serialization import SerializationMixin
|
||||
from .._sessions import AgentSession
|
||||
from .._types import (
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
Content,
|
||||
FinishReason,
|
||||
FinishReasonLiteral,
|
||||
Message,
|
||||
ResponseStream,
|
||||
)
|
||||
|
||||
DEFAULT_TOOL_APPROVAL_SOURCE_ID = "tool_approval"
|
||||
_FUNCTION_INVOCATION_BUDGET_STATE_KEY = "_function_invocation_budget_state"
|
||||
ALWAYS_APPROVE_PROPERTY = "tool_approval"
|
||||
ALWAYS_APPROVE_SCOPE_PROPERTY = "always_approve"
|
||||
ALWAYS_APPROVE_TOOL: Literal["tool"] = "tool"
|
||||
ALWAYS_APPROVE_TOOL_WITH_ARGUMENTS: Literal["tool_with_arguments"] = "tool_with_arguments"
|
||||
|
||||
_RULES_KEY = "rules"
|
||||
_QUEUED_APPROVAL_REQUESTS_KEY = "queued_approval_requests"
|
||||
_COLLECTED_APPROVAL_RESPONSES_KEY = "collected_approval_responses"
|
||||
|
||||
ToolApprovalScope = Literal["tool", "tool_with_arguments"]
|
||||
ToolApprovalRuleCallback = Callable[[Content], bool | Awaitable[bool]]
|
||||
|
||||
|
||||
def _parse_function_arguments(function_call: Content) -> dict[str, Any]:
|
||||
arguments = function_call.parse_arguments()
|
||||
return dict(arguments or {})
|
||||
|
||||
|
||||
def _serialize_argument_value(value: Any) -> str:
|
||||
return json.dumps(value, sort_keys=True, separators=(",", ":"), default=str)
|
||||
|
||||
|
||||
def _serialize_arguments(function_call: Content) -> dict[str, str]:
|
||||
"""Serialize arguments for exact matching.
|
||||
|
||||
``None`` is reserved on :class:`ToolApprovalRule` for tool-wide rules.
|
||||
An argument-scoped rule for a no-argument call stores ``{}``, so it only
|
||||
matches future no-argument calls and never becomes a wildcard.
|
||||
"""
|
||||
arguments = _parse_function_arguments(function_call)
|
||||
return {key: _serialize_argument_value(value) for key, value in arguments.items()}
|
||||
|
||||
|
||||
def _server_label(function_call: Content) -> str | None:
|
||||
"""Return the hosted-tool server boundary for a function call, if present."""
|
||||
value = function_call.additional_properties.get("server_label")
|
||||
return value if isinstance(value, str) else None
|
||||
|
||||
|
||||
def _content_from_state(value: Any) -> Content:
|
||||
if isinstance(value, Content):
|
||||
return value
|
||||
if isinstance(value, Mapping):
|
||||
return Content.from_dict(cast(Mapping[str, Any], value))
|
||||
raise TypeError(f"Expected Content or mapping state item, got {type(value).__name__}.")
|
||||
|
||||
|
||||
def _contents_from_state(values: Any) -> list[Content]:
|
||||
if not isinstance(values, list):
|
||||
return []
|
||||
state_items = list(cast(Iterable[Any], values))
|
||||
return [_content_from_state(value) for value in state_items]
|
||||
|
||||
|
||||
def _content_to_state(content: Content) -> dict[str, Any]:
|
||||
return content.to_dict()
|
||||
|
||||
|
||||
@experimental(feature_id=ExperimentalFeature.HARNESS)
|
||||
class ToolApprovalRule(SerializationMixin):
|
||||
"""A standing rule for approving future matching tool calls."""
|
||||
|
||||
tool_name: str
|
||||
arguments: dict[str, str] | None
|
||||
server_label: str | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: Mapping[str, str] | None = None,
|
||||
*,
|
||||
server_label: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize a tool approval rule.
|
||||
|
||||
Args:
|
||||
tool_name: The function tool name this rule applies to.
|
||||
arguments: Optional canonicalized argument values. When omitted, the
|
||||
rule applies to every call to the tool. Use an empty mapping to
|
||||
match only no-argument calls.
|
||||
|
||||
Keyword Args:
|
||||
server_label: Optional hosted-tool server boundary. Hosted approvals
|
||||
only match future approvals from the same server label.
|
||||
"""
|
||||
normalized_name = tool_name.strip()
|
||||
if not normalized_name:
|
||||
raise ValueError("Tool approval rule tool_name must be a non-empty string.")
|
||||
self.tool_name = normalized_name
|
||||
self.arguments = dict(arguments) if arguments is not None else None
|
||||
self.server_label = server_label
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
value: MutableMapping[str, Any],
|
||||
/,
|
||||
*,
|
||||
dependencies: MutableMapping[str, Any] | None = None,
|
||||
) -> ToolApprovalRule:
|
||||
"""Create a rule from serialized state."""
|
||||
del dependencies
|
||||
tool_name = value.get("tool_name")
|
||||
if not isinstance(tool_name, str):
|
||||
raise ValueError("Tool approval rule tool_name must be a string.")
|
||||
raw_arguments = value.get("arguments")
|
||||
if raw_arguments is not None and not isinstance(raw_arguments, Mapping):
|
||||
raise ValueError("Tool approval rule arguments must be a mapping or None.")
|
||||
server_label = value.get("server_label")
|
||||
if server_label is not None and not isinstance(server_label, str):
|
||||
raise ValueError("Tool approval rule server_label must be a string or None.")
|
||||
arguments = (
|
||||
{str(key): str(argument_value) for key, argument_value in cast(Mapping[str, Any], raw_arguments).items()}
|
||||
if isinstance(raw_arguments, Mapping)
|
||||
else None
|
||||
)
|
||||
return cls(tool_name=tool_name, arguments=arguments, server_label=server_label)
|
||||
|
||||
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]:
|
||||
"""Serialize the rule."""
|
||||
exclude = exclude or set()
|
||||
payload: dict[str, Any] = {"tool_name": self.tool_name}
|
||||
if "type" not in exclude:
|
||||
payload["type"] = self._get_type_identifier()
|
||||
if self.arguments is not None or not exclude_none:
|
||||
payload["arguments"] = self.arguments
|
||||
if self.server_label is not None or not exclude_none:
|
||||
payload["server_label"] = self.server_label
|
||||
return payload
|
||||
|
||||
|
||||
@experimental(feature_id=ExperimentalFeature.HARNESS)
|
||||
class ToolApprovalState(SerializationMixin):
|
||||
"""Session-backed state used by :class:`ToolApprovalMiddleware`."""
|
||||
|
||||
rules: list[ToolApprovalRule]
|
||||
queued_approval_requests: list[Content]
|
||||
collected_approval_responses: list[Content]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
rules: Sequence[ToolApprovalRule | Mapping[str, Any]] | None = None,
|
||||
queued_approval_requests: Sequence[Content | Mapping[str, Any]] | None = None,
|
||||
collected_approval_responses: Sequence[Content | Mapping[str, Any]] | None = None,
|
||||
) -> None:
|
||||
"""Initialize approval state."""
|
||||
self.rules = [
|
||||
rule if isinstance(rule, ToolApprovalRule) else ToolApprovalRule.from_dict(dict(rule))
|
||||
for rule in (rules or [])
|
||||
]
|
||||
self.queued_approval_requests = [
|
||||
item if isinstance(item, Content) else Content.from_dict(item) for item in (queued_approval_requests or [])
|
||||
]
|
||||
self.collected_approval_responses = [
|
||||
item if isinstance(item, Content) else Content.from_dict(item)
|
||||
for item in (collected_approval_responses or [])
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
value: MutableMapping[str, Any],
|
||||
/,
|
||||
*,
|
||||
dependencies: MutableMapping[str, Any] | None = None,
|
||||
) -> ToolApprovalState:
|
||||
"""Create state from serialized state."""
|
||||
del dependencies
|
||||
return cls(
|
||||
rules=cast(Sequence[Mapping[str, Any]], value.get("rules", [])),
|
||||
queued_approval_requests=cast(Sequence[Mapping[str, Any]], value.get("queued_approval_requests", [])),
|
||||
collected_approval_responses=cast(
|
||||
Sequence[Mapping[str, Any]],
|
||||
value.get("collected_approval_responses", []),
|
||||
),
|
||||
)
|
||||
|
||||
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]:
|
||||
"""Serialize state."""
|
||||
del exclude_none
|
||||
exclude = exclude or set()
|
||||
payload: dict[str, Any] = {
|
||||
"rules": [rule.to_dict() for rule in self.rules],
|
||||
"queued_approval_requests": [_content_to_state(item) for item in self.queued_approval_requests],
|
||||
"collected_approval_responses": [_content_to_state(item) for item in self.collected_approval_responses],
|
||||
}
|
||||
if "type" not in exclude:
|
||||
payload["type"] = self._get_type_identifier()
|
||||
return payload
|
||||
|
||||
|
||||
def create_always_approve_tool_response(request: Content, *, reason: str | None = None) -> Content:
|
||||
"""Create an approval response that records a standing rule for the whole tool.
|
||||
|
||||
Args:
|
||||
request: The ``function_approval_request`` content to approve.
|
||||
|
||||
Keyword Args:
|
||||
reason: Optional approval reason stored in ``additional_properties``.
|
||||
|
||||
Returns:
|
||||
A ``function_approval_response`` with metadata consumed by
|
||||
:class:`ToolApprovalMiddleware`.
|
||||
"""
|
||||
return _create_always_approve_response(request, ALWAYS_APPROVE_TOOL, reason=reason)
|
||||
|
||||
|
||||
def create_always_approve_tool_with_arguments_response(request: Content, *, reason: str | None = None) -> Content:
|
||||
"""Create an approval response that records a standing rule for the tool and exact arguments."""
|
||||
return _create_always_approve_response(request, ALWAYS_APPROVE_TOOL_WITH_ARGUMENTS, reason=reason)
|
||||
|
||||
|
||||
def _create_always_approve_response(request: Content, scope: ToolApprovalScope, *, reason: str | None) -> Content:
|
||||
response = request.to_function_approval_response(approved=True)
|
||||
metadata: dict[str, Any] = {ALWAYS_APPROVE_SCOPE_PROPERTY: scope}
|
||||
if reason is not None:
|
||||
metadata["reason"] = reason
|
||||
response.additional_properties[ALWAYS_APPROVE_PROPERTY] = metadata
|
||||
return response
|
||||
|
||||
|
||||
def _get_state(session: AgentSession, *, source_id: str) -> ToolApprovalState:
|
||||
raw_state = session.state.get(source_id)
|
||||
if isinstance(raw_state, ToolApprovalState):
|
||||
return raw_state
|
||||
if isinstance(raw_state, MutableMapping):
|
||||
raw_state_mapping = cast(MutableMapping[str, Any], raw_state)
|
||||
return ToolApprovalState(
|
||||
rules=cast(Sequence[Mapping[str, Any]], raw_state_mapping.get(_RULES_KEY, [])),
|
||||
queued_approval_requests=_contents_from_state(raw_state_mapping.get(_QUEUED_APPROVAL_REQUESTS_KEY, [])),
|
||||
collected_approval_responses=_contents_from_state(
|
||||
raw_state_mapping.get(_COLLECTED_APPROVAL_RESPONSES_KEY, []),
|
||||
),
|
||||
)
|
||||
if raw_state is not None:
|
||||
raise TypeError(f"Session state for {source_id!r} must be a mapping, got {type(raw_state).__name__}.")
|
||||
state = ToolApprovalState()
|
||||
session.state[source_id] = state.to_dict(exclude={"type"})
|
||||
return state
|
||||
|
||||
|
||||
def _save_state(session: AgentSession, state: ToolApprovalState, *, source_id: str) -> None:
|
||||
serialized = state.to_dict(exclude={"type"})
|
||||
existing = session.state.get(source_id)
|
||||
if isinstance(existing, MutableMapping):
|
||||
for key, value in cast(MutableMapping[str, Any], existing).items():
|
||||
if key not in serialized and key != "type":
|
||||
serialized[key] = value
|
||||
session.state[source_id] = serialized
|
||||
|
||||
|
||||
def _rule_exists(rules: Sequence[ToolApprovalRule], new_rule: ToolApprovalRule) -> bool:
|
||||
for rule in rules:
|
||||
if rule.tool_name != new_rule.tool_name:
|
||||
continue
|
||||
if rule.server_label != new_rule.server_label:
|
||||
continue
|
||||
if rule.arguments == new_rule.arguments:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _add_rule_if_missing(state: ToolApprovalState, rule: ToolApprovalRule) -> None:
|
||||
if not _rule_exists(state.rules, rule):
|
||||
state.rules.append(rule)
|
||||
|
||||
|
||||
def _function_call_from_request(request: Content) -> Content | None:
|
||||
function_call = request.function_call
|
||||
if function_call is None or function_call.type != "function_call" or function_call.name is None:
|
||||
return None
|
||||
return function_call
|
||||
|
||||
|
||||
def _arguments_match(rule_arguments: Mapping[str, str], function_call: Content) -> bool:
|
||||
call_arguments = _serialize_arguments(function_call) or {}
|
||||
if len(rule_arguments) != len(call_arguments):
|
||||
return False
|
||||
return all(call_arguments.get(key) == value for key, value in rule_arguments.items())
|
||||
|
||||
|
||||
def _matches_rule(request: Content, rules: Sequence[ToolApprovalRule]) -> bool:
|
||||
function_call = _function_call_from_request(request)
|
||||
if function_call is None:
|
||||
return False
|
||||
for rule in rules:
|
||||
if rule.tool_name != function_call.name:
|
||||
continue
|
||||
if rule.server_label != _server_label(function_call):
|
||||
continue
|
||||
if rule.arguments is None:
|
||||
return True
|
||||
if _arguments_match(rule.arguments, function_call):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_always_approve_scope(response: Content) -> ToolApprovalScope | None:
|
||||
metadata = response.additional_properties.get(ALWAYS_APPROVE_PROPERTY)
|
||||
if not isinstance(metadata, Mapping):
|
||||
return None
|
||||
metadata_mapping = cast(Mapping[str, Any], metadata)
|
||||
scope = metadata_mapping.get(ALWAYS_APPROVE_SCOPE_PROPERTY)
|
||||
if scope == ALWAYS_APPROVE_TOOL:
|
||||
return ALWAYS_APPROVE_TOOL
|
||||
if scope == ALWAYS_APPROVE_TOOL_WITH_ARGUMENTS:
|
||||
return ALWAYS_APPROVE_TOOL_WITH_ARGUMENTS
|
||||
return None
|
||||
|
||||
|
||||
def _clone_without_always_approve_metadata(response: Content) -> Content:
|
||||
cloned = copy.deepcopy(response)
|
||||
cloned.additional_properties.pop(ALWAYS_APPROVE_PROPERTY, None)
|
||||
return cloned
|
||||
|
||||
|
||||
@experimental(feature_id=ExperimentalFeature.HARNESS)
|
||||
class ToolApprovalMiddleware(AgentMiddleware):
|
||||
"""Coordinate standing tool approvals and queued approval prompts for an agent.
|
||||
|
||||
This middleware is opt-in and requires callers to run the agent with an
|
||||
:class:`AgentSession`, because approval rules and queued requests are stored
|
||||
in session state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
source_id: str = DEFAULT_TOOL_APPROVAL_SOURCE_ID,
|
||||
auto_approval_rules: Sequence[ToolApprovalRuleCallback] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the middleware.
|
||||
|
||||
Keyword Args:
|
||||
source_id: Session-state key used by this middleware.
|
||||
auto_approval_rules: Optional callbacks that can auto-approve a
|
||||
``function_call``. Each callback receives the function-call
|
||||
content and returns ``True`` to approve it.
|
||||
"""
|
||||
self.source_id = source_id
|
||||
self.auto_approval_rules = tuple(auto_approval_rules or ())
|
||||
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
"""Process one agent invocation."""
|
||||
if context.session is None:
|
||||
raise RuntimeError("ToolApprovalMiddleware requires an AgentSession.")
|
||||
|
||||
state = _get_state(context.session, source_id=self.source_id)
|
||||
context.client_kwargs.setdefault(_FUNCTION_INVOCATION_BUDGET_STATE_KEY, {})
|
||||
context.messages = self._prepare_inbound_messages(context.messages, state)
|
||||
await self._drain_auto_approvable_queue(state)
|
||||
if next_queued := self._pop_next_queued_request(state):
|
||||
_save_state(context.session, state, source_id=self.source_id)
|
||||
context.result = self._response_for_queued_request(next_queued, stream=context.stream)
|
||||
return
|
||||
if context.stream:
|
||||
context.result = self._process_stream(context, call_next, state)
|
||||
return
|
||||
|
||||
while True:
|
||||
context.messages = self._inject_collected_responses(context.messages, state)
|
||||
state_changed = bool(state.collected_approval_responses)
|
||||
state.collected_approval_responses.clear()
|
||||
if state_changed:
|
||||
_save_state(context.session, state, source_id=self.source_id)
|
||||
|
||||
await call_next()
|
||||
if isinstance(context.result, ResponseStream):
|
||||
return
|
||||
if context.result is None:
|
||||
_save_state(context.session, state, source_id=self.source_id)
|
||||
return
|
||||
|
||||
all_auto_approved = await self._process_outbound_messages(context.result.messages, state)
|
||||
_save_state(context.session, state, source_id=self.source_id)
|
||||
if not all_auto_approved:
|
||||
return
|
||||
context.messages = []
|
||||
context.result = None
|
||||
|
||||
def _response_for_queued_request(
|
||||
self,
|
||||
request: Content,
|
||||
*,
|
||||
stream: bool,
|
||||
) -> AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
if not stream:
|
||||
return AgentResponse(messages=[Message(role="assistant", contents=[request])])
|
||||
|
||||
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
|
||||
await sleep(0)
|
||||
yield AgentResponseUpdate(role="assistant", contents=[request])
|
||||
|
||||
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates)
|
||||
|
||||
def _process_stream(
|
||||
self,
|
||||
context: AgentContext,
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
state: ToolApprovalState,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
|
||||
if context.session is None:
|
||||
raise RuntimeError("ToolApprovalMiddleware requires an AgentSession.")
|
||||
while True:
|
||||
context.messages = self._inject_collected_responses(context.messages, state)
|
||||
state_changed = bool(state.collected_approval_responses)
|
||||
state.collected_approval_responses.clear()
|
||||
if state_changed:
|
||||
_save_state(context.session, state, source_id=self.source_id)
|
||||
|
||||
await call_next()
|
||||
if not isinstance(context.result, ResponseStream):
|
||||
raise ValueError("Streaming ToolApprovalMiddleware requires a ResponseStream result.")
|
||||
|
||||
approval_requests: list[Content] = []
|
||||
async for update in context.result:
|
||||
approval_contents = [
|
||||
content for content in update.contents if content.type == "function_approval_request"
|
||||
]
|
||||
if not approval_contents:
|
||||
yield update
|
||||
continue
|
||||
approval_requests.extend(approval_contents)
|
||||
remaining_contents = [
|
||||
content for content in update.contents if content.type != "function_approval_request"
|
||||
]
|
||||
if remaining_contents:
|
||||
raw_finish_reason = update.finish_reason
|
||||
finish_reason: FinishReasonLiteral | FinishReason | None
|
||||
if isinstance(raw_finish_reason, str):
|
||||
finish_reason = FinishReason(raw_finish_reason)
|
||||
else:
|
||||
finish_reason = cast(FinishReasonLiteral | FinishReason | None, raw_finish_reason)
|
||||
yield AgentResponseUpdate(
|
||||
contents=remaining_contents,
|
||||
role=update.role,
|
||||
author_name=update.author_name,
|
||||
agent_id=update.agent_id,
|
||||
response_id=update.response_id,
|
||||
message_id=update.message_id,
|
||||
created_at=update.created_at,
|
||||
finish_reason=finish_reason,
|
||||
continuation_token=update.continuation_token,
|
||||
additional_properties=update.additional_properties,
|
||||
raw_representation=update.raw_representation,
|
||||
)
|
||||
await context.result.get_final_response()
|
||||
if not approval_requests:
|
||||
return
|
||||
|
||||
response_messages = [Message(role="assistant", contents=approval_requests)]
|
||||
all_auto_approved = await self._process_outbound_messages(response_messages, state)
|
||||
_save_state(context.session, state, source_id=self.source_id)
|
||||
if not all_auto_approved:
|
||||
for message in response_messages:
|
||||
if message.contents:
|
||||
yield AgentResponseUpdate(role=message.role, contents=message.contents)
|
||||
return
|
||||
context.messages = []
|
||||
context.result = None
|
||||
|
||||
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates)
|
||||
|
||||
def _prepare_inbound_messages(self, messages: Sequence[Message], state: ToolApprovalState) -> list[Message]:
|
||||
prepared: list[Message] = []
|
||||
for message in messages:
|
||||
replacement_contents: list[Content] = []
|
||||
changed = False
|
||||
for content in message.contents:
|
||||
if content.type == "function_approval_response":
|
||||
replacement = self._handle_inbound_approval_response(content, state)
|
||||
state.collected_approval_responses.append(replacement)
|
||||
changed = True
|
||||
continue
|
||||
replacement_contents.append(content)
|
||||
|
||||
if not changed:
|
||||
prepared.append(message)
|
||||
continue
|
||||
if replacement_contents:
|
||||
cloned = copy.copy(message)
|
||||
cloned.contents = replacement_contents
|
||||
prepared.append(cloned)
|
||||
return prepared
|
||||
|
||||
def _handle_inbound_approval_response(self, response: Content, state: ToolApprovalState) -> Content:
|
||||
scope = _get_always_approve_scope(response)
|
||||
if scope is None or not response.approved:
|
||||
return response
|
||||
|
||||
function_call = response.function_call
|
||||
if function_call is not None and function_call.type == "function_call" and function_call.name is not None:
|
||||
if scope == ALWAYS_APPROVE_TOOL:
|
||||
_add_rule_if_missing(
|
||||
state,
|
||||
ToolApprovalRule(
|
||||
tool_name=function_call.name,
|
||||
server_label=_server_label(function_call),
|
||||
),
|
||||
)
|
||||
else:
|
||||
_add_rule_if_missing(
|
||||
state,
|
||||
ToolApprovalRule(
|
||||
tool_name=function_call.name,
|
||||
arguments=_serialize_arguments(function_call),
|
||||
server_label=_server_label(function_call),
|
||||
),
|
||||
)
|
||||
return _clone_without_always_approve_metadata(response)
|
||||
|
||||
def _inject_collected_responses(self, messages: Sequence[Message], state: ToolApprovalState) -> list[Message]:
|
||||
if not state.collected_approval_responses:
|
||||
return list(messages)
|
||||
return [Message(role="user", contents=list(state.collected_approval_responses)), *messages]
|
||||
|
||||
async def _drain_auto_approvable_queue(self, state: ToolApprovalState) -> None:
|
||||
remaining: list[Content] = []
|
||||
for request in state.queued_approval_requests:
|
||||
if _matches_rule(request, state.rules) or await self._matches_auto_rule(request):
|
||||
state.collected_approval_responses.append(request.to_function_approval_response(approved=True))
|
||||
continue
|
||||
remaining.append(request)
|
||||
state.queued_approval_requests = remaining
|
||||
|
||||
def _pop_next_queued_request(self, state: ToolApprovalState) -> Content | None:
|
||||
if not state.queued_approval_requests:
|
||||
return None
|
||||
return state.queued_approval_requests.pop(0)
|
||||
|
||||
async def _process_outbound_messages(self, messages: list[Message], state: ToolApprovalState) -> bool:
|
||||
approval_requests = [
|
||||
content
|
||||
for message in messages
|
||||
for content in message.contents
|
||||
if content.type == "function_approval_request"
|
||||
]
|
||||
if not approval_requests:
|
||||
return False
|
||||
|
||||
auto_approved: set[int] = set()
|
||||
unresolved: list[Content] = []
|
||||
for request in approval_requests:
|
||||
if _matches_rule(request, state.rules) or await self._matches_auto_rule(request):
|
||||
state.collected_approval_responses.append(request.to_function_approval_response(approved=True))
|
||||
auto_approved.add(id(request))
|
||||
else:
|
||||
unresolved.append(request)
|
||||
|
||||
if not auto_approved and len(unresolved) <= 1:
|
||||
return False
|
||||
|
||||
queued_ids: set[int] = set()
|
||||
for request in unresolved[1:]:
|
||||
queued_ids.add(id(request))
|
||||
state.queued_approval_requests.append(request)
|
||||
|
||||
remove_ids = auto_approved | queued_ids
|
||||
self._remove_approval_requests(messages, remove_ids)
|
||||
return not unresolved
|
||||
|
||||
@staticmethod
|
||||
def _remove_approval_requests(messages: list[Message], remove_ids: set[int]) -> None:
|
||||
for message_index in range(len(messages) - 1, -1, -1):
|
||||
message = messages[message_index]
|
||||
filtered = [
|
||||
content
|
||||
for content in message.contents
|
||||
if content.type != "function_approval_request" or id(content) not in remove_ids
|
||||
]
|
||||
if len(filtered) == len(message.contents):
|
||||
continue
|
||||
if filtered:
|
||||
message.contents = filtered
|
||||
else:
|
||||
messages.pop(message_index)
|
||||
|
||||
async def _matches_auto_rule(self, request: Content) -> bool:
|
||||
function_call = _function_call_from_request(request)
|
||||
if function_call is None:
|
||||
return False
|
||||
for rule in self.auto_approval_rules:
|
||||
result = rule(function_call)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
if result:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ALWAYS_APPROVE_PROPERTY",
|
||||
"ALWAYS_APPROVE_SCOPE_PROPERTY",
|
||||
"ALWAYS_APPROVE_TOOL",
|
||||
"ALWAYS_APPROVE_TOOL_WITH_ARGUMENTS",
|
||||
"DEFAULT_TOOL_APPROVAL_SOURCE_ID",
|
||||
"ToolApprovalMiddleware",
|
||||
"ToolApprovalRule",
|
||||
"ToolApprovalRuleCallback",
|
||||
"ToolApprovalState",
|
||||
"create_always_approve_tool_response",
|
||||
"create_always_approve_tool_with_arguments_response",
|
||||
]
|
||||
@@ -90,6 +90,9 @@ logger = logging.getLogger("agent_framework")
|
||||
DEFAULT_MAX_ITERATIONS: Final[int] = 40
|
||||
DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST: Final[int] = 3
|
||||
SHELL_TOOL_KIND_VALUE: Final[str] = "shell"
|
||||
_TOOL_APPROVAL_STATE_KEY: Final[str] = "tool_approval"
|
||||
_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY: Final[str] = "already_approved_approval_request_groups"
|
||||
_FUNCTION_INVOCATION_BUDGET_STATE_KEY: Final[str] = "_function_invocation_budget_state"
|
||||
ApprovalMode: TypeAlias = Literal["always_require", "never_require"]
|
||||
ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]")
|
||||
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
|
||||
@@ -1685,15 +1688,15 @@ async def _try_execute_function_calls(
|
||||
# The live tools list (when tools is the run-local list) is exposed on the
|
||||
# FunctionInvocationContext so tools can add/remove tools during the run.
|
||||
live_tools: list[ToolTypes] | None = cast("list[ToolTypes]", tools) if isinstance(tools, list) else None
|
||||
approval_tools = [tool_name for tool_name, tool in tool_map.items() if tool.approval_mode == "always_require"]
|
||||
approval_tools = {tool_name for tool_name, tool in tool_map.items() if tool.approval_mode == "always_require"}
|
||||
logger.debug(
|
||||
"_try_execute_function_calls: tool_map keys=%s, approval_tools=%s",
|
||||
list(tool_map.keys()),
|
||||
approval_tools,
|
||||
)
|
||||
declaration_only = [tool_name for tool_name, tool in tool_map.items() if tool.declaration_only]
|
||||
declaration_only = {tool_name for tool_name, tool in tool_map.items() if tool.declaration_only}
|
||||
configured_additional_tools = config.get("additional_tools") or []
|
||||
additional_tool_names = [tool.name for tool in configured_additional_tools]
|
||||
additional_tool_names = {tool.name for tool in configured_additional_tools}
|
||||
# check if any are calling functions that need approval
|
||||
# if so, we return approval request for all
|
||||
approval_needed = False
|
||||
@@ -1719,15 +1722,39 @@ async def _try_execute_function_calls(
|
||||
raise KeyError(f'Error: Requested function "{fcc.name}" not found.') # type: ignore[attr-defined]
|
||||
if approval_needed:
|
||||
# approval can only be needed for Function Call Content, not Approval Responses.
|
||||
logger.debug("Returning function_approval_request contents")
|
||||
return (
|
||||
[
|
||||
Content.from_function_approval_request(id=fcc.call_id, function_call=fcc) # type: ignore[attr-defined, arg-type]
|
||||
for fcc in function_calls
|
||||
if fcc.type == "function_call"
|
||||
],
|
||||
False,
|
||||
logger.debug("Returning visible function_approval_request contents and storing already-approved requests")
|
||||
visible_requests: list[Content] = []
|
||||
already_approved_requests: list[Content] = []
|
||||
for fcc in function_calls:
|
||||
if fcc.type != "function_call":
|
||||
continue
|
||||
approval_request = Content.from_function_approval_request(
|
||||
id=fcc.call_id, # type: ignore[arg-type]
|
||||
function_call=fcc,
|
||||
)
|
||||
tool_name = fcc.name # type: ignore[attr-defined]
|
||||
if tool_name is None:
|
||||
visible_requests.append(approval_request)
|
||||
continue
|
||||
tool = tool_map.get(tool_name)
|
||||
if (
|
||||
tool_name in approval_tools
|
||||
or tool is None
|
||||
or tool_name in declaration_only
|
||||
or tool_name in additional_tool_names
|
||||
):
|
||||
visible_requests.append(approval_request)
|
||||
continue
|
||||
if invocation_session is None:
|
||||
visible_requests.append(approval_request)
|
||||
continue
|
||||
already_approved_requests.append(approval_request)
|
||||
_store_already_approved_approval_requests(
|
||||
invocation_session,
|
||||
visible_requests,
|
||||
already_approved_requests,
|
||||
)
|
||||
return (visible_requests, False)
|
||||
if declaration_only_flag:
|
||||
# return the declaration only tools to the user, since we cannot execute them.
|
||||
# Mark as user_input_request so AgentExecutor emits request_info events and pauses the workflow.
|
||||
@@ -1912,6 +1939,108 @@ def _is_hosted_tool_approval(content: Any) -> bool:
|
||||
return bool(ap and ap.get("server_label"))
|
||||
|
||||
|
||||
def _get_tool_approval_state(invocation_session: AgentSession | None) -> dict[str, Any] | None:
|
||||
"""Return the shared tool-approval state bag for the invocation session."""
|
||||
if invocation_session is None:
|
||||
return None
|
||||
raw_state = invocation_session.state.get(_TOOL_APPROVAL_STATE_KEY)
|
||||
if isinstance(raw_state, dict):
|
||||
return cast(dict[str, Any], raw_state)
|
||||
from ._harness._tool_approval import ToolApprovalState
|
||||
|
||||
if isinstance(raw_state, ToolApprovalState):
|
||||
serialized_state = raw_state.to_dict(exclude={"type"})
|
||||
invocation_session.state[_TOOL_APPROVAL_STATE_KEY] = serialized_state
|
||||
return serialized_state
|
||||
if raw_state is not None:
|
||||
raise TypeError(
|
||||
f"Session state for {_TOOL_APPROVAL_STATE_KEY!r} must be a dict or ToolApprovalState, "
|
||||
f"got {type(raw_state).__name__}."
|
||||
)
|
||||
new_state: dict[str, Any] = {}
|
||||
invocation_session.state[_TOOL_APPROVAL_STATE_KEY] = new_state
|
||||
return new_state
|
||||
|
||||
|
||||
def _content_from_state(value: Any) -> Content | None:
|
||||
"""Restore a Content item stored in session state."""
|
||||
from ._types import Content
|
||||
|
||||
if isinstance(value, Content):
|
||||
return value
|
||||
if isinstance(value, Mapping):
|
||||
return Content.from_dict(cast(Mapping[str, Any], value))
|
||||
return None
|
||||
|
||||
|
||||
def _store_already_approved_approval_requests(
|
||||
invocation_session: AgentSession | None,
|
||||
visible_approval_requests: Sequence[Content],
|
||||
already_approved_requests: Sequence[Content],
|
||||
) -> None:
|
||||
"""Store hidden already-approved requests keyed by the visible approvals that resume the batch."""
|
||||
if not already_approved_requests:
|
||||
return
|
||||
state = _get_tool_approval_state(invocation_session)
|
||||
if state is None:
|
||||
return
|
||||
visible_ids = [request.id for request in visible_approval_requests if request.id]
|
||||
if not visible_ids:
|
||||
return
|
||||
|
||||
existing_groups = state.get(_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY)
|
||||
pending_groups: list[Any] = (
|
||||
list(cast(Iterable[Any], existing_groups)) if isinstance(existing_groups, list) else []
|
||||
)
|
||||
pending_groups.append({
|
||||
"approval_request_ids": visible_ids,
|
||||
"approval_requests": [request.to_dict() for request in already_approved_requests],
|
||||
})
|
||||
state[_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY] = pending_groups
|
||||
|
||||
|
||||
def _pop_already_approved_approval_responses(
|
||||
invocation_session: AgentSession | None,
|
||||
approval_response_ids: set[str],
|
||||
) -> list[Content]:
|
||||
"""Pop already-approved requests for the visible approval ids being answered."""
|
||||
if not approval_response_ids:
|
||||
return []
|
||||
state = _get_tool_approval_state(invocation_session)
|
||||
if state is None:
|
||||
return []
|
||||
raw_groups = state.get(_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY, [])
|
||||
if not isinstance(raw_groups, list):
|
||||
return []
|
||||
|
||||
responses: list[Content] = []
|
||||
remaining_groups: list[Any] = []
|
||||
raw_group_items = list(cast(Iterable[Any], raw_groups))
|
||||
for raw_group in raw_group_items:
|
||||
if not isinstance(raw_group, Mapping):
|
||||
continue
|
||||
group = cast(Mapping[str, Any], raw_group)
|
||||
raw_ids = group.get("approval_request_ids")
|
||||
raw_group_ids: Iterable[Any] = cast(Iterable[Any], raw_ids) if isinstance(raw_ids, list) else ()
|
||||
group_ids = {str(item) for item in raw_group_ids}
|
||||
if group_ids.isdisjoint(approval_response_ids):
|
||||
remaining_groups.append(raw_group)
|
||||
continue
|
||||
raw_requests = group.get("approval_requests")
|
||||
if not isinstance(raw_requests, list):
|
||||
continue
|
||||
for raw_request in list(cast(Iterable[Any], raw_requests)):
|
||||
request = _content_from_state(raw_request)
|
||||
if request is None or request.type != "function_approval_request":
|
||||
continue
|
||||
responses.append(request.to_function_approval_response(approved=True))
|
||||
if remaining_groups:
|
||||
state[_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY] = remaining_groups
|
||||
else:
|
||||
state.pop(_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY, None)
|
||||
return responses
|
||||
|
||||
|
||||
def _collect_approval_responses(
|
||||
messages: list[Message],
|
||||
) -> dict[str, Content]:
|
||||
@@ -2157,8 +2286,24 @@ async def _process_function_requests(
|
||||
errors_in_a_row: int,
|
||||
max_errors: int,
|
||||
execute_function_calls: Callable[..., Awaitable[tuple[list[Content], bool, bool]]],
|
||||
invocation_session: AgentSession | None = None,
|
||||
) -> FunctionRequestResult:
|
||||
from ._types import Message
|
||||
|
||||
if prepped_messages is not None:
|
||||
explicit_approval_response_ids = {
|
||||
content.id
|
||||
for message in prepped_messages
|
||||
if isinstance(message, Message)
|
||||
for content in message.contents
|
||||
if content.type == "function_approval_response" and content.id
|
||||
}
|
||||
already_approved_responses = _pop_already_approved_approval_responses(
|
||||
invocation_session,
|
||||
explicit_approval_response_ids,
|
||||
)
|
||||
if already_approved_responses:
|
||||
prepped_messages.append(Message(role="user", contents=already_approved_responses))
|
||||
fcc_todo = _collect_approval_responses(prepped_messages)
|
||||
if not fcc_todo:
|
||||
fcc_todo = {}
|
||||
@@ -2362,6 +2507,10 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
function_middleware_pipeline = self._get_function_middleware_pipeline(runtime_middleware["function"])
|
||||
if runtime_middleware["chat"]:
|
||||
effective_client_kwargs["middleware"] = runtime_middleware["chat"]
|
||||
raw_budget_state = effective_client_kwargs.pop(_FUNCTION_INVOCATION_BUDGET_STATE_KEY, None)
|
||||
budget_state: dict[str, Any] = (
|
||||
cast(dict[str, Any], raw_budget_state) if isinstance(raw_budget_state, dict) else {}
|
||||
)
|
||||
max_errors = self.function_invocation_configuration.get(
|
||||
"max_consecutive_errors_per_request", DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST
|
||||
)
|
||||
@@ -2411,7 +2560,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
nonlocal mutable_options
|
||||
nonlocal filtered_kwargs
|
||||
errors_in_a_row: int = 0
|
||||
total_function_calls: int = 0
|
||||
total_function_calls = int(budget_state.get("total_function_calls", 0) or 0)
|
||||
max_function_calls: int | None = self.function_invocation_configuration.get("max_function_calls")
|
||||
prepped_messages = list(messages)
|
||||
fcc_messages: list[Message] = []
|
||||
@@ -2420,7 +2569,9 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
|
||||
loop_enabled = self.function_invocation_configuration.get("enabled", True)
|
||||
max_iterations = self.function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS)
|
||||
for attempt_idx in range(max_iterations if loop_enabled else 0):
|
||||
attempt_start = int(budget_state.get("attempt_count", 0) or 0)
|
||||
for attempt_idx in range(attempt_start, max_iterations if loop_enabled else 0):
|
||||
budget_state["attempt_count"] = attempt_idx + 1
|
||||
approval_result = await _process_function_requests(
|
||||
response=None,
|
||||
prepped_messages=prepped_messages,
|
||||
@@ -2430,12 +2581,21 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
errors_in_a_row=errors_in_a_row,
|
||||
max_errors=max_errors,
|
||||
execute_function_calls=execute_function_calls,
|
||||
invocation_session=invocation_session,
|
||||
)
|
||||
if approval_result.get("action") == "stop":
|
||||
response = ChatResponse(messages=prepped_messages)
|
||||
break
|
||||
errors_in_a_row = approval_result.get("errors_in_a_row", errors_in_a_row)
|
||||
total_function_calls += approval_result.get("function_call_count", 0)
|
||||
budget_state["total_function_calls"] = total_function_calls
|
||||
if max_function_calls is not None and total_function_calls >= max_function_calls:
|
||||
logger.info(
|
||||
"Maximum function calls reached (%d/%d). Stopping further function calls for this request.",
|
||||
total_function_calls,
|
||||
max_function_calls,
|
||||
)
|
||||
mutable_options["tool_choice"] = "none"
|
||||
|
||||
response = cast(
|
||||
ChatResponse[Any],
|
||||
@@ -2468,11 +2628,13 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
errors_in_a_row=errors_in_a_row,
|
||||
max_errors=max_errors,
|
||||
execute_function_calls=execute_function_calls,
|
||||
invocation_session=invocation_session,
|
||||
)
|
||||
if result.get("action") == "return":
|
||||
response.usage_details = aggregated_usage
|
||||
return _clear_internal_conversation_id(response)
|
||||
total_function_calls += result.get("function_call_count", 0)
|
||||
budget_state["total_function_calls"] = total_function_calls
|
||||
if result.get("action") == "stop":
|
||||
# Error threshold reached: force a final non-tool turn so
|
||||
# function_call_output items are submitted before exit.
|
||||
@@ -2549,7 +2711,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
nonlocal mutable_options
|
||||
nonlocal stream_result_hooks
|
||||
errors_in_a_row: int = 0
|
||||
total_function_calls: int = 0
|
||||
total_function_calls = int(budget_state.get("total_function_calls", 0) or 0)
|
||||
max_function_calls: int | None = self.function_invocation_configuration.get("max_function_calls")
|
||||
prepped_messages = list(messages)
|
||||
fcc_messages: list[Message] = []
|
||||
@@ -2557,7 +2719,9 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
|
||||
loop_enabled = self.function_invocation_configuration.get("enabled", True)
|
||||
max_iterations = self.function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS)
|
||||
for attempt_idx in range(max_iterations if loop_enabled else 0):
|
||||
attempt_start = int(budget_state.get("attempt_count", 0) or 0)
|
||||
for attempt_idx in range(attempt_start, max_iterations if loop_enabled else 0):
|
||||
budget_state["attempt_count"] = attempt_idx + 1
|
||||
approval_result = await _process_function_requests(
|
||||
response=None,
|
||||
prepped_messages=prepped_messages,
|
||||
@@ -2567,9 +2731,18 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
errors_in_a_row=errors_in_a_row,
|
||||
max_errors=max_errors,
|
||||
execute_function_calls=execute_function_calls,
|
||||
invocation_session=invocation_session,
|
||||
)
|
||||
errors_in_a_row = approval_result.get("errors_in_a_row", errors_in_a_row)
|
||||
total_function_calls += approval_result.get("function_call_count", 0)
|
||||
budget_state["total_function_calls"] = total_function_calls
|
||||
if max_function_calls is not None and total_function_calls >= max_function_calls:
|
||||
logger.info(
|
||||
"Maximum function calls reached (%d/%d). Stopping further function calls for this request.",
|
||||
total_function_calls,
|
||||
max_function_calls,
|
||||
)
|
||||
mutable_options["tool_choice"] = "none"
|
||||
if approval_result.get("action") == "stop":
|
||||
mutable_options["tool_choice"] = "none"
|
||||
return
|
||||
@@ -2622,9 +2795,11 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
errors_in_a_row=errors_in_a_row,
|
||||
max_errors=max_errors,
|
||||
execute_function_calls=execute_function_calls,
|
||||
invocation_session=invocation_session,
|
||||
)
|
||||
errors_in_a_row = result.get("errors_in_a_row", errors_in_a_row)
|
||||
total_function_calls += result.get("function_call_count", 0)
|
||||
budget_state["total_function_calls"] = total_function_calls
|
||||
if role := result.get("update_role"):
|
||||
yield ChatResponseUpdate(
|
||||
contents=result.get("function_call_results") or [],
|
||||
|
||||
@@ -0,0 +1,817 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from agent_framework import (
|
||||
DEFAULT_TOOL_APPROVAL_SOURCE_ID,
|
||||
Agent,
|
||||
AgentSession,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
Content,
|
||||
Message,
|
||||
SupportsChatGetResponse,
|
||||
ToolApprovalMiddleware,
|
||||
ToolApprovalState,
|
||||
create_always_approve_tool_response,
|
||||
create_always_approve_tool_with_arguments_response,
|
||||
tool,
|
||||
)
|
||||
|
||||
|
||||
def _approval_requests(messages: list[Message]) -> list[Content]:
|
||||
return [
|
||||
content for message in messages for content in message.contents if content.type == "function_approval_request"
|
||||
]
|
||||
|
||||
|
||||
async def test_mixed_batch_hides_already_approved_request_until_approval_replay(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
"""Mixed batches should only show real approval requests when a session can store hidden requests."""
|
||||
no_approval_calls = 0
|
||||
approval_calls = 0
|
||||
|
||||
@tool(name="lookup_work_items", approval_mode="never_require")
|
||||
def lookup_work_items(query: str) -> str:
|
||||
nonlocal no_approval_calls
|
||||
no_approval_calls += 1
|
||||
return f"found {query}"
|
||||
|
||||
@tool(name="add_comment", approval_mode="always_require")
|
||||
def add_comment(comment: str) -> str:
|
||||
nonlocal approval_calls
|
||||
approval_calls += 1
|
||||
return f"added {comment}"
|
||||
|
||||
agent = Agent(client=chat_client_base, tools=[lookup_work_items, add_comment])
|
||||
session = AgentSession(session_id="approval-session")
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_lookup",
|
||||
name="lookup_work_items",
|
||||
arguments='{"query": "mine"}',
|
||||
),
|
||||
Content.from_function_call(
|
||||
call_id="call_comment",
|
||||
name="add_comment",
|
||||
arguments='{"comment": "done"}',
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
first_response = await agent.run("update work item", session=session)
|
||||
|
||||
requests = _approval_requests(first_response.messages)
|
||||
assert [request.function_call.name for request in requests] == ["add_comment"]
|
||||
assert no_approval_calls == 0
|
||||
assert approval_calls == 0
|
||||
|
||||
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["complete"]))]
|
||||
second_response = await agent.run(requests[0].to_function_approval_response(approved=True), session=session)
|
||||
|
||||
assert second_response.text == "complete"
|
||||
assert no_approval_calls == 1
|
||||
assert approval_calls == 1
|
||||
|
||||
|
||||
async def test_mixed_batch_accepts_restored_tool_approval_state(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
"""Mixed-batch bypass should work when session state contains ToolApprovalState."""
|
||||
safe_calls = 0
|
||||
risky_calls = 0
|
||||
|
||||
@tool(name="safe_read", approval_mode="never_require")
|
||||
def safe_read() -> str:
|
||||
nonlocal safe_calls
|
||||
safe_calls += 1
|
||||
return "safe"
|
||||
|
||||
@tool(name="risky_write", approval_mode="always_require")
|
||||
def risky_write() -> str:
|
||||
nonlocal risky_calls
|
||||
risky_calls += 1
|
||||
return "risky"
|
||||
|
||||
agent = Agent(client=chat_client_base, tools=[safe_read, risky_write])
|
||||
session = AgentSession(session_id="restored-state-session")
|
||||
session.state[DEFAULT_TOOL_APPROVAL_SOURCE_ID] = ToolApprovalState()
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="call_safe", name="safe_read", arguments="{}"),
|
||||
Content.from_function_call(call_id="call_risky", name="risky_write", arguments="{}"),
|
||||
],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
first_response = await agent.run("read and write", session=session)
|
||||
requests = _approval_requests(first_response.messages)
|
||||
|
||||
assert [request.function_call.name for request in requests] == ["risky_write"]
|
||||
assert safe_calls == 0
|
||||
assert risky_calls == 0
|
||||
|
||||
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["done"]))]
|
||||
final_response = await agent.run(requests[0].to_function_approval_response(approved=True), session=session)
|
||||
|
||||
assert final_response.text == "done"
|
||||
assert safe_calls == 1
|
||||
assert risky_calls == 1
|
||||
|
||||
|
||||
async def test_hidden_mixed_batch_requests_do_not_replay_on_unrelated_turn(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
"""Stored hidden approvals should only replay when an approval response resumes the flow."""
|
||||
safe_calls = 0
|
||||
risky_calls = 0
|
||||
|
||||
@tool(name="safe_lookup", approval_mode="never_require")
|
||||
def safe_lookup() -> str:
|
||||
nonlocal safe_calls
|
||||
safe_calls += 1
|
||||
return "safe"
|
||||
|
||||
@tool(name="risky_update", approval_mode="always_require")
|
||||
def risky_update() -> str:
|
||||
nonlocal risky_calls
|
||||
risky_calls += 1
|
||||
return "risky"
|
||||
|
||||
agent = Agent(client=chat_client_base, tools=[safe_lookup, risky_update])
|
||||
session = AgentSession(session_id="stale-hidden-session")
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="call_safe", name="safe_lookup", arguments="{}"),
|
||||
Content.from_function_call(call_id="call_risky", name="risky_update", arguments="{}"),
|
||||
],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
first_response = await agent.run("lookup and update", session=session)
|
||||
request = _approval_requests(first_response.messages)[0]
|
||||
|
||||
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["unrelated"]))]
|
||||
unrelated_response = await agent.run("never mind, answer something else", session=session)
|
||||
|
||||
assert unrelated_response.text == "unrelated"
|
||||
assert safe_calls == 0
|
||||
assert risky_calls == 0
|
||||
|
||||
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["done"]))]
|
||||
final_response = await agent.run(request.to_function_approval_response(approved=True), session=session)
|
||||
|
||||
assert final_response.text == "done"
|
||||
assert safe_calls == 1
|
||||
assert risky_calls == 1
|
||||
|
||||
|
||||
async def test_hidden_mixed_batch_requests_replay_only_for_matching_visible_approval(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
"""Approving one mixed batch must not replay hidden calls from another abandoned batch."""
|
||||
safe_a_calls = 0
|
||||
safe_b_calls = 0
|
||||
risky_a_calls = 0
|
||||
risky_b_calls = 0
|
||||
|
||||
@tool(name="safe_a", approval_mode="never_require")
|
||||
def safe_a() -> str:
|
||||
nonlocal safe_a_calls
|
||||
safe_a_calls += 1
|
||||
return "safe-a"
|
||||
|
||||
@tool(name="safe_b", approval_mode="never_require")
|
||||
def safe_b() -> str:
|
||||
nonlocal safe_b_calls
|
||||
safe_b_calls += 1
|
||||
return "safe-b"
|
||||
|
||||
@tool(name="risky_a", approval_mode="always_require")
|
||||
def risky_a() -> str:
|
||||
nonlocal risky_a_calls
|
||||
risky_a_calls += 1
|
||||
return "risky-a"
|
||||
|
||||
@tool(name="risky_b", approval_mode="always_require")
|
||||
def risky_b() -> str:
|
||||
nonlocal risky_b_calls
|
||||
risky_b_calls += 1
|
||||
return "risky-b"
|
||||
|
||||
agent = Agent(client=chat_client_base, tools=[safe_a, safe_b, risky_a, risky_b])
|
||||
session = AgentSession(session_id="grouped-hidden-session")
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="call_safe_a", name="safe_a", arguments="{}"),
|
||||
Content.from_function_call(call_id="call_risky_a", name="risky_a", arguments="{}"),
|
||||
],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
first_response = await agent.run("batch a", session=session)
|
||||
assert [request.function_call.name for request in _approval_requests(first_response.messages)] == ["risky_a"]
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="call_safe_b", name="safe_b", arguments="{}"),
|
||||
Content.from_function_call(call_id="call_risky_b", name="risky_b", arguments="{}"),
|
||||
],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
second_response = await agent.run("batch b", session=session)
|
||||
second_request = _approval_requests(second_response.messages)[0]
|
||||
|
||||
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["done"]))]
|
||||
final_response = await agent.run(second_request.to_function_approval_response(approved=True), session=session)
|
||||
|
||||
assert final_response.text == "done"
|
||||
assert safe_a_calls == 0
|
||||
assert risky_a_calls == 0
|
||||
assert safe_b_calls == 1
|
||||
assert risky_b_calls == 1
|
||||
|
||||
|
||||
async def test_tool_approval_middleware_queues_multiple_approval_requests(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
"""The opt-in middleware should present multiple unresolved approvals one at a time."""
|
||||
first_calls = 0
|
||||
second_calls = 0
|
||||
|
||||
@tool(name="first_tool", approval_mode="always_require")
|
||||
def first_tool() -> str:
|
||||
nonlocal first_calls
|
||||
first_calls += 1
|
||||
return "first"
|
||||
|
||||
@tool(name="second_tool", approval_mode="always_require")
|
||||
def second_tool() -> str:
|
||||
nonlocal second_calls
|
||||
second_calls += 1
|
||||
return "second"
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
tools=[first_tool, second_tool],
|
||||
middleware=[ToolApprovalMiddleware()],
|
||||
)
|
||||
session = AgentSession(session_id="queue-session")
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="call_first", name="first_tool", arguments="{}"),
|
||||
Content.from_function_call(call_id="call_second", name="second_tool", arguments="{}"),
|
||||
],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
first_response = await agent.run("call both", session=session)
|
||||
|
||||
first_requests = _approval_requests(first_response.messages)
|
||||
assert [request.function_call.name for request in first_requests] == ["first_tool"]
|
||||
assert first_calls == 0
|
||||
assert second_calls == 0
|
||||
|
||||
second_response = await agent.run(first_requests[0].to_function_approval_response(approved=True), session=session)
|
||||
|
||||
second_requests = _approval_requests(second_response.messages)
|
||||
assert [request.function_call.name for request in second_requests] == ["second_tool"]
|
||||
assert first_calls == 0
|
||||
assert second_calls == 0
|
||||
|
||||
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["done"]))]
|
||||
final_response = await agent.run(second_requests[0].to_function_approval_response(approved=True), session=session)
|
||||
|
||||
assert final_response.text == "done"
|
||||
assert first_calls == 1
|
||||
assert second_calls == 1
|
||||
|
||||
|
||||
async def test_tool_approval_middleware_preserves_hidden_mixed_batch_requests(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
"""Middleware state saves should not discard core hidden already-approved requests."""
|
||||
lookup_calls = 0
|
||||
write_calls = 0
|
||||
|
||||
@tool(name="lookup_records", approval_mode="never_require")
|
||||
def lookup_records() -> str:
|
||||
nonlocal lookup_calls
|
||||
lookup_calls += 1
|
||||
return "records"
|
||||
|
||||
@tool(name="write_record", approval_mode="always_require")
|
||||
def write_record() -> str:
|
||||
nonlocal write_calls
|
||||
write_calls += 1
|
||||
return "written"
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
tools=[lookup_records, write_record],
|
||||
middleware=[ToolApprovalMiddleware()],
|
||||
)
|
||||
session = AgentSession(session_id="mixed-middleware-session")
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="call_lookup", name="lookup_records", arguments="{}"),
|
||||
Content.from_function_call(call_id="call_write", name="write_record", arguments="{}"),
|
||||
],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
first_response = await agent.run("lookup and write", session=session)
|
||||
request = _approval_requests(first_response.messages)[0]
|
||||
|
||||
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["done"]))]
|
||||
second_response = await agent.run(request.to_function_approval_response(approved=True), session=session)
|
||||
|
||||
assert second_response.text == "done"
|
||||
assert lookup_calls == 1
|
||||
assert write_calls == 1
|
||||
|
||||
|
||||
async def test_tool_approval_middleware_auto_approval_rule_receives_function_call(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
"""Heuristic auto-approval callbacks should receive function-call content and approve matching calls."""
|
||||
auto_calls = 0
|
||||
manual_calls = 0
|
||||
seen_calls: list[tuple[str, str | None]] = []
|
||||
|
||||
@tool(name="auto_write", approval_mode="always_require")
|
||||
def auto_write() -> str:
|
||||
nonlocal auto_calls
|
||||
auto_calls += 1
|
||||
return "auto"
|
||||
|
||||
@tool(name="manual_write", approval_mode="always_require")
|
||||
def manual_write() -> str:
|
||||
nonlocal manual_calls
|
||||
manual_calls += 1
|
||||
return "manual"
|
||||
|
||||
async def auto_approve_auto_write(function_call: Content) -> bool:
|
||||
seen_calls.append((function_call.type, function_call.name))
|
||||
return function_call.name == "auto_write"
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
tools=[auto_write, manual_write],
|
||||
middleware=[ToolApprovalMiddleware(auto_approval_rules=[auto_approve_auto_write])],
|
||||
)
|
||||
session = AgentSession(session_id="heuristic-session")
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="call_auto", name="auto_write", arguments="{}"),
|
||||
Content.from_function_call(call_id="call_manual", name="manual_write", arguments="{}"),
|
||||
],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
first_response = await agent.run("write both", session=session)
|
||||
|
||||
requests = _approval_requests(first_response.messages)
|
||||
assert [request.function_call.name for request in requests] == ["manual_write"]
|
||||
assert seen_calls == [("function_call", "auto_write"), ("function_call", "manual_write")]
|
||||
assert auto_calls == 0
|
||||
assert manual_calls == 0
|
||||
|
||||
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["done"]))]
|
||||
final_response = await agent.run(requests[0].to_function_approval_response(approved=True), session=session)
|
||||
|
||||
assert final_response.text == "done"
|
||||
assert auto_calls == 1
|
||||
assert manual_calls == 1
|
||||
|
||||
|
||||
async def test_tool_approval_middleware_auto_approved_loops_share_function_call_budget(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
"""Auto-approved re-entry should not reset max_function_calls."""
|
||||
calls = 0
|
||||
|
||||
@tool(name="budgeted_tool", approval_mode="always_require")
|
||||
def budgeted_tool(value: str) -> str:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return value
|
||||
|
||||
def auto_approve_budgeted_tool(function_call: Content) -> bool:
|
||||
return function_call.name == "budgeted_tool"
|
||||
|
||||
chat_client_base.function_invocation_configuration["max_function_calls"] = 1 # type: ignore[attr-defined]
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
tools=[budgeted_tool],
|
||||
middleware=[ToolApprovalMiddleware(auto_approval_rules=[auto_approve_budgeted_tool])],
|
||||
)
|
||||
session = AgentSession(session_id="shared-budget-session")
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_first",
|
||||
name="budgeted_tool",
|
||||
arguments='{"value": "first"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_second",
|
||||
name="budgeted_tool",
|
||||
arguments='{"value": "second"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
response = await agent.run("call repeatedly", session=session)
|
||||
|
||||
assert response.text == "I broke out of the function invocation loop..."
|
||||
assert calls == 1
|
||||
|
||||
|
||||
async def test_tool_approval_middleware_queues_streamed_approval_requests(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
"""Streaming approval requests should also be queued one at a time."""
|
||||
calls = 0
|
||||
|
||||
@tool(name="first_streamed_tool", approval_mode="always_require")
|
||||
def first_streamed_tool() -> str:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return "first"
|
||||
|
||||
@tool(name="second_streamed_tool", approval_mode="always_require")
|
||||
def second_streamed_tool() -> str:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return "second"
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
tools=[first_streamed_tool, second_streamed_tool],
|
||||
middleware=[ToolApprovalMiddleware()],
|
||||
)
|
||||
session = AgentSession(session_id="stream-queue-session")
|
||||
chat_client_base.streaming_responses = [
|
||||
[
|
||||
ChatResponseUpdate(
|
||||
contents=[Content.from_function_call(call_id="call_first", name="first_streamed_tool", arguments="{}")],
|
||||
role="assistant",
|
||||
),
|
||||
ChatResponseUpdate(
|
||||
contents=[
|
||||
Content.from_function_call(call_id="call_second", name="second_streamed_tool", arguments="{}")
|
||||
],
|
||||
role="assistant",
|
||||
),
|
||||
]
|
||||
]
|
||||
|
||||
first_stream = agent.run("call both", stream=True, session=session)
|
||||
first_updates = [update async for update in first_stream]
|
||||
first_requests = [content for update in first_updates for content in update.user_input_requests]
|
||||
assert [request.function_call.name for request in first_requests] == ["first_streamed_tool"]
|
||||
assert calls == 0
|
||||
|
||||
second_stream = agent.run(
|
||||
first_requests[0].to_function_approval_response(approved=True),
|
||||
stream=True,
|
||||
session=session,
|
||||
)
|
||||
second_updates = [update async for update in second_stream]
|
||||
second_requests = [content for update in second_updates for content in update.user_input_requests]
|
||||
assert [request.function_call.name for request in second_requests] == ["second_streamed_tool"]
|
||||
assert calls == 0
|
||||
|
||||
chat_client_base.streaming_responses = [
|
||||
[ChatResponseUpdate(contents=[Content.from_text("done")], role="assistant")]
|
||||
]
|
||||
final_stream = agent.run(
|
||||
second_requests[0].to_function_approval_response(approved=True),
|
||||
stream=True,
|
||||
session=session,
|
||||
)
|
||||
final_updates = [update async for update in final_stream]
|
||||
final_response = await final_stream.get_final_response()
|
||||
|
||||
assert final_updates[-1].text == "done"
|
||||
assert final_response.text == "done"
|
||||
assert calls == 2
|
||||
|
||||
|
||||
async def test_tool_approval_middleware_always_approve_tool_rule(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
"""An always-approve response should add a standing tool-level approval rule."""
|
||||
calls = 0
|
||||
|
||||
@tool(name="dangerous_tool", approval_mode="always_require")
|
||||
def dangerous_tool(value: str) -> str:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return value
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
tools=[dangerous_tool],
|
||||
middleware=[ToolApprovalMiddleware()],
|
||||
)
|
||||
session = AgentSession(session_id="standing-rule-session")
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_initial",
|
||||
name="dangerous_tool",
|
||||
arguments='{"value": "one"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
first_response = await agent.run("call once", session=session)
|
||||
first_request = _approval_requests(first_response.messages)[0]
|
||||
|
||||
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["first done"]))]
|
||||
await agent.run(create_always_approve_tool_response(first_request), session=session)
|
||||
|
||||
assert calls == 1
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_auto",
|
||||
name="dangerous_tool",
|
||||
arguments='{"value": "two"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatResponse(messages=Message(role="assistant", contents=["second done"])),
|
||||
]
|
||||
|
||||
second_response = await agent.run("call again", session=session)
|
||||
|
||||
assert second_response.text == "second done"
|
||||
assert calls == 2
|
||||
|
||||
|
||||
async def test_tool_approval_middleware_standing_rules_include_hosted_server_boundary(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
"""A standing hosted-tool rule should only match the same server_label."""
|
||||
calls = 0
|
||||
|
||||
@tool(name="hosted_tool", approval_mode="always_require")
|
||||
def hosted_tool() -> str:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return "hosted"
|
||||
|
||||
def hosted_call(call_id: str, server_label: str) -> Content:
|
||||
return Content.from_function_call(
|
||||
call_id=call_id,
|
||||
name="hosted_tool",
|
||||
arguments="{}",
|
||||
additional_properties={"server_label": server_label},
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
tools=[hosted_tool],
|
||||
middleware=[ToolApprovalMiddleware()],
|
||||
)
|
||||
session = AgentSession(session_id="hosted-boundary-session")
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(messages=Message(role="assistant", contents=[hosted_call("call_initial", "server-a")]))
|
||||
]
|
||||
|
||||
first_response = await agent.run("call hosted a", session=session)
|
||||
first_request = _approval_requests(first_response.messages)[0]
|
||||
|
||||
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["server a done"]))]
|
||||
await agent.run(create_always_approve_tool_response(first_request), session=session)
|
||||
|
||||
assert calls == 0
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(messages=Message(role="assistant", contents=[hosted_call("call_same_server", "server-a")])),
|
||||
ChatResponse(messages=Message(role="assistant", contents=["same server done"])),
|
||||
]
|
||||
|
||||
same_server_response = await agent.run("call hosted a again", session=session)
|
||||
|
||||
assert same_server_response.text == "same server done"
|
||||
assert _approval_requests(same_server_response.messages) == []
|
||||
assert calls == 0
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(messages=Message(role="assistant", contents=[hosted_call("call_other_server", "server-b")]))
|
||||
]
|
||||
|
||||
other_server_response = await agent.run("call hosted b", session=session)
|
||||
|
||||
requests = _approval_requests(other_server_response.messages)
|
||||
assert [request.function_call.additional_properties["server_label"] for request in requests] == ["server-b"]
|
||||
assert calls == 0
|
||||
|
||||
|
||||
async def test_tool_approval_middleware_always_approve_tool_with_arguments_rule(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
"""Argument-scoped always-approve rules should require exact argument matches."""
|
||||
calls = 0
|
||||
|
||||
@tool(name="argument_scoped_tool", approval_mode="always_require")
|
||||
def argument_scoped_tool(value: str) -> str:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return value
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
tools=[argument_scoped_tool],
|
||||
middleware=[ToolApprovalMiddleware()],
|
||||
)
|
||||
session = AgentSession(session_id="argument-rule-session")
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_initial",
|
||||
name="argument_scoped_tool",
|
||||
arguments='{"value": "same"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
first_response = await agent.run("call with same", session=session)
|
||||
first_request = _approval_requests(first_response.messages)[0]
|
||||
|
||||
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["first done"]))]
|
||||
await agent.run(create_always_approve_tool_with_arguments_response(first_request), session=session)
|
||||
|
||||
assert calls == 1
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_same",
|
||||
name="argument_scoped_tool",
|
||||
arguments='{"value": "same"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatResponse(messages=Message(role="assistant", contents=["same done"])),
|
||||
]
|
||||
|
||||
second_response = await agent.run("call with same again", session=session)
|
||||
|
||||
assert second_response.text == "same done"
|
||||
assert calls == 2
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_different",
|
||||
name="argument_scoped_tool",
|
||||
arguments='{"value": "different"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
third_response = await agent.run("call with different args", session=session)
|
||||
|
||||
requests = _approval_requests(third_response.messages)
|
||||
assert [request.function_call.arguments for request in requests] == ['{"value": "different"}']
|
||||
assert calls == 2
|
||||
|
||||
|
||||
async def test_tool_approval_middleware_empty_arguments_rule_is_not_tool_wide(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
"""An argument-scoped no-argument approval should not become a wildcard."""
|
||||
calls = 0
|
||||
|
||||
@tool(name="optional_args_tool", approval_mode="always_require")
|
||||
def optional_args_tool(value: str = "default") -> str:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return value
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
tools=[optional_args_tool],
|
||||
middleware=[ToolApprovalMiddleware()],
|
||||
)
|
||||
session = AgentSession(session_id="empty-arguments-rule-session")
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_empty",
|
||||
name="optional_args_tool",
|
||||
arguments="{}",
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
first_response = await agent.run("call without args", session=session)
|
||||
first_request = _approval_requests(first_response.messages)[0]
|
||||
|
||||
chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=["empty done"]))]
|
||||
await agent.run(create_always_approve_tool_with_arguments_response(first_request), session=session)
|
||||
|
||||
assert calls == 1
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_non_empty",
|
||||
name="optional_args_tool",
|
||||
arguments='{"value": "custom"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
second_response = await agent.run("call with args", session=session)
|
||||
|
||||
requests = _approval_requests(second_response.messages)
|
||||
assert [request.function_call.arguments for request in requests] == ['{"value": "custom"}']
|
||||
assert calls == 1
|
||||
Reference in New Issue
Block a user