Python: Hyperlight: thread-confine sandbox, skip parsing on host callbacks, schema/tool cleanup (#5424)

* improved parsing of tool call results and tweaks

* Address PR review: skip_parsing flag, broader registry close, comment fix

- FunctionTool.invoke now takes a boolean skip_parsing flag instead of the
  SKIP_PARSING sentinel; the sentinel is still accepted as result_parser at
  construction time to opt out of parsing for every call. The two paths are
  equivalent.
- _SandboxRegistry.close now invokes any sandbox close/shutdown hook on the
  entry's own worker thread (PyO3 unsendable), then shuts the worker down,
  then cleans up the per-entry temporary directories.
- Clarified the _SandboxWorker.shutdown comment to describe the actual
  ThreadPoolExecutor.shutdown(wait=False, cancel_futures=False) semantics.
- Hyperlight host callback uses skip_parsing=True (the new flag).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Drop redundant 'is not SKIP_PARSING' guard that mypy 1.x flags

After callable(configured_parser) the sentinel is already excluded; the extra
identity check tripped mypy's non-overlapping identity warning.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fixed sandbox working on copy of tool

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Eduard van Valkenburg
2026-04-23 10:14:51 +02:00
committed by GitHub
Unverified
parent 66e02c10e3
commit 58ff4ad3a9
6 changed files with 576 additions and 92 deletions
@@ -125,6 +125,7 @@ from ._telemetry import (
prepend_agent_framework_to_user_agent,
)
from ._tools import (
SKIP_PARSING,
FunctionInvocationConfiguration,
FunctionInvocationLayer,
FunctionTool,
@@ -258,6 +259,7 @@ __all__ = [
"GROUP_INDEX_KEY",
"GROUP_KIND_KEY",
"GROUP_TOKEN_COUNT_KEY",
"SKIP_PARSING",
"SUMMARIZED_BY_SUMMARY_ID_KEY",
"SUMMARY_OF_GROUP_IDS_KEY",
"SUMMARY_OF_MESSAGE_IDS_KEY",
+88 -12
View File
@@ -94,6 +94,33 @@ ApprovalMode: TypeAlias = Literal["always_require", "never_require"]
ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]")
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
class _SkipParsingSentinel:
"""Sentinel signaling that :meth:`FunctionTool.invoke` should return the raw value.
When passed as ``result_parser`` to :class:`FunctionTool` (or the ``@tool`` decorator),
the default :meth:`FunctionTool.parse_result` is bypassed and the wrapped function's
return value is returned unchanged from :meth:`FunctionTool.invoke`. Callers may also
request the raw value on a per-call basis by passing ``skip_parsing=True`` to
:meth:`FunctionTool.invoke`.
Use the module-level ``SKIP_PARSING`` singleton — do not instantiate this class.
"""
_instance: ClassVar[_SkipParsingSentinel | None] = None
def __new__(cls) -> _SkipParsingSentinel:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __repr__(self) -> str:
return "SKIP_PARSING"
SKIP_PARSING: Final[_SkipParsingSentinel] = _SkipParsingSentinel()
"""Sentinel for ``FunctionTool(result_parser=...)`` meaning "do not parse the result"."""
# region Helpers
@@ -279,7 +306,7 @@ class FunctionTool(SerializationMixin):
additional_properties: dict[str, Any] | None = None,
func: Callable[..., Any] | None = None,
input_model: type[BaseModel] | Mapping[str, Any] | None = None,
result_parser: Callable[[Any], str | list[Content]] | None = None,
result_parser: Callable[[Any], str | list[Content]] | _SkipParsingSentinel | None = None,
**kwargs: Any,
) -> None:
"""Initialize the FunctionTool.
@@ -327,9 +354,11 @@ class FunctionTool(SerializationMixin):
result_parser: An optional callable with signature ``Callable[[Any], str]`` that
overrides the default result parsing behavior. When provided, this callable
is used to convert the raw function return value to a string instead of the
built-in :meth:`parse_result` logic. Depending on your function, it may be
easiest to just do the serialization directly in the function body rather
than providing a custom ``result_parser``.
built-in :meth:`parse_result` logic. Pass the :data:`SKIP_PARSING` sentinel
instead of a callable to opt out of parsing entirely; in that case
:meth:`invoke` returns the wrapped function's raw return value. Depending
on your function, it may be easiest to just do the serialization directly
in the function body rather than providing a custom ``result_parser``.
**kwargs: Additional keyword arguments.
"""
# Core attributes (formerly from BaseTool)
@@ -508,31 +537,65 @@ class FunctionTool(SerializationMixin):
self.invocation_exception_count += 1
raise
@overload
async def invoke(
self,
*,
arguments: BaseModel | Mapping[str, Any] | None = None,
context: FunctionInvocationContext | None = None,
tool_call_id: str | None = None,
skip_parsing: Literal[True],
**kwargs: Any,
) -> list[Content]:
) -> Any: ...
@overload
async def invoke(
self,
*,
arguments: BaseModel | Mapping[str, Any] | None = None,
context: FunctionInvocationContext | None = None,
tool_call_id: str | None = None,
skip_parsing: Literal[False] = False,
**kwargs: Any,
) -> list[Content]: ...
async def invoke(
self,
*,
arguments: BaseModel | Mapping[str, Any] | None = None,
context: FunctionInvocationContext | None = None,
tool_call_id: str | None = None,
skip_parsing: bool = False,
**kwargs: Any,
) -> list[Content] | Any:
"""Run the AI function with the provided arguments as a Pydantic model.
The raw return value of the wrapped function is automatically parsed into a
``list[Content]`` using :meth:`parse_result` or the custom ``result_parser``
if one was provided. Every result — text, rich media, or serialized objects —
is represented uniformly as Content items.
configured on the tool. Every result — text, rich media, or serialized
objects — is represented uniformly as Content items.
Parsing can be skipped in two ways: configure the tool with
``result_parser=SKIP_PARSING`` to always skip parsing, or pass
``skip_parsing=True`` per call. Either way the wrapped function's raw value
is returned. This is intended for callers (e.g. sandboxed runtimes) that
consume the value from Python directly and would otherwise undo the
``Content`` wrapping.
Keyword Args:
arguments: A mapping or model instance containing the arguments for the function.
context: Explicit function invocation context carrying runtime kwargs.
tool_call_id: Optional tool call identifier used for telemetry and tracing.
skip_parsing: When ``True``, bypass parsing and return the wrapped function's
raw value instead of a ``list[Content]``. Defaults to ``False``.
kwargs: Direct function argument values. When provided, every keyword
must match a declared tool parameter. Runtime data must be passed
via ``context``.
Returns:
A list of Content items representing the tool output.
``list[Content]`` by default. The raw function return value (``Any``) when
``skip_parsing=True`` (or the tool was constructed with
``result_parser=SKIP_PARSING``).
Raises:
TypeError: If arguments is not mapping-like or fails schema checks.
@@ -544,7 +607,9 @@ class FunctionTool(SerializationMixin):
from ._types import Content
from .observability import OBSERVABILITY_SETTINGS
parser = self.result_parser or FunctionTool.parse_result
configured_parser = self.result_parser
skip_parsing = skip_parsing or configured_parser is SKIP_PARSING
parser = configured_parser if callable(configured_parser) else FunctionTool.parse_result
parameter_names = set(self.parameters().get("properties", {}).keys())
direct_argument_kwargs = (
@@ -616,6 +681,10 @@ class FunctionTool(SerializationMixin):
logger.debug(f"Function arguments: {observable_kwargs}")
res = self.__call__(**call_kwargs)
result = await res if inspect.isawaitable(res) else res
if skip_parsing:
logger.info(f"Function {self.name} succeeded.")
logger.debug(f"Function result: {type(result).__name__}")
return result
try:
parsed = parser(result)
except Exception:
@@ -671,6 +740,13 @@ class FunctionTool(SerializationMixin):
logger.error(f"Function failed. Error: {exception}")
raise
else:
if skip_parsing:
logger.info(f"Function {self.name} succeeded.")
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined]
result_str = str(result)
span.set_attribute(OtelAttr.TOOL_RESULT, result_str)
logger.debug(f"Function result: {result_str}")
return result
try:
parsed = parser(result)
except Exception:
@@ -1067,7 +1143,7 @@ def tool(
max_invocations: int | None = None,
max_invocation_exceptions: int | None = None,
additional_properties: dict[str, Any] | None = None,
result_parser: Callable[[Any], str | list[Content]] | None = None,
result_parser: Callable[[Any], str | list[Content]] | _SkipParsingSentinel | None = None,
) -> FunctionTool: ...
@@ -1083,7 +1159,7 @@ def tool(
max_invocations: int | None = None,
max_invocation_exceptions: int | None = None,
additional_properties: dict[str, Any] | None = None,
result_parser: Callable[[Any], str | list[Content]] | None = None,
result_parser: Callable[[Any], str | list[Content]] | _SkipParsingSentinel | None = None,
) -> Callable[[Callable[..., Any]], FunctionTool]: ...
@@ -1098,7 +1174,7 @@ def tool(
max_invocations: int | None = None,
max_invocation_exceptions: int | None = None,
additional_properties: dict[str, Any] | None = None,
result_parser: Callable[[Any], str | list[Content]] | None = None,
result_parser: Callable[[Any], str | list[Content]] | _SkipParsingSentinel | None = None,
) -> FunctionTool | Callable[[Callable[..., Any]], FunctionTool]:
"""Decorate a function to turn it into a FunctionTool that can be passed to models and executed automatically.
@@ -8,6 +8,7 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanE
from pydantic import BaseModel
from agent_framework import (
SKIP_PARSING,
Content,
FunctionTool,
tool,
@@ -1300,4 +1301,165 @@ def test_normalize_tools_flattens_mapping_like_toolbox_with_tools_attr() -> None
assert normalized[1] is standalone
# region SKIP_PARSING sentinel & skip_parsing
async def test_invoke_skip_parsing_returns_native_value() -> None:
"""invoke(skip_parsing=True) returns the wrapped function's raw value."""
@tool
def get_weather(city: str) -> dict[str, Any]:
"""Get the weather."""
return {"city": city, "temperature_c": 21.5, "conditions": "partly cloudy"}
raw = await get_weather.invoke(arguments={"city": "Seattle"}, skip_parsing=True)
assert isinstance(raw, dict)
assert raw == {"city": "Seattle", "temperature_c": 21.5, "conditions": "partly cloudy"}
async def test_invoke_skip_parsing_passes_through_custom_objects() -> None:
"""skip_parsing must not call str()/repr() on the result."""
class Custom: # noqa: B903
def __init__(self, value: int) -> None:
self.value = value
@tool
def make() -> Custom:
"""Make a custom object."""
return Custom(42)
raw = await make.invoke(skip_parsing=True)
assert isinstance(raw, Custom)
assert raw.value == 42
async def test_invoke_skip_parsing_awaits_async_functions() -> None:
@tool
async def slow(x: int) -> int:
"""Async tool."""
return x * 2
raw = await slow.invoke(arguments={"x": 21}, skip_parsing=True)
assert raw == 42
async def test_invoke_skip_parsing_bypasses_configured_result_parser() -> None:
"""The tool's own result_parser is bypassed when skip_parsing=True is requested."""
parser_calls: list[Any] = []
def parser(value: Any) -> str:
parser_calls.append(value)
return "PARSED"
@tool(result_parser=parser)
def make_dict() -> dict[str, int]:
"""Returns a dict."""
return {"a": 1}
raw = await make_dict.invoke(skip_parsing=True)
assert raw == {"a": 1}
assert parser_calls == []
# Sanity: omitting skip_parsing still applies the configured parser.
parsed = await make_dict.invoke()
assert parsed[0].type == "text"
assert parsed[0].text == "PARSED"
async def test_constructor_skip_parsing_sentinel_returns_raw_by_default() -> None:
"""Constructing a tool with result_parser=SKIP_PARSING makes invoke return the raw value."""
@tool(result_parser=SKIP_PARSING)
def make_dict() -> dict[str, int]:
"""Returns a dict."""
return {"a": 1}
raw = await make_dict.invoke()
assert raw == {"a": 1}
async def test_invoke_skip_parsing_validates_arguments() -> None:
"""Argument validation is shared with the default path."""
@tool
def adder(x: int, y: int) -> int:
"""Add."""
return x + y
with pytest.raises(TypeError):
await adder.invoke(arguments={"x": "not-an-int", "y": 1}, skip_parsing=True)
async def test_invoke_skip_parsing_rejects_unexpected_runtime_kwargs() -> None:
@tool
async def echo(message: str) -> str:
"""Echo."""
return message
with pytest.raises(TypeError, match="Unexpected keyword argument"):
await echo.invoke(arguments={"message": "hi"}, skip_parsing=True, api_token="secret")
async def test_invoke_skip_parsing_raises_for_declaration_only_tool() -> None:
declared = FunctionTool(name="dummy", description="declaration only")
from agent_framework.exceptions import ToolException
with pytest.raises(ToolException):
await declared.invoke(arguments={}, skip_parsing=True)
async def test_invoke_skip_parsing_records_telemetry(span_exporter: InMemorySpanExporter) -> None:
"""skip_parsing participates in OTEL spans and records str(raw) as TOOL_RESULT."""
@tool(name="raw_tool", description="raw tool")
def returns_dict(x: int) -> dict[str, int]:
"""Returns a dict."""
return {"value": x}
span_exporter.clear()
raw = await returns_dict.invoke(arguments={"x": 5}, tool_call_id="raw_call", skip_parsing=True)
assert raw == {"value": 5}
spans = span_exporter.get_finished_spans()
assert len(spans) == 1
span = spans[0]
assert span.attributes[OtelAttr.TOOL_NAME] == "raw_tool"
assert span.attributes[OtelAttr.TOOL_CALL_ID] == "raw_call"
assert span.attributes[OtelAttr.TOOL_RESULT] == "{'value': 5}"
async def test_invoke_default_path_records_parsed_telemetry(
span_exporter: InMemorySpanExporter,
) -> None:
"""Regression: omitting skip_parsing still records the parsed result in telemetry."""
def parser(value: Any) -> str:
return f"parsed:{value}"
@tool(name="parsed_tool", description="parsed", result_parser=parser)
def returns_int() -> int:
"""Returns an int."""
return 7
span_exporter.clear()
parsed = await returns_int.invoke(tool_call_id="parsed_call")
assert parsed[0].text == "parsed:7"
spans = span_exporter.get_finished_spans()
assert len(spans) == 1
assert spans[0].attributes[OtelAttr.TOOL_RESULT] == "parsed:7"
def test_skip_parsing_is_singleton() -> None:
"""SKIP_PARSING is a singleton; instantiation returns the same object."""
from agent_framework._tools import _SkipParsingSentinel
assert _SkipParsingSentinel() is SKIP_PARSING
assert repr(SKIP_PARSING) == "SKIP_PARSING"
# endregion
+6
View File
@@ -130,3 +130,9 @@ codeact = HyperlightCodeActProvider(
- `allowed_domains` accepts a single string target such as `"github.com"` to
allow all backend-supported methods, an explicit `(target, method_or_methods)`
tuple such as `("github.com", "GET")`, or an `AllowedDomain` named tuple.
- Tools registered with the sandbox return their native Python value
(`dict`, `list`, primitives, or custom objects) directly to the guest via the
Hyperlight FFI. Any `result_parser` configured on a `FunctionTool` is
intended for LLM-facing consumers and does not run on the sandbox path —
apply formatting inside the tool function itself if you need it for
in-sandbox consumers.
@@ -2,42 +2,45 @@
from __future__ import annotations
import ast
import asyncio
import copy
import mimetypes
import shutil
import threading
import time
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import suppress
from copy import copy
from dataclasses import dataclass, field
from pathlib import Path, PurePosixPath
from tempfile import TemporaryDirectory
from typing import Annotated, Any, Protocol, TypeGuard, cast
from typing import Any, Protocol, TypeGuard, TypeVar, cast
from urllib.parse import urlparse
from agent_framework import Content, FunctionTool
from agent_framework._tools import ApprovalMode, normalize_tools
from pydantic import BaseModel, Field
from ._instructions import build_codeact_instructions, build_execute_code_description
from ._types import AllowedDomain, AllowedDomainInput, FileMount, FileMountHostPath, FileMountInput
DEFAULT_HYPERLIGHT_BACKEND = "wasm"
DEFAULT_HYPERLIGHT_MODULE = "python_guest.path"
EXECUTE_CODE_INPUT_DESCRIPTION = "Python code to execute in an isolated Hyperlight sandbox."
EXECUTE_CODE_TOOL_DESCRIPTION = "Execute Python in an isolated Hyperlight sandbox."
OUTPUT_FILE_RETRY_ATTEMPTS = 10
OUTPUT_FILE_RETRY_DELAY_SECONDS = 0.1
class _ExecuteCodeInput(BaseModel):
code: Annotated[str, Field(description=EXECUTE_CODE_INPUT_DESCRIPTION)]
@dataclass(frozen=True, slots=True)
class _StoredFileMount:
host_path: Path
mount_path: str
EXECUTE_CODE_INPUT_SCHEMA: dict[str, Any] = {
"type": "object",
"title": "_ExecuteCodeInput",
"properties": {
"code": {
"type": "string",
"title": "Code",
"description": "Python code to execute in an isolated Hyperlight sandbox.",
},
},
"required": ["code"],
}
@dataclass(frozen=True, slots=True)
@@ -85,13 +88,43 @@ class SandboxRuntime(Protocol):
def execute(self, *, config: _RunConfig, code: str) -> list[Content]: ...
_T = TypeVar("_T")
class _SandboxWorker:
"""Single-threaded executor that confines all sandbox operations to one OS thread.
The Hyperlight ``WasmSandbox`` is declared ``unsendable`` in PyO3, meaning it can only be
accessed from the OS thread that created it; touching it from any other thread triggers a
Rust panic that cannot be caught from Python. Every cached :class:`_SandboxEntry` therefore
owns its own ``_SandboxWorker``, and *all* lifecycle and execution calls against the
underlying sandbox object must be routed through :meth:`submit`/:meth:`run`.
"""
__slots__ = ("_executor",)
def __init__(self, *, name: str = "hl-sandbox") -> None:
self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix=name)
def submit(self, fn: Callable[..., _T], /, *args: Any, **kwargs: Any) -> Future[_T]:
return self._executor.submit(fn, *args, **kwargs)
def run(self, fn: Callable[..., _T], /, *args: Any, **kwargs: Any) -> _T:
return self._executor.submit(fn, *args, **kwargs).result()
def shutdown(self) -> None:
# Do not block on shutdown; stop accepting new tasks, but allow the currently running
# task and any already-queued tasks to finish before the worker thread exits.
self._executor.shutdown(wait=False, cancel_futures=False)
@dataclass
class _SandboxEntry:
sandbox: Any
snapshot: Any
input_dir: TemporaryDirectory[str] | None
output_dir: TemporaryDirectory[str] | None
lock: threading.RLock
worker: _SandboxWorker = field(default_factory=_SandboxWorker)
def _load_sandbox_class() -> type[Any]:
@@ -106,10 +139,6 @@ def _load_sandbox_class() -> type[Any]:
return Sandbox
def _passthrough_result_parser(result: Any) -> str:
return repr(result)
def _collect_tools(*tool_groups: Any) -> list[FunctionTool]:
tools_by_name: dict[str, FunctionTool] = {}
@@ -166,7 +195,7 @@ def _is_file_mount_pair(value: Any) -> TypeGuard[FileMount | tuple[FileMountHost
return isinstance(host_path, (str, Path)) and isinstance(mount_path, str)
def _normalize_file_mount_input(file_mount: FileMountInput) -> _StoredFileMount:
def _normalize_file_mount_input(file_mount: FileMountInput) -> FileMount:
host_path: FileMountHostPath
mount_path: str
if isinstance(file_mount, str):
@@ -176,7 +205,7 @@ def _normalize_file_mount_input(file_mount: FileMountInput) -> _StoredFileMount:
host_path = file_mount[0]
mount_path = file_mount[1]
return _StoredFileMount(
return FileMount(
host_path=_resolve_existing_path(host_path),
mount_path=_normalize_mount_path(mount_path),
)
@@ -445,18 +474,13 @@ def _build_execution_contents(
def _make_sandbox_callback(tool_obj: FunctionTool) -> Callable[..., Any]:
sandbox_tool = copy.copy(tool_obj)
# Auto-assign a passthrough parser so the raw return value round-trips through
# `ast.literal_eval` in the sandbox callback below. User-supplied parsers are
# left in place so callers can customize how results are exposed to the guest.
if sandbox_tool.result_parser is None:
sandbox_tool.result_parser = _passthrough_result_parser
sandbox_tool = copy(tool_obj)
def _callback(**kwargs: Any) -> Any:
async def _invoke() -> list[Content]:
return await sandbox_tool.invoke(arguments=kwargs)
async def _invoke() -> Any:
return await sandbox_tool.invoke(arguments=kwargs, skip_parsing=True)
# FunctionTool.invoke() is always async. The real Hyperlight backend invokes
# FunctionTool.invoke() is async. The real Hyperlight backend invokes
# registered callbacks synchronously via FFI, so this must be a sync function.
# We run the async call on a dedicated thread to avoid conflicts with any
# event loop that may be running on the current thread.
@@ -474,22 +498,11 @@ def _make_sandbox_callback(tool_obj: FunctionTool) -> Callable[..., Any]:
worker.join()
if error_box:
raise error_box[0]
contents: list[Content] = result_box[0]
values: list[Any] = []
for content in contents:
if content.type == "text" and content.text is not None:
try:
values.append(ast.literal_eval(content.text))
except (SyntaxError, ValueError):
values.append(content.text)
continue
values.append(content.to_dict())
if len(values) == 1:
return values[0]
return values
# Return the raw value. The Hyperlight FFI marshals primitives (dict, list,
# str, int, float, bool, None) natively into the guest, and falls back to
# repr()/str() for unsupported types — so the guest receives real Python
# objects without a lossy host-side serialization round-trip.
return result_box[0]
return _callback
@@ -509,7 +522,7 @@ def _clear_directory(output_dir: TemporaryDirectory[str] | None) -> None:
pass
class _SandboxRegistry:
class _SandboxRegistry(SandboxRuntime):
def __init__(self) -> None:
self._entries: dict[tuple[Any, ...], _SandboxEntry] = {}
self._entries_lock = threading.RLock()
@@ -517,28 +530,54 @@ class _SandboxRegistry:
def execute(self, *, config: _RunConfig, code: str) -> list[Content]:
"""Execute code in a cached sandbox matching the given config.
Entries are keyed by ``config.cache_key()``. Concurrent calls with the same
key are serialized by the entry lock so they never race, but they share the
same sandbox instance. For true parallel execution, use distinct provider
instances or configs that produce different cache keys.
Entries are keyed by ``config.cache_key()``. All operations against the underlying
sandbox object are routed through the entry's dedicated single-threaded worker, which
both serializes concurrent callers and satisfies the PyO3 ``unsendable`` invariant
that the sandbox can only be touched from the thread that created it.
"""
entry = self._get_or_create_entry(config)
return entry.worker.run(self._run_on_worker, entry, code)
@staticmethod
def _run_on_worker(entry: _SandboxEntry, code: str) -> list[Content]:
entry.sandbox.restore(entry.snapshot)
_clear_directory(entry.output_dir)
result = entry.sandbox.run(code=code)
return _build_execution_contents(
result=result,
sandbox=entry.sandbox,
output_dir=entry.output_dir,
code=code,
)
def _get_or_create_entry(self, config: _RunConfig) -> _SandboxEntry:
cache_key = config.cache_key()
with self._entries_lock:
entry = self._entries.get(cache_key)
if entry is None:
entry = self._create_entry(config)
self._entries[cache_key] = entry
return entry
with entry.lock:
entry.sandbox.restore(entry.snapshot)
_clear_directory(entry.output_dir)
result = entry.sandbox.run(code=code)
return _build_execution_contents(
result=result,
sandbox=entry.sandbox,
output_dir=entry.output_dir,
code=code,
)
def close(self) -> None:
"""Shut down all per-entry worker threads and release per-entry resources.
Safe to call multiple times. Runs any sandbox close hook on the entry's
own worker thread to honor the PyO3 ``unsendable`` invariant.
"""
with self._entries_lock:
entries = list(self._entries.values())
self._entries.clear()
for entry in entries:
close_hook = getattr(entry.sandbox, "close", None) or getattr(entry.sandbox, "shutdown", None)
if callable(close_hook):
with suppress(Exception):
entry.worker.run(close_hook)
entry.worker.shutdown()
for tmp_dir in (entry.input_dir, entry.output_dir):
if tmp_dir is not None:
with suppress(Exception):
tmp_dir.cleanup()
def _create_entry(self, config: _RunConfig) -> _SandboxEntry:
input_dir_handle = TemporaryDirectory() if config.filesystem_enabled else None
@@ -578,26 +617,37 @@ class _SandboxRegistry:
methods=list(allowed_domain.methods) if allowed_domain.methods is not None else None,
)
sandbox = _create_sandbox()
_configure_sandbox(sandbox=sandbox, expand_missing_scheme=False)
worker = _SandboxWorker()
def _build_sandbox() -> tuple[Any, Any]:
sandbox = _create_sandbox()
_configure_sandbox(sandbox=sandbox, expand_missing_scheme=False)
try:
sandbox.run("None")
except RuntimeError as exc:
if not _should_retry_allowed_domain_registration(error=exc, allowed_domains=config.allowed_domains):
raise
sandbox = _create_sandbox()
_configure_sandbox(sandbox=sandbox, expand_missing_scheme=True)
sandbox.run("None")
snapshot = sandbox.snapshot()
return sandbox, snapshot
try:
sandbox.run("None")
except RuntimeError as exc:
if not _should_retry_allowed_domain_registration(error=exc, allowed_domains=config.allowed_domains):
raise
sandbox, snapshot = worker.run(_build_sandbox)
except BaseException:
worker.shutdown()
raise
sandbox = _create_sandbox()
_configure_sandbox(sandbox=sandbox, expand_missing_scheme=True)
sandbox.run("None")
snapshot = sandbox.snapshot()
return _SandboxEntry(
sandbox=sandbox,
snapshot=snapshot,
input_dir=input_dir_handle,
output_dir=output_dir_handle,
lock=threading.RLock(),
worker=worker,
)
@@ -619,10 +669,10 @@ class HyperlightExecuteCodeTool(FunctionTool):
) -> None:
super().__init__(
name="execute_code",
description=EXECUTE_CODE_INPUT_DESCRIPTION,
description=EXECUTE_CODE_TOOL_DESCRIPTION,
approval_mode="never_require",
func=self._run_code,
input_model=_ExecuteCodeInput,
input_model=EXECUTE_CODE_INPUT_SCHEMA,
)
self._state_lock = threading.RLock()
self._registry = _registry or _SandboxRegistry()
@@ -632,7 +682,7 @@ class HyperlightExecuteCodeTool(FunctionTool):
self._module: str | None = module
self._module_path: str | None = module_path
self._managed_tools: list[FunctionTool] = []
self._file_mounts: dict[str, _StoredFileMount] = {}
self._file_mounts: dict[str, FileMount] = {}
self._allowed_domains: dict[str, AllowedDomain] = {}
if tools is not None:
@@ -648,7 +698,7 @@ class HyperlightExecuteCodeTool(FunctionTool):
def description(self) -> str:
state_lock = getattr(self, "_state_lock", None)
if state_lock is None:
return str(self.__dict__.get("description", EXECUTE_CODE_INPUT_DESCRIPTION))
return str(self.__dict__.get("description", EXECUTE_CODE_TOOL_DESCRIPTION))
with state_lock:
allowed_domains = sorted(self._allowed_domains.values(), key=lambda value: value.target)
@@ -841,9 +891,9 @@ class HyperlightExecuteCodeTool(FunctionTool):
workspace_signature = _path_tree_signature(workspace_root) if workspace_root is not None else ()
normalized_mounts = tuple(
_NormalizedFileMount(
host_path=mount.host_path,
host_path=Path(mount.host_path),
mount_path=mount.mount_path,
path_signature=_path_tree_signature(mount.host_path),
path_signature=_path_tree_signature(Path(mount.host_path)),
)
for mount in stored_mounts
)
@@ -937,3 +937,191 @@ async def test_run_code_does_not_block_event_loop() -> None:
assert concurrent_ran, "Event loop was blocked during sandbox execution"
assert result[0].type == "text"
class _ThreadAffinityFakeSandbox(_FakeSandbox):
"""Fake sandbox that records the OS thread of every method invocation.
Mirrors the PyO3 ``unsendable`` invariant of ``hyperlight_sandbox.WasmSandbox``:
if ``__init__``, ``register_tool``, ``allow_domain``, ``run``, ``snapshot`` or ``restore``
are ever called from more than one thread for a given instance, the test fails.
"""
affinity_failures: list[str] = []
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._owner_thread = threading.get_ident()
self.thread_ids: set[int] = {self._owner_thread}
def _record(self, method: str) -> None:
ident = threading.get_ident()
self.thread_ids.add(ident)
if ident != self._owner_thread:
_ThreadAffinityFakeSandbox.affinity_failures.append(
f"{method} called from thread {ident}, expected {self._owner_thread}"
)
def register_tool(self, name_or_tool: Any, callback: Any | None = None) -> None:
self._record("register_tool")
super().register_tool(name_or_tool, callback)
def allow_domain(self, target: str, methods: list[str] | None = None) -> None:
self._record("allow_domain")
super().allow_domain(target, methods)
def run(self, code: str) -> _FakeResult:
self._record("run")
return super().run(code)
def snapshot(self) -> str:
self._record("snapshot")
return super().snapshot()
def restore(self, snapshot: Any) -> None:
self._record("restore")
super().restore(snapshot)
async def test_sandbox_calls_are_pinned_to_owning_worker_thread(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Regression: WasmSandbox is unsendable; every sandbox call must run on its owner thread."""
_ThreadAffinityFakeSandbox.instances.clear()
_ThreadAffinityFakeSandbox.affinity_failures.clear()
monkeypatch.setattr(execute_code_module, "_load_sandbox_class", lambda: _ThreadAffinityFakeSandbox)
execute_code = HyperlightExecuteCodeTool()
# Invoke many times concurrently; asyncio.to_thread will spread these across the default
# executor's worker threads, which previously caused PyO3 to panic when a different thread
# touched the cached sandbox.
results = await asyncio.gather(*[execute_code.invoke(arguments={"code": "None"}) for _ in range(8)])
for result in results:
assert result[0].type == "text"
assert _ThreadAffinityFakeSandbox.affinity_failures == []
assert len(_ThreadAffinityFakeSandbox.instances) == 1
sandbox = _ThreadAffinityFakeSandbox.instances[0]
# All sandbox-touching calls must have stayed on a single owning thread, distinct from the
# caller thread that asyncio.to_thread used for dispatch.
assert sandbox.thread_ids == {sandbox._owner_thread}
assert sandbox._owner_thread != threading.get_ident()
async def test_sandbox_owner_thread_persists_across_dispatch_threads(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Sequential calls landing on different dispatch threads still share one sandbox thread."""
_ThreadAffinityFakeSandbox.instances.clear()
_ThreadAffinityFakeSandbox.affinity_failures.clear()
monkeypatch.setattr(execute_code_module, "_load_sandbox_class", lambda: _ThreadAffinityFakeSandbox)
execute_code = HyperlightExecuteCodeTool()
for _ in range(5):
result = await execute_code.invoke(arguments={"code": "None"})
assert result[0].type == "text"
assert _ThreadAffinityFakeSandbox.affinity_failures == []
assert len(_ThreadAffinityFakeSandbox.instances) == 1
def test_sandbox_registry_close_shuts_down_workers(monkeypatch: pytest.MonkeyPatch) -> None:
_FakeSandbox.instances.clear()
monkeypatch.setattr(execute_code_module, "_load_sandbox_class", lambda: _FakeSandbox)
registry = execute_code_module._SandboxRegistry()
execute_code = HyperlightExecuteCodeTool(_registry=registry)
asyncio.run(execute_code.invoke(arguments={"code": "None"}))
entries = list(registry._entries.values())
assert len(entries) == 1
worker = entries[0].worker
registry.close()
assert registry._entries == {}
# Submitting after shutdown must fail; this proves the executor was actually torn down.
with pytest.raises(RuntimeError):
worker.submit(lambda: None)
def test_sandbox_registry_close_releases_per_entry_resources(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""close() must invoke any sandbox close hook and release temp directories."""
close_calls: list[int] = []
class _ClosableFakeSandbox(_FakeSandbox):
def close(self) -> None:
close_calls.append(1)
_FakeSandbox.instances.clear()
monkeypatch.setattr(execute_code_module, "_load_sandbox_class", lambda: _ClosableFakeSandbox)
workspace = tmp_path / "workspace"
workspace.mkdir()
registry = execute_code_module._SandboxRegistry()
execute_code = HyperlightExecuteCodeTool(workspace_root=workspace, _registry=registry)
asyncio.run(execute_code.invoke(arguments={"code": "None"}))
entries = list(registry._entries.values())
assert len(entries) == 1
entry = entries[0]
assert entry.input_dir is not None and entry.output_dir is not None
input_path = Path(entry.input_dir.name)
output_path = Path(entry.output_dir.name)
assert input_path.exists() and output_path.exists()
registry.close()
assert close_calls == [1]
assert not input_path.exists()
assert not output_path.exists()
async def test_make_sandbox_callback_returns_native_dict() -> None:
"""Host tool returning a dict must be forwarded as a native dict (no repr round-trip)."""
@tool
def get_weather(city: str) -> dict[str, Any]:
"""Get weather."""
return {"city": city, "temp_c": 21.5}
callback = execute_code_module._make_sandbox_callback(get_weather)
result = callback(city="Seattle")
assert isinstance(result, dict)
assert result == {"city": "Seattle", "temp_c": 21.5}
async def test_make_sandbox_callback_bypasses_user_result_parser() -> None:
"""Documented behavior change: result_parser is bypassed in the sandbox path."""
parser_calls: list[Any] = []
def parser(value: Any) -> str:
parser_calls.append(value)
return "PARSED"
@tool(result_parser=parser)
def make_payload() -> dict[str, int]:
"""Returns a dict."""
return {"a": 1, "b": 2}
callback = execute_code_module._make_sandbox_callback(make_payload)
result = callback()
assert result == {"a": 1, "b": 2}
assert parser_calls == [], "result_parser must not run on the sandbox path"
async def test_make_sandbox_callback_propagates_exceptions() -> None:
@tool
def boom(x: int) -> int:
"""Always fails."""
raise RuntimeError("nope")
callback = execute_code_module._make_sandbox_callback(boom)
with pytest.raises(RuntimeError, match="nope"):
callback(x=1)