mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Hyperlight: thread-confine sandbox, skip parsing on host callbacks, schema/tool cleanup (#5424)
* improved parsing of tool call results and tweaks * Address PR review: skip_parsing flag, broader registry close, comment fix - FunctionTool.invoke now takes a boolean skip_parsing flag instead of the SKIP_PARSING sentinel; the sentinel is still accepted as result_parser at construction time to opt out of parsing for every call. The two paths are equivalent. - _SandboxRegistry.close now invokes any sandbox close/shutdown hook on the entry's own worker thread (PyO3 unsendable), then shuts the worker down, then cleans up the per-entry temporary directories. - Clarified the _SandboxWorker.shutdown comment to describe the actual ThreadPoolExecutor.shutdown(wait=False, cancel_futures=False) semantics. - Hyperlight host callback uses skip_parsing=True (the new flag). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Drop redundant 'is not SKIP_PARSING' guard that mypy 1.x flags After callable(configured_parser) the sentinel is already excluded; the extra identity check tripped mypy's non-overlapping identity warning. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fixed sandbox working on copy of tool --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
66e02c10e3
commit
58ff4ad3a9
@@ -125,6 +125,7 @@ from ._telemetry import (
|
||||
prepend_agent_framework_to_user_agent,
|
||||
)
|
||||
from ._tools import (
|
||||
SKIP_PARSING,
|
||||
FunctionInvocationConfiguration,
|
||||
FunctionInvocationLayer,
|
||||
FunctionTool,
|
||||
@@ -258,6 +259,7 @@ __all__ = [
|
||||
"GROUP_INDEX_KEY",
|
||||
"GROUP_KIND_KEY",
|
||||
"GROUP_TOKEN_COUNT_KEY",
|
||||
"SKIP_PARSING",
|
||||
"SUMMARIZED_BY_SUMMARY_ID_KEY",
|
||||
"SUMMARY_OF_GROUP_IDS_KEY",
|
||||
"SUMMARY_OF_MESSAGE_IDS_KEY",
|
||||
|
||||
@@ -94,6 +94,33 @@ ApprovalMode: TypeAlias = Literal["always_require", "never_require"]
|
||||
ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]")
|
||||
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
|
||||
|
||||
|
||||
class _SkipParsingSentinel:
|
||||
"""Sentinel signaling that :meth:`FunctionTool.invoke` should return the raw value.
|
||||
|
||||
When passed as ``result_parser`` to :class:`FunctionTool` (or the ``@tool`` decorator),
|
||||
the default :meth:`FunctionTool.parse_result` is bypassed and the wrapped function's
|
||||
return value is returned unchanged from :meth:`FunctionTool.invoke`. Callers may also
|
||||
request the raw value on a per-call basis by passing ``skip_parsing=True`` to
|
||||
:meth:`FunctionTool.invoke`.
|
||||
|
||||
Use the module-level ``SKIP_PARSING`` singleton — do not instantiate this class.
|
||||
"""
|
||||
|
||||
_instance: ClassVar[_SkipParsingSentinel | None] = None
|
||||
|
||||
def __new__(cls) -> _SkipParsingSentinel:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "SKIP_PARSING"
|
||||
|
||||
|
||||
SKIP_PARSING: Final[_SkipParsingSentinel] = _SkipParsingSentinel()
|
||||
"""Sentinel for ``FunctionTool(result_parser=...)`` meaning "do not parse the result"."""
|
||||
|
||||
# region Helpers
|
||||
|
||||
|
||||
@@ -279,7 +306,7 @@ class FunctionTool(SerializationMixin):
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
func: Callable[..., Any] | None = None,
|
||||
input_model: type[BaseModel] | Mapping[str, Any] | None = None,
|
||||
result_parser: Callable[[Any], str | list[Content]] | None = None,
|
||||
result_parser: Callable[[Any], str | list[Content]] | _SkipParsingSentinel | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the FunctionTool.
|
||||
@@ -327,9 +354,11 @@ class FunctionTool(SerializationMixin):
|
||||
result_parser: An optional callable with signature ``Callable[[Any], str]`` that
|
||||
overrides the default result parsing behavior. When provided, this callable
|
||||
is used to convert the raw function return value to a string instead of the
|
||||
built-in :meth:`parse_result` logic. Depending on your function, it may be
|
||||
easiest to just do the serialization directly in the function body rather
|
||||
than providing a custom ``result_parser``.
|
||||
built-in :meth:`parse_result` logic. Pass the :data:`SKIP_PARSING` sentinel
|
||||
instead of a callable to opt out of parsing entirely; in that case
|
||||
:meth:`invoke` returns the wrapped function's raw return value. Depending
|
||||
on your function, it may be easiest to just do the serialization directly
|
||||
in the function body rather than providing a custom ``result_parser``.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
# Core attributes (formerly from BaseTool)
|
||||
@@ -508,31 +537,65 @@ class FunctionTool(SerializationMixin):
|
||||
self.invocation_exception_count += 1
|
||||
raise
|
||||
|
||||
@overload
|
||||
async def invoke(
|
||||
self,
|
||||
*,
|
||||
arguments: BaseModel | Mapping[str, Any] | None = None,
|
||||
context: FunctionInvocationContext | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
skip_parsing: Literal[True],
|
||||
**kwargs: Any,
|
||||
) -> list[Content]:
|
||||
) -> Any: ...
|
||||
|
||||
@overload
|
||||
async def invoke(
|
||||
self,
|
||||
*,
|
||||
arguments: BaseModel | Mapping[str, Any] | None = None,
|
||||
context: FunctionInvocationContext | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
skip_parsing: Literal[False] = False,
|
||||
**kwargs: Any,
|
||||
) -> list[Content]: ...
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
*,
|
||||
arguments: BaseModel | Mapping[str, Any] | None = None,
|
||||
context: FunctionInvocationContext | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
skip_parsing: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> list[Content] | Any:
|
||||
"""Run the AI function with the provided arguments as a Pydantic model.
|
||||
|
||||
The raw return value of the wrapped function is automatically parsed into a
|
||||
``list[Content]`` using :meth:`parse_result` or the custom ``result_parser``
|
||||
if one was provided. Every result — text, rich media, or serialized objects —
|
||||
is represented uniformly as Content items.
|
||||
configured on the tool. Every result — text, rich media, or serialized
|
||||
objects — is represented uniformly as Content items.
|
||||
|
||||
Parsing can be skipped in two ways: configure the tool with
|
||||
``result_parser=SKIP_PARSING`` to always skip parsing, or pass
|
||||
``skip_parsing=True`` per call. Either way the wrapped function's raw value
|
||||
is returned. This is intended for callers (e.g. sandboxed runtimes) that
|
||||
consume the value from Python directly and would otherwise undo the
|
||||
``Content`` wrapping.
|
||||
|
||||
Keyword Args:
|
||||
arguments: A mapping or model instance containing the arguments for the function.
|
||||
context: Explicit function invocation context carrying runtime kwargs.
|
||||
tool_call_id: Optional tool call identifier used for telemetry and tracing.
|
||||
skip_parsing: When ``True``, bypass parsing and return the wrapped function's
|
||||
raw value instead of a ``list[Content]``. Defaults to ``False``.
|
||||
kwargs: Direct function argument values. When provided, every keyword
|
||||
must match a declared tool parameter. Runtime data must be passed
|
||||
via ``context``.
|
||||
|
||||
Returns:
|
||||
A list of Content items representing the tool output.
|
||||
``list[Content]`` by default. The raw function return value (``Any``) when
|
||||
``skip_parsing=True`` (or the tool was constructed with
|
||||
``result_parser=SKIP_PARSING``).
|
||||
|
||||
Raises:
|
||||
TypeError: If arguments is not mapping-like or fails schema checks.
|
||||
@@ -544,7 +607,9 @@ class FunctionTool(SerializationMixin):
|
||||
from ._types import Content
|
||||
from .observability import OBSERVABILITY_SETTINGS
|
||||
|
||||
parser = self.result_parser or FunctionTool.parse_result
|
||||
configured_parser = self.result_parser
|
||||
skip_parsing = skip_parsing or configured_parser is SKIP_PARSING
|
||||
parser = configured_parser if callable(configured_parser) else FunctionTool.parse_result
|
||||
|
||||
parameter_names = set(self.parameters().get("properties", {}).keys())
|
||||
direct_argument_kwargs = (
|
||||
@@ -616,6 +681,10 @@ class FunctionTool(SerializationMixin):
|
||||
logger.debug(f"Function arguments: {observable_kwargs}")
|
||||
res = self.__call__(**call_kwargs)
|
||||
result = await res if inspect.isawaitable(res) else res
|
||||
if skip_parsing:
|
||||
logger.info(f"Function {self.name} succeeded.")
|
||||
logger.debug(f"Function result: {type(result).__name__}")
|
||||
return result
|
||||
try:
|
||||
parsed = parser(result)
|
||||
except Exception:
|
||||
@@ -671,6 +740,13 @@ class FunctionTool(SerializationMixin):
|
||||
logger.error(f"Function failed. Error: {exception}")
|
||||
raise
|
||||
else:
|
||||
if skip_parsing:
|
||||
logger.info(f"Function {self.name} succeeded.")
|
||||
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined]
|
||||
result_str = str(result)
|
||||
span.set_attribute(OtelAttr.TOOL_RESULT, result_str)
|
||||
logger.debug(f"Function result: {result_str}")
|
||||
return result
|
||||
try:
|
||||
parsed = parser(result)
|
||||
except Exception:
|
||||
@@ -1067,7 +1143,7 @@ def tool(
|
||||
max_invocations: int | None = None,
|
||||
max_invocation_exceptions: int | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
result_parser: Callable[[Any], str | list[Content]] | None = None,
|
||||
result_parser: Callable[[Any], str | list[Content]] | _SkipParsingSentinel | None = None,
|
||||
) -> FunctionTool: ...
|
||||
|
||||
|
||||
@@ -1083,7 +1159,7 @@ def tool(
|
||||
max_invocations: int | None = None,
|
||||
max_invocation_exceptions: int | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
result_parser: Callable[[Any], str | list[Content]] | None = None,
|
||||
result_parser: Callable[[Any], str | list[Content]] | _SkipParsingSentinel | None = None,
|
||||
) -> Callable[[Callable[..., Any]], FunctionTool]: ...
|
||||
|
||||
|
||||
@@ -1098,7 +1174,7 @@ def tool(
|
||||
max_invocations: int | None = None,
|
||||
max_invocation_exceptions: int | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
result_parser: Callable[[Any], str | list[Content]] | None = None,
|
||||
result_parser: Callable[[Any], str | list[Content]] | _SkipParsingSentinel | None = None,
|
||||
) -> FunctionTool | Callable[[Callable[..., Any]], FunctionTool]:
|
||||
"""Decorate a function to turn it into a FunctionTool that can be passed to models and executed automatically.
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanE
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agent_framework import (
|
||||
SKIP_PARSING,
|
||||
Content,
|
||||
FunctionTool,
|
||||
tool,
|
||||
@@ -1300,4 +1301,165 @@ def test_normalize_tools_flattens_mapping_like_toolbox_with_tools_attr() -> None
|
||||
assert normalized[1] is standalone
|
||||
|
||||
|
||||
# region SKIP_PARSING sentinel & skip_parsing
|
||||
|
||||
|
||||
async def test_invoke_skip_parsing_returns_native_value() -> None:
|
||||
"""invoke(skip_parsing=True) returns the wrapped function's raw value."""
|
||||
|
||||
@tool
|
||||
def get_weather(city: str) -> dict[str, Any]:
|
||||
"""Get the weather."""
|
||||
return {"city": city, "temperature_c": 21.5, "conditions": "partly cloudy"}
|
||||
|
||||
raw = await get_weather.invoke(arguments={"city": "Seattle"}, skip_parsing=True)
|
||||
|
||||
assert isinstance(raw, dict)
|
||||
assert raw == {"city": "Seattle", "temperature_c": 21.5, "conditions": "partly cloudy"}
|
||||
|
||||
|
||||
async def test_invoke_skip_parsing_passes_through_custom_objects() -> None:
|
||||
"""skip_parsing must not call str()/repr() on the result."""
|
||||
|
||||
class Custom: # noqa: B903
|
||||
def __init__(self, value: int) -> None:
|
||||
self.value = value
|
||||
|
||||
@tool
|
||||
def make() -> Custom:
|
||||
"""Make a custom object."""
|
||||
return Custom(42)
|
||||
|
||||
raw = await make.invoke(skip_parsing=True)
|
||||
|
||||
assert isinstance(raw, Custom)
|
||||
assert raw.value == 42
|
||||
|
||||
|
||||
async def test_invoke_skip_parsing_awaits_async_functions() -> None:
|
||||
@tool
|
||||
async def slow(x: int) -> int:
|
||||
"""Async tool."""
|
||||
return x * 2
|
||||
|
||||
raw = await slow.invoke(arguments={"x": 21}, skip_parsing=True)
|
||||
assert raw == 42
|
||||
|
||||
|
||||
async def test_invoke_skip_parsing_bypasses_configured_result_parser() -> None:
|
||||
"""The tool's own result_parser is bypassed when skip_parsing=True is requested."""
|
||||
parser_calls: list[Any] = []
|
||||
|
||||
def parser(value: Any) -> str:
|
||||
parser_calls.append(value)
|
||||
return "PARSED"
|
||||
|
||||
@tool(result_parser=parser)
|
||||
def make_dict() -> dict[str, int]:
|
||||
"""Returns a dict."""
|
||||
return {"a": 1}
|
||||
|
||||
raw = await make_dict.invoke(skip_parsing=True)
|
||||
assert raw == {"a": 1}
|
||||
assert parser_calls == []
|
||||
|
||||
# Sanity: omitting skip_parsing still applies the configured parser.
|
||||
parsed = await make_dict.invoke()
|
||||
assert parsed[0].type == "text"
|
||||
assert parsed[0].text == "PARSED"
|
||||
|
||||
|
||||
async def test_constructor_skip_parsing_sentinel_returns_raw_by_default() -> None:
|
||||
"""Constructing a tool with result_parser=SKIP_PARSING makes invoke return the raw value."""
|
||||
|
||||
@tool(result_parser=SKIP_PARSING)
|
||||
def make_dict() -> dict[str, int]:
|
||||
"""Returns a dict."""
|
||||
return {"a": 1}
|
||||
|
||||
raw = await make_dict.invoke()
|
||||
assert raw == {"a": 1}
|
||||
|
||||
|
||||
async def test_invoke_skip_parsing_validates_arguments() -> None:
|
||||
"""Argument validation is shared with the default path."""
|
||||
|
||||
@tool
|
||||
def adder(x: int, y: int) -> int:
|
||||
"""Add."""
|
||||
return x + y
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
await adder.invoke(arguments={"x": "not-an-int", "y": 1}, skip_parsing=True)
|
||||
|
||||
|
||||
async def test_invoke_skip_parsing_rejects_unexpected_runtime_kwargs() -> None:
|
||||
@tool
|
||||
async def echo(message: str) -> str:
|
||||
"""Echo."""
|
||||
return message
|
||||
|
||||
with pytest.raises(TypeError, match="Unexpected keyword argument"):
|
||||
await echo.invoke(arguments={"message": "hi"}, skip_parsing=True, api_token="secret")
|
||||
|
||||
|
||||
async def test_invoke_skip_parsing_raises_for_declaration_only_tool() -> None:
|
||||
declared = FunctionTool(name="dummy", description="declaration only")
|
||||
|
||||
from agent_framework.exceptions import ToolException
|
||||
|
||||
with pytest.raises(ToolException):
|
||||
await declared.invoke(arguments={}, skip_parsing=True)
|
||||
|
||||
|
||||
async def test_invoke_skip_parsing_records_telemetry(span_exporter: InMemorySpanExporter) -> None:
|
||||
"""skip_parsing participates in OTEL spans and records str(raw) as TOOL_RESULT."""
|
||||
|
||||
@tool(name="raw_tool", description="raw tool")
|
||||
def returns_dict(x: int) -> dict[str, int]:
|
||||
"""Returns a dict."""
|
||||
return {"value": x}
|
||||
|
||||
span_exporter.clear()
|
||||
raw = await returns_dict.invoke(arguments={"x": 5}, tool_call_id="raw_call", skip_parsing=True)
|
||||
|
||||
assert raw == {"value": 5}
|
||||
spans = span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
span = spans[0]
|
||||
assert span.attributes[OtelAttr.TOOL_NAME] == "raw_tool"
|
||||
assert span.attributes[OtelAttr.TOOL_CALL_ID] == "raw_call"
|
||||
assert span.attributes[OtelAttr.TOOL_RESULT] == "{'value': 5}"
|
||||
|
||||
|
||||
async def test_invoke_default_path_records_parsed_telemetry(
|
||||
span_exporter: InMemorySpanExporter,
|
||||
) -> None:
|
||||
"""Regression: omitting skip_parsing still records the parsed result in telemetry."""
|
||||
|
||||
def parser(value: Any) -> str:
|
||||
return f"parsed:{value}"
|
||||
|
||||
@tool(name="parsed_tool", description="parsed", result_parser=parser)
|
||||
def returns_int() -> int:
|
||||
"""Returns an int."""
|
||||
return 7
|
||||
|
||||
span_exporter.clear()
|
||||
parsed = await returns_int.invoke(tool_call_id="parsed_call")
|
||||
|
||||
assert parsed[0].text == "parsed:7"
|
||||
spans = span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
assert spans[0].attributes[OtelAttr.TOOL_RESULT] == "parsed:7"
|
||||
|
||||
|
||||
def test_skip_parsing_is_singleton() -> None:
|
||||
"""SKIP_PARSING is a singleton; instantiation returns the same object."""
|
||||
from agent_framework._tools import _SkipParsingSentinel
|
||||
|
||||
assert _SkipParsingSentinel() is SKIP_PARSING
|
||||
assert repr(SKIP_PARSING) == "SKIP_PARSING"
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -130,3 +130,9 @@ codeact = HyperlightCodeActProvider(
|
||||
- `allowed_domains` accepts a single string target such as `"github.com"` to
|
||||
allow all backend-supported methods, an explicit `(target, method_or_methods)`
|
||||
tuple such as `("github.com", "GET")`, or an `AllowedDomain` named tuple.
|
||||
- Tools registered with the sandbox return their native Python value
|
||||
(`dict`, `list`, primitives, or custom objects) directly to the guest via the
|
||||
Hyperlight FFI. Any `result_parser` configured on a `FunctionTool` is
|
||||
intended for LLM-facing consumers and does not run on the sandbox path —
|
||||
apply formatting inside the tool function itself if you need it for
|
||||
in-sandbox consumers.
|
||||
|
||||
@@ -2,42 +2,45 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import copy
|
||||
import mimetypes
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import suppress
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path, PurePosixPath
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Annotated, Any, Protocol, TypeGuard, cast
|
||||
from typing import Any, Protocol, TypeGuard, TypeVar, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from agent_framework import Content, FunctionTool
|
||||
from agent_framework._tools import ApprovalMode, normalize_tools
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ._instructions import build_codeact_instructions, build_execute_code_description
|
||||
from ._types import AllowedDomain, AllowedDomainInput, FileMount, FileMountHostPath, FileMountInput
|
||||
|
||||
DEFAULT_HYPERLIGHT_BACKEND = "wasm"
|
||||
DEFAULT_HYPERLIGHT_MODULE = "python_guest.path"
|
||||
EXECUTE_CODE_INPUT_DESCRIPTION = "Python code to execute in an isolated Hyperlight sandbox."
|
||||
EXECUTE_CODE_TOOL_DESCRIPTION = "Execute Python in an isolated Hyperlight sandbox."
|
||||
OUTPUT_FILE_RETRY_ATTEMPTS = 10
|
||||
OUTPUT_FILE_RETRY_DELAY_SECONDS = 0.1
|
||||
|
||||
|
||||
class _ExecuteCodeInput(BaseModel):
|
||||
code: Annotated[str, Field(description=EXECUTE_CODE_INPUT_DESCRIPTION)]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _StoredFileMount:
|
||||
host_path: Path
|
||||
mount_path: str
|
||||
EXECUTE_CODE_INPUT_SCHEMA: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"title": "_ExecuteCodeInput",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"title": "Code",
|
||||
"description": "Python code to execute in an isolated Hyperlight sandbox.",
|
||||
},
|
||||
},
|
||||
"required": ["code"],
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
@@ -85,13 +88,43 @@ class SandboxRuntime(Protocol):
|
||||
def execute(self, *, config: _RunConfig, code: str) -> list[Content]: ...
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class _SandboxWorker:
|
||||
"""Single-threaded executor that confines all sandbox operations to one OS thread.
|
||||
|
||||
The Hyperlight ``WasmSandbox`` is declared ``unsendable`` in PyO3, meaning it can only be
|
||||
accessed from the OS thread that created it; touching it from any other thread triggers a
|
||||
Rust panic that cannot be caught from Python. Every cached :class:`_SandboxEntry` therefore
|
||||
owns its own ``_SandboxWorker``, and *all* lifecycle and execution calls against the
|
||||
underlying sandbox object must be routed through :meth:`submit`/:meth:`run`.
|
||||
"""
|
||||
|
||||
__slots__ = ("_executor",)
|
||||
|
||||
def __init__(self, *, name: str = "hl-sandbox") -> None:
|
||||
self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix=name)
|
||||
|
||||
def submit(self, fn: Callable[..., _T], /, *args: Any, **kwargs: Any) -> Future[_T]:
|
||||
return self._executor.submit(fn, *args, **kwargs)
|
||||
|
||||
def run(self, fn: Callable[..., _T], /, *args: Any, **kwargs: Any) -> _T:
|
||||
return self._executor.submit(fn, *args, **kwargs).result()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
# Do not block on shutdown; stop accepting new tasks, but allow the currently running
|
||||
# task and any already-queued tasks to finish before the worker thread exits.
|
||||
self._executor.shutdown(wait=False, cancel_futures=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SandboxEntry:
|
||||
sandbox: Any
|
||||
snapshot: Any
|
||||
input_dir: TemporaryDirectory[str] | None
|
||||
output_dir: TemporaryDirectory[str] | None
|
||||
lock: threading.RLock
|
||||
worker: _SandboxWorker = field(default_factory=_SandboxWorker)
|
||||
|
||||
|
||||
def _load_sandbox_class() -> type[Any]:
|
||||
@@ -106,10 +139,6 @@ def _load_sandbox_class() -> type[Any]:
|
||||
return Sandbox
|
||||
|
||||
|
||||
def _passthrough_result_parser(result: Any) -> str:
|
||||
return repr(result)
|
||||
|
||||
|
||||
def _collect_tools(*tool_groups: Any) -> list[FunctionTool]:
|
||||
tools_by_name: dict[str, FunctionTool] = {}
|
||||
|
||||
@@ -166,7 +195,7 @@ def _is_file_mount_pair(value: Any) -> TypeGuard[FileMount | tuple[FileMountHost
|
||||
return isinstance(host_path, (str, Path)) and isinstance(mount_path, str)
|
||||
|
||||
|
||||
def _normalize_file_mount_input(file_mount: FileMountInput) -> _StoredFileMount:
|
||||
def _normalize_file_mount_input(file_mount: FileMountInput) -> FileMount:
|
||||
host_path: FileMountHostPath
|
||||
mount_path: str
|
||||
if isinstance(file_mount, str):
|
||||
@@ -176,7 +205,7 @@ def _normalize_file_mount_input(file_mount: FileMountInput) -> _StoredFileMount:
|
||||
host_path = file_mount[0]
|
||||
mount_path = file_mount[1]
|
||||
|
||||
return _StoredFileMount(
|
||||
return FileMount(
|
||||
host_path=_resolve_existing_path(host_path),
|
||||
mount_path=_normalize_mount_path(mount_path),
|
||||
)
|
||||
@@ -445,18 +474,13 @@ def _build_execution_contents(
|
||||
|
||||
|
||||
def _make_sandbox_callback(tool_obj: FunctionTool) -> Callable[..., Any]:
|
||||
sandbox_tool = copy.copy(tool_obj)
|
||||
# Auto-assign a passthrough parser so the raw return value round-trips through
|
||||
# `ast.literal_eval` in the sandbox callback below. User-supplied parsers are
|
||||
# left in place so callers can customize how results are exposed to the guest.
|
||||
if sandbox_tool.result_parser is None:
|
||||
sandbox_tool.result_parser = _passthrough_result_parser
|
||||
sandbox_tool = copy(tool_obj)
|
||||
|
||||
def _callback(**kwargs: Any) -> Any:
|
||||
async def _invoke() -> list[Content]:
|
||||
return await sandbox_tool.invoke(arguments=kwargs)
|
||||
async def _invoke() -> Any:
|
||||
return await sandbox_tool.invoke(arguments=kwargs, skip_parsing=True)
|
||||
|
||||
# FunctionTool.invoke() is always async. The real Hyperlight backend invokes
|
||||
# FunctionTool.invoke() is async. The real Hyperlight backend invokes
|
||||
# registered callbacks synchronously via FFI, so this must be a sync function.
|
||||
# We run the async call on a dedicated thread to avoid conflicts with any
|
||||
# event loop that may be running on the current thread.
|
||||
@@ -474,22 +498,11 @@ def _make_sandbox_callback(tool_obj: FunctionTool) -> Callable[..., Any]:
|
||||
worker.join()
|
||||
if error_box:
|
||||
raise error_box[0]
|
||||
contents: list[Content] = result_box[0]
|
||||
|
||||
values: list[Any] = []
|
||||
for content in contents:
|
||||
if content.type == "text" and content.text is not None:
|
||||
try:
|
||||
values.append(ast.literal_eval(content.text))
|
||||
except (SyntaxError, ValueError):
|
||||
values.append(content.text)
|
||||
continue
|
||||
|
||||
values.append(content.to_dict())
|
||||
|
||||
if len(values) == 1:
|
||||
return values[0]
|
||||
return values
|
||||
# Return the raw value. The Hyperlight FFI marshals primitives (dict, list,
|
||||
# str, int, float, bool, None) natively into the guest, and falls back to
|
||||
# repr()/str() for unsupported types — so the guest receives real Python
|
||||
# objects without a lossy host-side serialization round-trip.
|
||||
return result_box[0]
|
||||
|
||||
return _callback
|
||||
|
||||
@@ -509,7 +522,7 @@ def _clear_directory(output_dir: TemporaryDirectory[str] | None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _SandboxRegistry:
|
||||
class _SandboxRegistry(SandboxRuntime):
|
||||
def __init__(self) -> None:
|
||||
self._entries: dict[tuple[Any, ...], _SandboxEntry] = {}
|
||||
self._entries_lock = threading.RLock()
|
||||
@@ -517,28 +530,54 @@ class _SandboxRegistry:
|
||||
def execute(self, *, config: _RunConfig, code: str) -> list[Content]:
|
||||
"""Execute code in a cached sandbox matching the given config.
|
||||
|
||||
Entries are keyed by ``config.cache_key()``. Concurrent calls with the same
|
||||
key are serialized by the entry lock so they never race, but they share the
|
||||
same sandbox instance. For true parallel execution, use distinct provider
|
||||
instances or configs that produce different cache keys.
|
||||
Entries are keyed by ``config.cache_key()``. All operations against the underlying
|
||||
sandbox object are routed through the entry's dedicated single-threaded worker, which
|
||||
both serializes concurrent callers and satisfies the PyO3 ``unsendable`` invariant
|
||||
that the sandbox can only be touched from the thread that created it.
|
||||
"""
|
||||
entry = self._get_or_create_entry(config)
|
||||
return entry.worker.run(self._run_on_worker, entry, code)
|
||||
|
||||
@staticmethod
|
||||
def _run_on_worker(entry: _SandboxEntry, code: str) -> list[Content]:
|
||||
entry.sandbox.restore(entry.snapshot)
|
||||
_clear_directory(entry.output_dir)
|
||||
result = entry.sandbox.run(code=code)
|
||||
return _build_execution_contents(
|
||||
result=result,
|
||||
sandbox=entry.sandbox,
|
||||
output_dir=entry.output_dir,
|
||||
code=code,
|
||||
)
|
||||
|
||||
def _get_or_create_entry(self, config: _RunConfig) -> _SandboxEntry:
|
||||
cache_key = config.cache_key()
|
||||
with self._entries_lock:
|
||||
entry = self._entries.get(cache_key)
|
||||
if entry is None:
|
||||
entry = self._create_entry(config)
|
||||
self._entries[cache_key] = entry
|
||||
return entry
|
||||
|
||||
with entry.lock:
|
||||
entry.sandbox.restore(entry.snapshot)
|
||||
_clear_directory(entry.output_dir)
|
||||
result = entry.sandbox.run(code=code)
|
||||
return _build_execution_contents(
|
||||
result=result,
|
||||
sandbox=entry.sandbox,
|
||||
output_dir=entry.output_dir,
|
||||
code=code,
|
||||
)
|
||||
def close(self) -> None:
|
||||
"""Shut down all per-entry worker threads and release per-entry resources.
|
||||
|
||||
Safe to call multiple times. Runs any sandbox close hook on the entry's
|
||||
own worker thread to honor the PyO3 ``unsendable`` invariant.
|
||||
"""
|
||||
with self._entries_lock:
|
||||
entries = list(self._entries.values())
|
||||
self._entries.clear()
|
||||
for entry in entries:
|
||||
close_hook = getattr(entry.sandbox, "close", None) or getattr(entry.sandbox, "shutdown", None)
|
||||
if callable(close_hook):
|
||||
with suppress(Exception):
|
||||
entry.worker.run(close_hook)
|
||||
entry.worker.shutdown()
|
||||
for tmp_dir in (entry.input_dir, entry.output_dir):
|
||||
if tmp_dir is not None:
|
||||
with suppress(Exception):
|
||||
tmp_dir.cleanup()
|
||||
|
||||
def _create_entry(self, config: _RunConfig) -> _SandboxEntry:
|
||||
input_dir_handle = TemporaryDirectory() if config.filesystem_enabled else None
|
||||
@@ -578,26 +617,37 @@ class _SandboxRegistry:
|
||||
methods=list(allowed_domain.methods) if allowed_domain.methods is not None else None,
|
||||
)
|
||||
|
||||
sandbox = _create_sandbox()
|
||||
_configure_sandbox(sandbox=sandbox, expand_missing_scheme=False)
|
||||
worker = _SandboxWorker()
|
||||
|
||||
def _build_sandbox() -> tuple[Any, Any]:
|
||||
sandbox = _create_sandbox()
|
||||
_configure_sandbox(sandbox=sandbox, expand_missing_scheme=False)
|
||||
|
||||
try:
|
||||
sandbox.run("None")
|
||||
except RuntimeError as exc:
|
||||
if not _should_retry_allowed_domain_registration(error=exc, allowed_domains=config.allowed_domains):
|
||||
raise
|
||||
|
||||
sandbox = _create_sandbox()
|
||||
_configure_sandbox(sandbox=sandbox, expand_missing_scheme=True)
|
||||
sandbox.run("None")
|
||||
|
||||
snapshot = sandbox.snapshot()
|
||||
return sandbox, snapshot
|
||||
|
||||
try:
|
||||
sandbox.run("None")
|
||||
except RuntimeError as exc:
|
||||
if not _should_retry_allowed_domain_registration(error=exc, allowed_domains=config.allowed_domains):
|
||||
raise
|
||||
sandbox, snapshot = worker.run(_build_sandbox)
|
||||
except BaseException:
|
||||
worker.shutdown()
|
||||
raise
|
||||
|
||||
sandbox = _create_sandbox()
|
||||
_configure_sandbox(sandbox=sandbox, expand_missing_scheme=True)
|
||||
sandbox.run("None")
|
||||
|
||||
snapshot = sandbox.snapshot()
|
||||
return _SandboxEntry(
|
||||
sandbox=sandbox,
|
||||
snapshot=snapshot,
|
||||
input_dir=input_dir_handle,
|
||||
output_dir=output_dir_handle,
|
||||
lock=threading.RLock(),
|
||||
worker=worker,
|
||||
)
|
||||
|
||||
|
||||
@@ -619,10 +669,10 @@ class HyperlightExecuteCodeTool(FunctionTool):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name="execute_code",
|
||||
description=EXECUTE_CODE_INPUT_DESCRIPTION,
|
||||
description=EXECUTE_CODE_TOOL_DESCRIPTION,
|
||||
approval_mode="never_require",
|
||||
func=self._run_code,
|
||||
input_model=_ExecuteCodeInput,
|
||||
input_model=EXECUTE_CODE_INPUT_SCHEMA,
|
||||
)
|
||||
self._state_lock = threading.RLock()
|
||||
self._registry = _registry or _SandboxRegistry()
|
||||
@@ -632,7 +682,7 @@ class HyperlightExecuteCodeTool(FunctionTool):
|
||||
self._module: str | None = module
|
||||
self._module_path: str | None = module_path
|
||||
self._managed_tools: list[FunctionTool] = []
|
||||
self._file_mounts: dict[str, _StoredFileMount] = {}
|
||||
self._file_mounts: dict[str, FileMount] = {}
|
||||
self._allowed_domains: dict[str, AllowedDomain] = {}
|
||||
|
||||
if tools is not None:
|
||||
@@ -648,7 +698,7 @@ class HyperlightExecuteCodeTool(FunctionTool):
|
||||
def description(self) -> str:
|
||||
state_lock = getattr(self, "_state_lock", None)
|
||||
if state_lock is None:
|
||||
return str(self.__dict__.get("description", EXECUTE_CODE_INPUT_DESCRIPTION))
|
||||
return str(self.__dict__.get("description", EXECUTE_CODE_TOOL_DESCRIPTION))
|
||||
|
||||
with state_lock:
|
||||
allowed_domains = sorted(self._allowed_domains.values(), key=lambda value: value.target)
|
||||
@@ -841,9 +891,9 @@ class HyperlightExecuteCodeTool(FunctionTool):
|
||||
workspace_signature = _path_tree_signature(workspace_root) if workspace_root is not None else ()
|
||||
normalized_mounts = tuple(
|
||||
_NormalizedFileMount(
|
||||
host_path=mount.host_path,
|
||||
host_path=Path(mount.host_path),
|
||||
mount_path=mount.mount_path,
|
||||
path_signature=_path_tree_signature(mount.host_path),
|
||||
path_signature=_path_tree_signature(Path(mount.host_path)),
|
||||
)
|
||||
for mount in stored_mounts
|
||||
)
|
||||
|
||||
@@ -937,3 +937,191 @@ async def test_run_code_does_not_block_event_loop() -> None:
|
||||
|
||||
assert concurrent_ran, "Event loop was blocked during sandbox execution"
|
||||
assert result[0].type == "text"
|
||||
|
||||
|
||||
class _ThreadAffinityFakeSandbox(_FakeSandbox):
|
||||
"""Fake sandbox that records the OS thread of every method invocation.
|
||||
|
||||
Mirrors the PyO3 ``unsendable`` invariant of ``hyperlight_sandbox.WasmSandbox``:
|
||||
if ``__init__``, ``register_tool``, ``allow_domain``, ``run``, ``snapshot`` or ``restore``
|
||||
are ever called from more than one thread for a given instance, the test fails.
|
||||
"""
|
||||
|
||||
affinity_failures: list[str] = []
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._owner_thread = threading.get_ident()
|
||||
self.thread_ids: set[int] = {self._owner_thread}
|
||||
|
||||
def _record(self, method: str) -> None:
|
||||
ident = threading.get_ident()
|
||||
self.thread_ids.add(ident)
|
||||
if ident != self._owner_thread:
|
||||
_ThreadAffinityFakeSandbox.affinity_failures.append(
|
||||
f"{method} called from thread {ident}, expected {self._owner_thread}"
|
||||
)
|
||||
|
||||
def register_tool(self, name_or_tool: Any, callback: Any | None = None) -> None:
|
||||
self._record("register_tool")
|
||||
super().register_tool(name_or_tool, callback)
|
||||
|
||||
def allow_domain(self, target: str, methods: list[str] | None = None) -> None:
|
||||
self._record("allow_domain")
|
||||
super().allow_domain(target, methods)
|
||||
|
||||
def run(self, code: str) -> _FakeResult:
|
||||
self._record("run")
|
||||
return super().run(code)
|
||||
|
||||
def snapshot(self) -> str:
|
||||
self._record("snapshot")
|
||||
return super().snapshot()
|
||||
|
||||
def restore(self, snapshot: Any) -> None:
|
||||
self._record("restore")
|
||||
super().restore(snapshot)
|
||||
|
||||
|
||||
async def test_sandbox_calls_are_pinned_to_owning_worker_thread(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Regression: WasmSandbox is unsendable; every sandbox call must run on its owner thread."""
|
||||
_ThreadAffinityFakeSandbox.instances.clear()
|
||||
_ThreadAffinityFakeSandbox.affinity_failures.clear()
|
||||
monkeypatch.setattr(execute_code_module, "_load_sandbox_class", lambda: _ThreadAffinityFakeSandbox)
|
||||
|
||||
execute_code = HyperlightExecuteCodeTool()
|
||||
|
||||
# Invoke many times concurrently; asyncio.to_thread will spread these across the default
|
||||
# executor's worker threads, which previously caused PyO3 to panic when a different thread
|
||||
# touched the cached sandbox.
|
||||
results = await asyncio.gather(*[execute_code.invoke(arguments={"code": "None"}) for _ in range(8)])
|
||||
for result in results:
|
||||
assert result[0].type == "text"
|
||||
|
||||
assert _ThreadAffinityFakeSandbox.affinity_failures == []
|
||||
assert len(_ThreadAffinityFakeSandbox.instances) == 1
|
||||
sandbox = _ThreadAffinityFakeSandbox.instances[0]
|
||||
# All sandbox-touching calls must have stayed on a single owning thread, distinct from the
|
||||
# caller thread that asyncio.to_thread used for dispatch.
|
||||
assert sandbox.thread_ids == {sandbox._owner_thread}
|
||||
assert sandbox._owner_thread != threading.get_ident()
|
||||
|
||||
|
||||
async def test_sandbox_owner_thread_persists_across_dispatch_threads(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Sequential calls landing on different dispatch threads still share one sandbox thread."""
|
||||
_ThreadAffinityFakeSandbox.instances.clear()
|
||||
_ThreadAffinityFakeSandbox.affinity_failures.clear()
|
||||
monkeypatch.setattr(execute_code_module, "_load_sandbox_class", lambda: _ThreadAffinityFakeSandbox)
|
||||
|
||||
execute_code = HyperlightExecuteCodeTool()
|
||||
|
||||
for _ in range(5):
|
||||
result = await execute_code.invoke(arguments={"code": "None"})
|
||||
assert result[0].type == "text"
|
||||
|
||||
assert _ThreadAffinityFakeSandbox.affinity_failures == []
|
||||
assert len(_ThreadAffinityFakeSandbox.instances) == 1
|
||||
|
||||
|
||||
def test_sandbox_registry_close_shuts_down_workers(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_FakeSandbox.instances.clear()
|
||||
monkeypatch.setattr(execute_code_module, "_load_sandbox_class", lambda: _FakeSandbox)
|
||||
|
||||
registry = execute_code_module._SandboxRegistry()
|
||||
execute_code = HyperlightExecuteCodeTool(_registry=registry)
|
||||
asyncio.run(execute_code.invoke(arguments={"code": "None"}))
|
||||
|
||||
entries = list(registry._entries.values())
|
||||
assert len(entries) == 1
|
||||
worker = entries[0].worker
|
||||
|
||||
registry.close()
|
||||
|
||||
assert registry._entries == {}
|
||||
# Submitting after shutdown must fail; this proves the executor was actually torn down.
|
||||
with pytest.raises(RuntimeError):
|
||||
worker.submit(lambda: None)
|
||||
|
||||
|
||||
def test_sandbox_registry_close_releases_per_entry_resources(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
"""close() must invoke any sandbox close hook and release temp directories."""
|
||||
|
||||
close_calls: list[int] = []
|
||||
|
||||
class _ClosableFakeSandbox(_FakeSandbox):
|
||||
def close(self) -> None:
|
||||
close_calls.append(1)
|
||||
|
||||
_FakeSandbox.instances.clear()
|
||||
monkeypatch.setattr(execute_code_module, "_load_sandbox_class", lambda: _ClosableFakeSandbox)
|
||||
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir()
|
||||
registry = execute_code_module._SandboxRegistry()
|
||||
execute_code = HyperlightExecuteCodeTool(workspace_root=workspace, _registry=registry)
|
||||
asyncio.run(execute_code.invoke(arguments={"code": "None"}))
|
||||
|
||||
entries = list(registry._entries.values())
|
||||
assert len(entries) == 1
|
||||
entry = entries[0]
|
||||
assert entry.input_dir is not None and entry.output_dir is not None
|
||||
input_path = Path(entry.input_dir.name)
|
||||
output_path = Path(entry.output_dir.name)
|
||||
assert input_path.exists() and output_path.exists()
|
||||
|
||||
registry.close()
|
||||
|
||||
assert close_calls == [1]
|
||||
assert not input_path.exists()
|
||||
assert not output_path.exists()
|
||||
|
||||
|
||||
async def test_make_sandbox_callback_returns_native_dict() -> None:
|
||||
"""Host tool returning a dict must be forwarded as a native dict (no repr round-trip)."""
|
||||
|
||||
@tool
|
||||
def get_weather(city: str) -> dict[str, Any]:
|
||||
"""Get weather."""
|
||||
return {"city": city, "temp_c": 21.5}
|
||||
|
||||
callback = execute_code_module._make_sandbox_callback(get_weather)
|
||||
result = callback(city="Seattle")
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result == {"city": "Seattle", "temp_c": 21.5}
|
||||
|
||||
|
||||
async def test_make_sandbox_callback_bypasses_user_result_parser() -> None:
|
||||
"""Documented behavior change: result_parser is bypassed in the sandbox path."""
|
||||
|
||||
parser_calls: list[Any] = []
|
||||
|
||||
def parser(value: Any) -> str:
|
||||
parser_calls.append(value)
|
||||
return "PARSED"
|
||||
|
||||
@tool(result_parser=parser)
|
||||
def make_payload() -> dict[str, int]:
|
||||
"""Returns a dict."""
|
||||
return {"a": 1, "b": 2}
|
||||
|
||||
callback = execute_code_module._make_sandbox_callback(make_payload)
|
||||
result = callback()
|
||||
|
||||
assert result == {"a": 1, "b": 2}
|
||||
assert parser_calls == [], "result_parser must not run on the sandbox path"
|
||||
|
||||
|
||||
async def test_make_sandbox_callback_propagates_exceptions() -> None:
|
||||
@tool
|
||||
def boom(x: int) -> int:
|
||||
"""Always fails."""
|
||||
raise RuntimeError("nope")
|
||||
|
||||
callback = execute_code_module._make_sandbox_callback(boom)
|
||||
with pytest.raises(RuntimeError, match="nope"):
|
||||
callback(x=1)
|
||||
|
||||
Reference in New Issue
Block a user