mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Hyperlight: thread-confine sandbox, skip parsing on host callbacks, schema/tool cleanup (#5424)
* improved parsing of tool call results and tweaks * Address PR review: skip_parsing flag, broader registry close, comment fix - FunctionTool.invoke now takes a boolean skip_parsing flag instead of the SKIP_PARSING sentinel; the sentinel is still accepted as result_parser at construction time to opt out of parsing for every call. The two paths are equivalent. - _SandboxRegistry.close now invokes any sandbox close/shutdown hook on the entry's own worker thread (PyO3 unsendable), then shuts the worker down, then cleans up the per-entry temporary directories. - Clarified the _SandboxWorker.shutdown comment to describe the actual ThreadPoolExecutor.shutdown(wait=False, cancel_futures=False) semantics. - Hyperlight host callback uses skip_parsing=True (the new flag). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Drop redundant 'is not SKIP_PARSING' guard that mypy 1.x flags After callable(configured_parser) the sentinel is already excluded; the extra identity check tripped mypy's non-overlapping identity warning. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fixed sandbox working on copy of tool --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
66e02c10e3
commit
58ff4ad3a9
@@ -130,3 +130,9 @@ codeact = HyperlightCodeActProvider(
|
||||
- `allowed_domains` accepts a single string target such as `"github.com"` to
|
||||
allow all backend-supported methods, an explicit `(target, method_or_methods)`
|
||||
tuple such as `("github.com", "GET")`, or an `AllowedDomain` named tuple.
|
||||
- Tools registered with the sandbox return their native Python value
|
||||
(`dict`, `list`, primitives, or custom objects) directly to the guest via the
|
||||
Hyperlight FFI. Any `result_parser` configured on a `FunctionTool` is
|
||||
intended for LLM-facing consumers and does not run on the sandbox path —
|
||||
apply formatting inside the tool function itself if you need it for
|
||||
in-sandbox consumers.
|
||||
|
||||
@@ -2,42 +2,45 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import copy
|
||||
import mimetypes
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import suppress
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path, PurePosixPath
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Annotated, Any, Protocol, TypeGuard, cast
|
||||
from typing import Any, Protocol, TypeGuard, TypeVar, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from agent_framework import Content, FunctionTool
|
||||
from agent_framework._tools import ApprovalMode, normalize_tools
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ._instructions import build_codeact_instructions, build_execute_code_description
|
||||
from ._types import AllowedDomain, AllowedDomainInput, FileMount, FileMountHostPath, FileMountInput
|
||||
|
||||
DEFAULT_HYPERLIGHT_BACKEND = "wasm"
|
||||
DEFAULT_HYPERLIGHT_MODULE = "python_guest.path"
|
||||
EXECUTE_CODE_INPUT_DESCRIPTION = "Python code to execute in an isolated Hyperlight sandbox."
|
||||
EXECUTE_CODE_TOOL_DESCRIPTION = "Execute Python in an isolated Hyperlight sandbox."
|
||||
OUTPUT_FILE_RETRY_ATTEMPTS = 10
|
||||
OUTPUT_FILE_RETRY_DELAY_SECONDS = 0.1
|
||||
|
||||
|
||||
class _ExecuteCodeInput(BaseModel):
|
||||
code: Annotated[str, Field(description=EXECUTE_CODE_INPUT_DESCRIPTION)]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _StoredFileMount:
|
||||
host_path: Path
|
||||
mount_path: str
|
||||
EXECUTE_CODE_INPUT_SCHEMA: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"title": "_ExecuteCodeInput",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"title": "Code",
|
||||
"description": "Python code to execute in an isolated Hyperlight sandbox.",
|
||||
},
|
||||
},
|
||||
"required": ["code"],
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
@@ -85,13 +88,43 @@ class SandboxRuntime(Protocol):
|
||||
def execute(self, *, config: _RunConfig, code: str) -> list[Content]: ...
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class _SandboxWorker:
|
||||
"""Single-threaded executor that confines all sandbox operations to one OS thread.
|
||||
|
||||
The Hyperlight ``WasmSandbox`` is declared ``unsendable`` in PyO3, meaning it can only be
|
||||
accessed from the OS thread that created it; touching it from any other thread triggers a
|
||||
Rust panic that cannot be caught from Python. Every cached :class:`_SandboxEntry` therefore
|
||||
owns its own ``_SandboxWorker``, and *all* lifecycle and execution calls against the
|
||||
underlying sandbox object must be routed through :meth:`submit`/:meth:`run`.
|
||||
"""
|
||||
|
||||
__slots__ = ("_executor",)
|
||||
|
||||
def __init__(self, *, name: str = "hl-sandbox") -> None:
|
||||
self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix=name)
|
||||
|
||||
def submit(self, fn: Callable[..., _T], /, *args: Any, **kwargs: Any) -> Future[_T]:
|
||||
return self._executor.submit(fn, *args, **kwargs)
|
||||
|
||||
def run(self, fn: Callable[..., _T], /, *args: Any, **kwargs: Any) -> _T:
|
||||
return self._executor.submit(fn, *args, **kwargs).result()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
# Do not block on shutdown; stop accepting new tasks, but allow the currently running
|
||||
# task and any already-queued tasks to finish before the worker thread exits.
|
||||
self._executor.shutdown(wait=False, cancel_futures=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SandboxEntry:
|
||||
sandbox: Any
|
||||
snapshot: Any
|
||||
input_dir: TemporaryDirectory[str] | None
|
||||
output_dir: TemporaryDirectory[str] | None
|
||||
lock: threading.RLock
|
||||
worker: _SandboxWorker = field(default_factory=_SandboxWorker)
|
||||
|
||||
|
||||
def _load_sandbox_class() -> type[Any]:
|
||||
@@ -106,10 +139,6 @@ def _load_sandbox_class() -> type[Any]:
|
||||
return Sandbox
|
||||
|
||||
|
||||
def _passthrough_result_parser(result: Any) -> str:
|
||||
return repr(result)
|
||||
|
||||
|
||||
def _collect_tools(*tool_groups: Any) -> list[FunctionTool]:
|
||||
tools_by_name: dict[str, FunctionTool] = {}
|
||||
|
||||
@@ -166,7 +195,7 @@ def _is_file_mount_pair(value: Any) -> TypeGuard[FileMount | tuple[FileMountHost
|
||||
return isinstance(host_path, (str, Path)) and isinstance(mount_path, str)
|
||||
|
||||
|
||||
def _normalize_file_mount_input(file_mount: FileMountInput) -> _StoredFileMount:
|
||||
def _normalize_file_mount_input(file_mount: FileMountInput) -> FileMount:
|
||||
host_path: FileMountHostPath
|
||||
mount_path: str
|
||||
if isinstance(file_mount, str):
|
||||
@@ -176,7 +205,7 @@ def _normalize_file_mount_input(file_mount: FileMountInput) -> _StoredFileMount:
|
||||
host_path = file_mount[0]
|
||||
mount_path = file_mount[1]
|
||||
|
||||
return _StoredFileMount(
|
||||
return FileMount(
|
||||
host_path=_resolve_existing_path(host_path),
|
||||
mount_path=_normalize_mount_path(mount_path),
|
||||
)
|
||||
@@ -445,18 +474,13 @@ def _build_execution_contents(
|
||||
|
||||
|
||||
def _make_sandbox_callback(tool_obj: FunctionTool) -> Callable[..., Any]:
|
||||
sandbox_tool = copy.copy(tool_obj)
|
||||
# Auto-assign a passthrough parser so the raw return value round-trips through
|
||||
# `ast.literal_eval` in the sandbox callback below. User-supplied parsers are
|
||||
# left in place so callers can customize how results are exposed to the guest.
|
||||
if sandbox_tool.result_parser is None:
|
||||
sandbox_tool.result_parser = _passthrough_result_parser
|
||||
sandbox_tool = copy(tool_obj)
|
||||
|
||||
def _callback(**kwargs: Any) -> Any:
|
||||
async def _invoke() -> list[Content]:
|
||||
return await sandbox_tool.invoke(arguments=kwargs)
|
||||
async def _invoke() -> Any:
|
||||
return await sandbox_tool.invoke(arguments=kwargs, skip_parsing=True)
|
||||
|
||||
# FunctionTool.invoke() is always async. The real Hyperlight backend invokes
|
||||
# FunctionTool.invoke() is async. The real Hyperlight backend invokes
|
||||
# registered callbacks synchronously via FFI, so this must be a sync function.
|
||||
# We run the async call on a dedicated thread to avoid conflicts with any
|
||||
# event loop that may be running on the current thread.
|
||||
@@ -474,22 +498,11 @@ def _make_sandbox_callback(tool_obj: FunctionTool) -> Callable[..., Any]:
|
||||
worker.join()
|
||||
if error_box:
|
||||
raise error_box[0]
|
||||
contents: list[Content] = result_box[0]
|
||||
|
||||
values: list[Any] = []
|
||||
for content in contents:
|
||||
if content.type == "text" and content.text is not None:
|
||||
try:
|
||||
values.append(ast.literal_eval(content.text))
|
||||
except (SyntaxError, ValueError):
|
||||
values.append(content.text)
|
||||
continue
|
||||
|
||||
values.append(content.to_dict())
|
||||
|
||||
if len(values) == 1:
|
||||
return values[0]
|
||||
return values
|
||||
# Return the raw value. The Hyperlight FFI marshals primitives (dict, list,
|
||||
# str, int, float, bool, None) natively into the guest, and falls back to
|
||||
# repr()/str() for unsupported types — so the guest receives real Python
|
||||
# objects without a lossy host-side serialization round-trip.
|
||||
return result_box[0]
|
||||
|
||||
return _callback
|
||||
|
||||
@@ -509,7 +522,7 @@ def _clear_directory(output_dir: TemporaryDirectory[str] | None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _SandboxRegistry:
|
||||
class _SandboxRegistry(SandboxRuntime):
|
||||
def __init__(self) -> None:
|
||||
self._entries: dict[tuple[Any, ...], _SandboxEntry] = {}
|
||||
self._entries_lock = threading.RLock()
|
||||
@@ -517,28 +530,54 @@ class _SandboxRegistry:
|
||||
def execute(self, *, config: _RunConfig, code: str) -> list[Content]:
|
||||
"""Execute code in a cached sandbox matching the given config.
|
||||
|
||||
Entries are keyed by ``config.cache_key()``. Concurrent calls with the same
|
||||
key are serialized by the entry lock so they never race, but they share the
|
||||
same sandbox instance. For true parallel execution, use distinct provider
|
||||
instances or configs that produce different cache keys.
|
||||
Entries are keyed by ``config.cache_key()``. All operations against the underlying
|
||||
sandbox object are routed through the entry's dedicated single-threaded worker, which
|
||||
both serializes concurrent callers and satisfies the PyO3 ``unsendable`` invariant
|
||||
that the sandbox can only be touched from the thread that created it.
|
||||
"""
|
||||
entry = self._get_or_create_entry(config)
|
||||
return entry.worker.run(self._run_on_worker, entry, code)
|
||||
|
||||
@staticmethod
|
||||
def _run_on_worker(entry: _SandboxEntry, code: str) -> list[Content]:
|
||||
entry.sandbox.restore(entry.snapshot)
|
||||
_clear_directory(entry.output_dir)
|
||||
result = entry.sandbox.run(code=code)
|
||||
return _build_execution_contents(
|
||||
result=result,
|
||||
sandbox=entry.sandbox,
|
||||
output_dir=entry.output_dir,
|
||||
code=code,
|
||||
)
|
||||
|
||||
def _get_or_create_entry(self, config: _RunConfig) -> _SandboxEntry:
|
||||
cache_key = config.cache_key()
|
||||
with self._entries_lock:
|
||||
entry = self._entries.get(cache_key)
|
||||
if entry is None:
|
||||
entry = self._create_entry(config)
|
||||
self._entries[cache_key] = entry
|
||||
return entry
|
||||
|
||||
with entry.lock:
|
||||
entry.sandbox.restore(entry.snapshot)
|
||||
_clear_directory(entry.output_dir)
|
||||
result = entry.sandbox.run(code=code)
|
||||
return _build_execution_contents(
|
||||
result=result,
|
||||
sandbox=entry.sandbox,
|
||||
output_dir=entry.output_dir,
|
||||
code=code,
|
||||
)
|
||||
def close(self) -> None:
|
||||
"""Shut down all per-entry worker threads and release per-entry resources.
|
||||
|
||||
Safe to call multiple times. Runs any sandbox close hook on the entry's
|
||||
own worker thread to honor the PyO3 ``unsendable`` invariant.
|
||||
"""
|
||||
with self._entries_lock:
|
||||
entries = list(self._entries.values())
|
||||
self._entries.clear()
|
||||
for entry in entries:
|
||||
close_hook = getattr(entry.sandbox, "close", None) or getattr(entry.sandbox, "shutdown", None)
|
||||
if callable(close_hook):
|
||||
with suppress(Exception):
|
||||
entry.worker.run(close_hook)
|
||||
entry.worker.shutdown()
|
||||
for tmp_dir in (entry.input_dir, entry.output_dir):
|
||||
if tmp_dir is not None:
|
||||
with suppress(Exception):
|
||||
tmp_dir.cleanup()
|
||||
|
||||
def _create_entry(self, config: _RunConfig) -> _SandboxEntry:
|
||||
input_dir_handle = TemporaryDirectory() if config.filesystem_enabled else None
|
||||
@@ -578,26 +617,37 @@ class _SandboxRegistry:
|
||||
methods=list(allowed_domain.methods) if allowed_domain.methods is not None else None,
|
||||
)
|
||||
|
||||
sandbox = _create_sandbox()
|
||||
_configure_sandbox(sandbox=sandbox, expand_missing_scheme=False)
|
||||
worker = _SandboxWorker()
|
||||
|
||||
def _build_sandbox() -> tuple[Any, Any]:
|
||||
sandbox = _create_sandbox()
|
||||
_configure_sandbox(sandbox=sandbox, expand_missing_scheme=False)
|
||||
|
||||
try:
|
||||
sandbox.run("None")
|
||||
except RuntimeError as exc:
|
||||
if not _should_retry_allowed_domain_registration(error=exc, allowed_domains=config.allowed_domains):
|
||||
raise
|
||||
|
||||
sandbox = _create_sandbox()
|
||||
_configure_sandbox(sandbox=sandbox, expand_missing_scheme=True)
|
||||
sandbox.run("None")
|
||||
|
||||
snapshot = sandbox.snapshot()
|
||||
return sandbox, snapshot
|
||||
|
||||
try:
|
||||
sandbox.run("None")
|
||||
except RuntimeError as exc:
|
||||
if not _should_retry_allowed_domain_registration(error=exc, allowed_domains=config.allowed_domains):
|
||||
raise
|
||||
sandbox, snapshot = worker.run(_build_sandbox)
|
||||
except BaseException:
|
||||
worker.shutdown()
|
||||
raise
|
||||
|
||||
sandbox = _create_sandbox()
|
||||
_configure_sandbox(sandbox=sandbox, expand_missing_scheme=True)
|
||||
sandbox.run("None")
|
||||
|
||||
snapshot = sandbox.snapshot()
|
||||
return _SandboxEntry(
|
||||
sandbox=sandbox,
|
||||
snapshot=snapshot,
|
||||
input_dir=input_dir_handle,
|
||||
output_dir=output_dir_handle,
|
||||
lock=threading.RLock(),
|
||||
worker=worker,
|
||||
)
|
||||
|
||||
|
||||
@@ -619,10 +669,10 @@ class HyperlightExecuteCodeTool(FunctionTool):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name="execute_code",
|
||||
description=EXECUTE_CODE_INPUT_DESCRIPTION,
|
||||
description=EXECUTE_CODE_TOOL_DESCRIPTION,
|
||||
approval_mode="never_require",
|
||||
func=self._run_code,
|
||||
input_model=_ExecuteCodeInput,
|
||||
input_model=EXECUTE_CODE_INPUT_SCHEMA,
|
||||
)
|
||||
self._state_lock = threading.RLock()
|
||||
self._registry = _registry or _SandboxRegistry()
|
||||
@@ -632,7 +682,7 @@ class HyperlightExecuteCodeTool(FunctionTool):
|
||||
self._module: str | None = module
|
||||
self._module_path: str | None = module_path
|
||||
self._managed_tools: list[FunctionTool] = []
|
||||
self._file_mounts: dict[str, _StoredFileMount] = {}
|
||||
self._file_mounts: dict[str, FileMount] = {}
|
||||
self._allowed_domains: dict[str, AllowedDomain] = {}
|
||||
|
||||
if tools is not None:
|
||||
@@ -648,7 +698,7 @@ class HyperlightExecuteCodeTool(FunctionTool):
|
||||
def description(self) -> str:
|
||||
state_lock = getattr(self, "_state_lock", None)
|
||||
if state_lock is None:
|
||||
return str(self.__dict__.get("description", EXECUTE_CODE_INPUT_DESCRIPTION))
|
||||
return str(self.__dict__.get("description", EXECUTE_CODE_TOOL_DESCRIPTION))
|
||||
|
||||
with state_lock:
|
||||
allowed_domains = sorted(self._allowed_domains.values(), key=lambda value: value.target)
|
||||
@@ -841,9 +891,9 @@ class HyperlightExecuteCodeTool(FunctionTool):
|
||||
workspace_signature = _path_tree_signature(workspace_root) if workspace_root is not None else ()
|
||||
normalized_mounts = tuple(
|
||||
_NormalizedFileMount(
|
||||
host_path=mount.host_path,
|
||||
host_path=Path(mount.host_path),
|
||||
mount_path=mount.mount_path,
|
||||
path_signature=_path_tree_signature(mount.host_path),
|
||||
path_signature=_path_tree_signature(Path(mount.host_path)),
|
||||
)
|
||||
for mount in stored_mounts
|
||||
)
|
||||
|
||||
@@ -937,3 +937,191 @@ async def test_run_code_does_not_block_event_loop() -> None:
|
||||
|
||||
assert concurrent_ran, "Event loop was blocked during sandbox execution"
|
||||
assert result[0].type == "text"
|
||||
|
||||
|
||||
class _ThreadAffinityFakeSandbox(_FakeSandbox):
|
||||
"""Fake sandbox that records the OS thread of every method invocation.
|
||||
|
||||
Mirrors the PyO3 ``unsendable`` invariant of ``hyperlight_sandbox.WasmSandbox``:
|
||||
if ``__init__``, ``register_tool``, ``allow_domain``, ``run``, ``snapshot`` or ``restore``
|
||||
are ever called from more than one thread for a given instance, the test fails.
|
||||
"""
|
||||
|
||||
affinity_failures: list[str] = []
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._owner_thread = threading.get_ident()
|
||||
self.thread_ids: set[int] = {self._owner_thread}
|
||||
|
||||
def _record(self, method: str) -> None:
|
||||
ident = threading.get_ident()
|
||||
self.thread_ids.add(ident)
|
||||
if ident != self._owner_thread:
|
||||
_ThreadAffinityFakeSandbox.affinity_failures.append(
|
||||
f"{method} called from thread {ident}, expected {self._owner_thread}"
|
||||
)
|
||||
|
||||
def register_tool(self, name_or_tool: Any, callback: Any | None = None) -> None:
|
||||
self._record("register_tool")
|
||||
super().register_tool(name_or_tool, callback)
|
||||
|
||||
def allow_domain(self, target: str, methods: list[str] | None = None) -> None:
|
||||
self._record("allow_domain")
|
||||
super().allow_domain(target, methods)
|
||||
|
||||
def run(self, code: str) -> _FakeResult:
|
||||
self._record("run")
|
||||
return super().run(code)
|
||||
|
||||
def snapshot(self) -> str:
|
||||
self._record("snapshot")
|
||||
return super().snapshot()
|
||||
|
||||
def restore(self, snapshot: Any) -> None:
|
||||
self._record("restore")
|
||||
super().restore(snapshot)
|
||||
|
||||
|
||||
async def test_sandbox_calls_are_pinned_to_owning_worker_thread(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Regression: WasmSandbox is unsendable; every sandbox call must run on its owner thread."""
|
||||
_ThreadAffinityFakeSandbox.instances.clear()
|
||||
_ThreadAffinityFakeSandbox.affinity_failures.clear()
|
||||
monkeypatch.setattr(execute_code_module, "_load_sandbox_class", lambda: _ThreadAffinityFakeSandbox)
|
||||
|
||||
execute_code = HyperlightExecuteCodeTool()
|
||||
|
||||
# Invoke many times concurrently; asyncio.to_thread will spread these across the default
|
||||
# executor's worker threads, which previously caused PyO3 to panic when a different thread
|
||||
# touched the cached sandbox.
|
||||
results = await asyncio.gather(*[execute_code.invoke(arguments={"code": "None"}) for _ in range(8)])
|
||||
for result in results:
|
||||
assert result[0].type == "text"
|
||||
|
||||
assert _ThreadAffinityFakeSandbox.affinity_failures == []
|
||||
assert len(_ThreadAffinityFakeSandbox.instances) == 1
|
||||
sandbox = _ThreadAffinityFakeSandbox.instances[0]
|
||||
# All sandbox-touching calls must have stayed on a single owning thread, distinct from the
|
||||
# caller thread that asyncio.to_thread used for dispatch.
|
||||
assert sandbox.thread_ids == {sandbox._owner_thread}
|
||||
assert sandbox._owner_thread != threading.get_ident()
|
||||
|
||||
|
||||
async def test_sandbox_owner_thread_persists_across_dispatch_threads(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Sequential calls landing on different dispatch threads still share one sandbox thread."""
|
||||
_ThreadAffinityFakeSandbox.instances.clear()
|
||||
_ThreadAffinityFakeSandbox.affinity_failures.clear()
|
||||
monkeypatch.setattr(execute_code_module, "_load_sandbox_class", lambda: _ThreadAffinityFakeSandbox)
|
||||
|
||||
execute_code = HyperlightExecuteCodeTool()
|
||||
|
||||
for _ in range(5):
|
||||
result = await execute_code.invoke(arguments={"code": "None"})
|
||||
assert result[0].type == "text"
|
||||
|
||||
assert _ThreadAffinityFakeSandbox.affinity_failures == []
|
||||
assert len(_ThreadAffinityFakeSandbox.instances) == 1
|
||||
|
||||
|
||||
def test_sandbox_registry_close_shuts_down_workers(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_FakeSandbox.instances.clear()
|
||||
monkeypatch.setattr(execute_code_module, "_load_sandbox_class", lambda: _FakeSandbox)
|
||||
|
||||
registry = execute_code_module._SandboxRegistry()
|
||||
execute_code = HyperlightExecuteCodeTool(_registry=registry)
|
||||
asyncio.run(execute_code.invoke(arguments={"code": "None"}))
|
||||
|
||||
entries = list(registry._entries.values())
|
||||
assert len(entries) == 1
|
||||
worker = entries[0].worker
|
||||
|
||||
registry.close()
|
||||
|
||||
assert registry._entries == {}
|
||||
# Submitting after shutdown must fail; this proves the executor was actually torn down.
|
||||
with pytest.raises(RuntimeError):
|
||||
worker.submit(lambda: None)
|
||||
|
||||
|
||||
def test_sandbox_registry_close_releases_per_entry_resources(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
"""close() must invoke any sandbox close hook and release temp directories."""
|
||||
|
||||
close_calls: list[int] = []
|
||||
|
||||
class _ClosableFakeSandbox(_FakeSandbox):
|
||||
def close(self) -> None:
|
||||
close_calls.append(1)
|
||||
|
||||
_FakeSandbox.instances.clear()
|
||||
monkeypatch.setattr(execute_code_module, "_load_sandbox_class", lambda: _ClosableFakeSandbox)
|
||||
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir()
|
||||
registry = execute_code_module._SandboxRegistry()
|
||||
execute_code = HyperlightExecuteCodeTool(workspace_root=workspace, _registry=registry)
|
||||
asyncio.run(execute_code.invoke(arguments={"code": "None"}))
|
||||
|
||||
entries = list(registry._entries.values())
|
||||
assert len(entries) == 1
|
||||
entry = entries[0]
|
||||
assert entry.input_dir is not None and entry.output_dir is not None
|
||||
input_path = Path(entry.input_dir.name)
|
||||
output_path = Path(entry.output_dir.name)
|
||||
assert input_path.exists() and output_path.exists()
|
||||
|
||||
registry.close()
|
||||
|
||||
assert close_calls == [1]
|
||||
assert not input_path.exists()
|
||||
assert not output_path.exists()
|
||||
|
||||
|
||||
async def test_make_sandbox_callback_returns_native_dict() -> None:
|
||||
"""Host tool returning a dict must be forwarded as a native dict (no repr round-trip)."""
|
||||
|
||||
@tool
|
||||
def get_weather(city: str) -> dict[str, Any]:
|
||||
"""Get weather."""
|
||||
return {"city": city, "temp_c": 21.5}
|
||||
|
||||
callback = execute_code_module._make_sandbox_callback(get_weather)
|
||||
result = callback(city="Seattle")
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result == {"city": "Seattle", "temp_c": 21.5}
|
||||
|
||||
|
||||
async def test_make_sandbox_callback_bypasses_user_result_parser() -> None:
|
||||
"""Documented behavior change: result_parser is bypassed in the sandbox path."""
|
||||
|
||||
parser_calls: list[Any] = []
|
||||
|
||||
def parser(value: Any) -> str:
|
||||
parser_calls.append(value)
|
||||
return "PARSED"
|
||||
|
||||
@tool(result_parser=parser)
|
||||
def make_payload() -> dict[str, int]:
|
||||
"""Returns a dict."""
|
||||
return {"a": 1, "b": 2}
|
||||
|
||||
callback = execute_code_module._make_sandbox_callback(make_payload)
|
||||
result = callback()
|
||||
|
||||
assert result == {"a": 1, "b": 2}
|
||||
assert parser_calls == [], "result_parser must not run on the sandbox path"
|
||||
|
||||
|
||||
async def test_make_sandbox_callback_propagates_exceptions() -> None:
|
||||
@tool
|
||||
def boom(x: int) -> int:
|
||||
"""Always fails."""
|
||||
raise RuntimeError("nope")
|
||||
|
||||
callback = execute_code_module._make_sandbox_callback(boom)
|
||||
with pytest.raises(RuntimeError, match="nope"):
|
||||
callback(x=1)
|
||||
|
||||
Reference in New Issue
Block a user