mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Restrict persisted checkpoint deserialization by default (#4941)
* Harden Python checkpoint persistence defaults Add RestrictedUnpickler to _checkpoint_encoding.py that limits which types may be instantiated during pickle deserialization. By default FileCheckpointStorage now uses the restricted unpickler, allowing only: - Built-in Python value types (primitives, datetime, uuid, decimal, collections, etc.) - All agent_framework.* internal types - Additional types specified via the new allowed_checkpoint_types parameter on FileCheckpointStorage This narrows the default type surface area for persisted checkpoints while keeping framework-owned scenarios working without extra configuration. Developers can extend the allowed set by passing "module:qualname" strings to allowed_checkpoint_types. The decode_checkpoint_value function retains backward-compatible unrestricted behavior when called without the new allowed_types kwarg. Fixes #4894 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: resolve mypy no-any-return error in checkpoint encoding Add explicit type annotation for super().find_class() return value to satisfy mypy's no-any-return check. Fixes #4894 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Simplify find_class return in _RestrictedUnpickler (#4894) Remove unnecessary intermediate variable and apply # noqa: S301 # nosec directly on the super().find_class() call, matching the established pattern used on the pickle.loads() call in the same file. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #4894: Python: Harden Python checkpoint persistence defaults * Restore # noqa: S301 on line 102 of _checkpoint_encoding.py (#4894) The review feedback correctly identified that removing the # noqa: S301 suppression from the find_class return statement would cause a ruff S301 lint failure, since the project enables bandit ("S") rules. This restores consistency with lines 82 and 246 in the same file. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #4894: Python: Harden Python checkpoint persistence defaults * Address PR review comments on checkpoint encoding (#4894) - Move module docstring to proper position after __future__ import - Fix find_class return type annotation to type[Any] - Add missing # noqa: S301 pragma on find_class return - Improve error message to reference both allowed_types param and FileCheckpointStorage.allowed_checkpoint_types - Add -> None return annotation to FileCheckpointStorage.__init__ - Replace tempfile.mktemp with TemporaryDirectory in test - Replace contextlib.suppress with pytest.raises for precise assertion - Remove unused contextlib import Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address PR #4941 review comments: fix docstring position and return type - Move module docstring before 'from __future__' import so it populates __doc__ (comment #4) - Change find_class return annotation from type[Any] to type to avoid misleading callers about non-type returns like copyreg._reconstructor (comment #2) Comments #1, #3, #5, #6, #7, #8 were already addressed in the current code. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #4894: review comment fixes * fix: use pickle.UnpicklingError in RestrictedUnpickler and improve docstring (#4894) - Change _RestrictedUnpickler.find_class to raise pickle.UnpicklingError instead of WorkflowCheckpointException, since it is pickle-level concern that gets wrapped by the caller in _base64_to_unpickle. - Remove now-unnecessary WorkflowCheckpointException re-raise in _base64_to_unpickle (pickle.UnpicklingError is caught by the generic except Exception handler and wrapped). - Expand decode_checkpoint_value docstring to show a concrete example of the module:qualname format with a user-defined class. - Add regression test verifying find_class raises pickle.UnpicklingError. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: address PR #4941 review comments for checkpoint encoding - Comment 1 (line 103): Already resolved in prior commit — _RestrictedUnpickler now raises pickle.UnpicklingError instead of WorkflowCheckpointException. - Comment 2 (line 140): Add concrete usage examples to decode_checkpoint_value docstring showing both direct allowed_types usage and FileCheckpointStorage allowed_checkpoint_types usage. Rename 'SafeState' to 'MyState' across all docstrings for consistency, making it clear this is a user-defined class name. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: replace deprecated 'builtin' repo with pre-commit-hooks in pre-commit config pre-commit 4.x no longer supports 'repo: builtin'. Merge those hooks into the existing pre-commit-hooks repo entry. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * style: apply pyupgrade formatting to docstring example Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: resolve pre-commit hook paths for monorepo git root The poe-check and bandit hooks referenced paths relative to python/ but pre-commit runs hooks from the git root (monorepo root). Fix poe-check entry to cd into python/ first, and update bandit config path to python/pyproject.toml. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix pre-commit config paths for prek --cd python execution Revert bandit config path from 'python/pyproject.toml' to 'pyproject.toml' and poe-check entry from explicit 'cd python' wrapper to direct invocation, since prek --cd python already sets the working directory to python/. Also apply ruff formatting fixes to cosmos checkpoint storage files. Fixes #4894 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: add builtins:getattr to checkpoint deserialization allowlist Pickle uses builtins:getattr to reconstruct enum members (e.g., WorkflowMessage.type which is a MessageType enum). Without it in the allowlist, checkpoint roundtrip tests fail with WorkflowCheckpointException. Fixes #4894 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #4894: review comment fixes --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
5e8fe0be1f
commit
4dbe696e0e
@@ -1,7 +1,8 @@
|
||||
fail_fast: true
|
||||
exclude: ^scripts/
|
||||
repos:
|
||||
- repo: builtin
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v6.0.0
|
||||
hooks:
|
||||
- id: check-toml
|
||||
name: Check TOML files
|
||||
@@ -34,9 +35,6 @@ repos:
|
||||
- id: no-commit-to-branch
|
||||
name: Protect main branch
|
||||
args: [--branch, main]
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v6.0.0
|
||||
hooks:
|
||||
- id: check-ast
|
||||
name: Check Valid Python Samples
|
||||
types: ["python"]
|
||||
|
||||
@@ -244,14 +244,39 @@ class FileCheckpointStorage:
|
||||
is serialized using pickle and embedded as base64-encoded strings within the JSON. This allows
|
||||
for human-readable checkpoint files while preserving the ability to store complex Python objects.
|
||||
|
||||
SECURITY WARNING: Checkpoints use pickle for data serialization. Only load checkpoints
|
||||
from trusted sources. Loading a malicious checkpoint file can execute arbitrary code.
|
||||
By default, checkpoint deserialization is restricted to a built-in set of safe
|
||||
Python types (primitives, datetime, uuid, ...) and all ``agent_framework``
|
||||
internal types. To allow additional application-specific types, pass them via
|
||||
the ``allowed_checkpoint_types`` parameter using ``"module:qualname"`` format.
|
||||
|
||||
Example::
|
||||
|
||||
storage = FileCheckpointStorage(
|
||||
"/tmp/checkpoints",
|
||||
allowed_checkpoint_types=[
|
||||
"my_app.models:MyState",
|
||||
],
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, storage_path: str | Path):
|
||||
"""Initialize the file storage."""
|
||||
def __init__(
|
||||
self,
|
||||
storage_path: str | Path,
|
||||
*,
|
||||
allowed_checkpoint_types: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the file storage.
|
||||
|
||||
Args:
|
||||
storage_path: Directory path where checkpoint files will be stored.
|
||||
allowed_checkpoint_types: Additional types (beyond the built-in safe set
|
||||
and framework types) that are permitted during checkpoint
|
||||
deserialization. Each entry should be a ``"module:qualname"``
|
||||
string (e.g., ``"my_app.models:MyState"``).
|
||||
"""
|
||||
self.storage_path = Path(storage_path)
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
self._allowed_types: frozenset[str] = frozenset(allowed_checkpoint_types or [])
|
||||
logger.info(f"Initialized file checkpoint storage at {self.storage_path}")
|
||||
|
||||
def _validate_file_path(self, checkpoint_id: CheckpointID) -> Path:
|
||||
@@ -327,7 +352,7 @@ class FileCheckpointStorage:
|
||||
from ._checkpoint_encoding import decode_checkpoint_value
|
||||
|
||||
try:
|
||||
decoded_checkpoint_dict = decode_checkpoint_value(encoded_checkpoint)
|
||||
decoded_checkpoint_dict = decode_checkpoint_value(encoded_checkpoint, allowed_types=self._allowed_types)
|
||||
except WorkflowCheckpointException:
|
||||
raise
|
||||
checkpoint = WorkflowCheckpoint.from_dict(decoded_checkpoint_dict)
|
||||
@@ -352,7 +377,9 @@ class FileCheckpointStorage:
|
||||
encoded_checkpoint = json.load(f)
|
||||
from ._checkpoint_encoding import decode_checkpoint_value
|
||||
|
||||
decoded_checkpoint_dict = decode_checkpoint_value(encoded_checkpoint)
|
||||
decoded_checkpoint_dict = decode_checkpoint_value(
|
||||
encoded_checkpoint, allowed_types=self._allowed_types
|
||||
)
|
||||
checkpoint = WorkflowCheckpoint.from_dict(decoded_checkpoint_dict)
|
||||
if checkpoint.workflow_name == workflow_name:
|
||||
checkpoints.append(checkpoint)
|
||||
|
||||
@@ -1,14 +1,5 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import pickle # nosec # noqa: S403
|
||||
from typing import Any
|
||||
|
||||
from ..exceptions import WorkflowCheckpointException
|
||||
|
||||
"""Checkpoint encoding using JSON structure with pickle+base64 for arbitrary data.
|
||||
|
||||
This hybrid approach provides:
|
||||
@@ -16,10 +7,23 @@ This hybrid approach provides:
|
||||
- Full Python object fidelity via pickle for data values (non-JSON-native types)
|
||||
- Base64 encoding to embed binary pickle data in JSON strings
|
||||
|
||||
SECURITY WARNING: Checkpoints use pickle for data serialization. Only load checkpoints
|
||||
from trusted sources. Loading a malicious checkpoint file can execute arbitrary code.
|
||||
When ``allowed_types`` is supplied to :func:`decode_checkpoint_value`, a
|
||||
``RestrictedUnpickler`` is used that limits which classes may be instantiated
|
||||
during deserialization. The default built-in safe set covers common Python
|
||||
value types (primitives, datetime, uuid, ...) and all ``agent_framework``
|
||||
internal types. Callers can extend the set by passing additional
|
||||
``"module:qualname"`` strings.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import pickle # nosec # noqa: S403
|
||||
from typing import Any
|
||||
|
||||
from ..exceptions import WorkflowCheckpointException
|
||||
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
@@ -30,6 +34,82 @@ _TYPE_MARKER = "__type__"
|
||||
# Types that are natively JSON-serializable and don't need pickling
|
||||
_JSON_NATIVE_TYPES = (str, int, float, bool, type(None))
|
||||
|
||||
# Module prefix for framework-internal types that are always allowed
|
||||
_FRAMEWORK_MODULE_PREFIX = "agent_framework."
|
||||
|
||||
# Built-in types considered safe for checkpoint deserialization.
|
||||
# Each entry is a ``module:qualname`` string matching the format produced by
|
||||
# :func:`_type_to_key`. These are the classes for which pickle's
|
||||
# ``find_class`` will be called when unpickling common Python value types.
|
||||
_BUILTIN_ALLOWED_TYPE_KEYS: frozenset[str] = frozenset({
|
||||
# builtins
|
||||
"builtins:object",
|
||||
"builtins:complex",
|
||||
"builtins:range",
|
||||
"builtins:slice",
|
||||
"builtins:int",
|
||||
"builtins:float",
|
||||
"builtins:str",
|
||||
"builtins:bytes",
|
||||
"builtins:bytearray",
|
||||
"builtins:bool",
|
||||
"builtins:set",
|
||||
"builtins:frozenset",
|
||||
"builtins:list",
|
||||
"builtins:dict",
|
||||
"builtins:tuple",
|
||||
"builtins:type",
|
||||
# getattr is used by pickle to reconstruct enum members
|
||||
"builtins:getattr",
|
||||
# copyreg helpers used by pickle for object reconstruction
|
||||
"copyreg:_reconstructor",
|
||||
# datetime
|
||||
"datetime:datetime",
|
||||
"datetime:date",
|
||||
"datetime:time",
|
||||
"datetime:timedelta",
|
||||
"datetime:timezone",
|
||||
# uuid
|
||||
"uuid:UUID",
|
||||
# decimal
|
||||
"decimal:Decimal",
|
||||
# collections
|
||||
"collections:OrderedDict",
|
||||
"collections:defaultdict",
|
||||
"collections:deque",
|
||||
})
|
||||
|
||||
|
||||
class _RestrictedUnpickler(pickle.Unpickler): # noqa: S301
|
||||
"""Unpickler that restricts which classes may be instantiated.
|
||||
|
||||
Only classes whose ``module:qualname`` key appears in the combined allow
|
||||
set (built-in safe types + framework types + caller-specified extras) are
|
||||
permitted. All other classes raise :class:`pickle.UnpicklingError`.
|
||||
"""
|
||||
|
||||
def __init__(self, data: bytes, allowed_types: frozenset[str]) -> None:
|
||||
super().__init__(io.BytesIO(data))
|
||||
self._allowed_types = allowed_types
|
||||
|
||||
def find_class(self, module: str, name: str) -> type:
|
||||
type_key = f"{module}:{name}"
|
||||
|
||||
if (
|
||||
type_key in _BUILTIN_ALLOWED_TYPE_KEYS
|
||||
or type_key in self._allowed_types
|
||||
or module.startswith(_FRAMEWORK_MODULE_PREFIX)
|
||||
):
|
||||
return super().find_class(module, name) # type: ignore[no-any-return] # nosec
|
||||
|
||||
raise pickle.UnpicklingError(
|
||||
f"Checkpoint deserialization blocked for type '{type_key}'. "
|
||||
f"To allow this type, either include its 'module:qualname' key in the "
|
||||
f"'allowed_types' set passed to 'decode_checkpoint_value', or add it to "
|
||||
f"'allowed_checkpoint_types' on your checkpoint storage "
|
||||
f"(for example, 'FileCheckpointStorage.allowed_checkpoint_types')."
|
||||
)
|
||||
|
||||
|
||||
def encode_checkpoint_value(value: Any) -> Any:
|
||||
"""Encode a Python value for checkpoint storage.
|
||||
@@ -48,29 +128,51 @@ def encode_checkpoint_value(value: Any) -> Any:
|
||||
return _encode(value)
|
||||
|
||||
|
||||
def decode_checkpoint_value(value: Any) -> Any:
|
||||
def decode_checkpoint_value(value: Any, *, allowed_types: frozenset[str] | None = None) -> Any:
|
||||
"""Decode a value from checkpoint storage.
|
||||
|
||||
Reverses the encoding performed by encode_checkpoint_value.
|
||||
Pickled values (identified by _PICKLE_MARKER) are decoded and unpickled.
|
||||
|
||||
WARNING: Only call this with trusted data. Pickle can execute
|
||||
arbitrary code during deserialization. The post-unpickle type verification
|
||||
detects accidental corruption or type mismatches, but cannot prevent
|
||||
arbitrary code execution from malicious pickle payloads.
|
||||
|
||||
Args:
|
||||
value: A JSON-deserialized value from checkpoint storage.
|
||||
allowed_types: If not ``None``, restrict pickle deserialization to the
|
||||
built-in safe set, framework types, and the types listed here.
|
||||
Each entry should use ``"module:qualname"`` format — that is, the
|
||||
dotted module path followed by a colon and the class
|
||||
``__qualname__``. For example, given a user-defined class::
|
||||
|
||||
# my_app/models.py
|
||||
class MyState: ...
|
||||
|
||||
the corresponding entry would be ``"my_app.models:MyState"``::
|
||||
|
||||
decode_checkpoint_value(
|
||||
data,
|
||||
allowed_types=frozenset({"my_app.models:MyState"}),
|
||||
)
|
||||
|
||||
When using :class:`FileCheckpointStorage`, pass the same strings
|
||||
via ``allowed_checkpoint_types``::
|
||||
|
||||
storage = FileCheckpointStorage(
|
||||
"/tmp/checkpoints",
|
||||
allowed_checkpoint_types=["my_app.models:MyState"],
|
||||
)
|
||||
|
||||
If ``None``, no restriction is applied (backward-compatible
|
||||
behavior).
|
||||
|
||||
Returns:
|
||||
The original Python value.
|
||||
|
||||
Raises:
|
||||
WorkflowCheckpointException: If the unpickled object's type doesn't match
|
||||
the recorded type, indicating corruption, or if the base64/pickle
|
||||
data is malformed.
|
||||
the recorded type, indicating corruption, if the base64/pickle
|
||||
data is malformed, or if a disallowed type is encountered during
|
||||
restricted deserialization.
|
||||
"""
|
||||
return _decode(value)
|
||||
return _decode(value, allowed_types=allowed_types)
|
||||
|
||||
|
||||
def _encode(value: Any) -> Any:
|
||||
@@ -94,7 +196,7 @@ def _encode(value: Any) -> Any:
|
||||
}
|
||||
|
||||
|
||||
def _decode(value: Any) -> Any:
|
||||
def _decode(value: Any, *, allowed_types: frozenset[str] | None = None) -> Any:
|
||||
"""Recursively decode a value from JSON storage."""
|
||||
# JSON-native types pass through
|
||||
if isinstance(value, _JSON_NATIVE_TYPES):
|
||||
@@ -104,16 +206,16 @@ def _decode(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
# Pickled value: decode, unpickle, and verify type
|
||||
if _PICKLE_MARKER in value and _TYPE_MARKER in value:
|
||||
obj = _base64_to_unpickle(value[_PICKLE_MARKER]) # type: ignore
|
||||
obj = _base64_to_unpickle(value[_PICKLE_MARKER], allowed_types=allowed_types) # type: ignore
|
||||
_verify_type(obj, value.get(_TYPE_MARKER)) # type: ignore
|
||||
return obj
|
||||
|
||||
# Regular dict: decode values recursively
|
||||
return {k: _decode(v) for k, v in value.items()} # type: ignore
|
||||
return {k: _decode(v, allowed_types=allowed_types) for k, v in value.items()} # type: ignore
|
||||
|
||||
# Handle encoded lists
|
||||
if isinstance(value, list):
|
||||
return [_decode(item) for item in value] # type: ignore
|
||||
return [_decode(item, allowed_types=allowed_types) for item in value] # type: ignore
|
||||
|
||||
return value
|
||||
|
||||
@@ -148,15 +250,23 @@ def _pickle_to_base64(value: Any) -> str:
|
||||
return base64.b64encode(pickled).decode("ascii")
|
||||
|
||||
|
||||
def _base64_to_unpickle(encoded: str) -> Any:
|
||||
def _base64_to_unpickle(encoded: str, *, allowed_types: frozenset[str] | None = None) -> Any:
|
||||
"""Decode base64 string and unpickle.
|
||||
|
||||
Args:
|
||||
encoded: Base64-encoded pickle data.
|
||||
allowed_types: If not ``None``, use restricted unpickling that only
|
||||
permits built-in safe types, framework types, and the specified
|
||||
extra types.
|
||||
|
||||
Raises:
|
||||
WorkflowCheckpointException: If the base64 data is corrupted or the pickle
|
||||
format is incompatible.
|
||||
WorkflowCheckpointException: If the base64 data is corrupted, the pickle
|
||||
format is incompatible, or a disallowed type is encountered.
|
||||
"""
|
||||
try:
|
||||
pickled = base64.b64decode(encoded.encode("ascii"))
|
||||
if allowed_types is not None:
|
||||
return _RestrictedUnpickler(pickled, allowed_types).load()
|
||||
return pickle.loads(pickled) # nosec # noqa: S301
|
||||
except Exception as exc:
|
||||
raise WorkflowCheckpointException(f"Failed to decode pickled checkpoint data: {exc}") from exc
|
||||
|
||||
@@ -1048,7 +1048,10 @@ async def test_file_checkpoint_storage_roundtrip_datetime():
|
||||
async def test_file_checkpoint_storage_roundtrip_dataclass():
|
||||
"""Test that dataclass objects roundtrip correctly via pickle encoding."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = FileCheckpointStorage(temp_dir)
|
||||
storage = FileCheckpointStorage(
|
||||
temp_dir,
|
||||
allowed_checkpoint_types=["tests.workflow.test_checkpoint:_TestCustomData"],
|
||||
)
|
||||
|
||||
custom_obj = _TestCustomData(name="test", value=42, tags=["a", "b", "c"])
|
||||
|
||||
@@ -1238,7 +1241,10 @@ async def test_file_checkpoint_storage_roundtrip_messages_with_complex_data():
|
||||
async def test_file_checkpoint_storage_roundtrip_pending_request_info_events():
|
||||
"""Test that pending_request_info_events with WorkflowEvent objects roundtrip correctly."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = FileCheckpointStorage(temp_dir)
|
||||
storage = FileCheckpointStorage(
|
||||
temp_dir,
|
||||
allowed_checkpoint_types=["tests.workflow.test_checkpoint:_TestToolApprovalRequest"],
|
||||
)
|
||||
|
||||
# Create request_info events using the proper WorkflowEvent factory
|
||||
event1 = WorkflowEvent.request_info(
|
||||
@@ -1300,7 +1306,13 @@ async def test_file_checkpoint_storage_roundtrip_pending_request_info_events():
|
||||
async def test_file_checkpoint_storage_roundtrip_full_checkpoint():
|
||||
"""Test complete WorkflowCheckpoint roundtrip with all fields populated using proper types."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = FileCheckpointStorage(temp_dir)
|
||||
storage = FileCheckpointStorage(
|
||||
temp_dir,
|
||||
allowed_checkpoint_types=[
|
||||
"tests.workflow.test_checkpoint:_TestApprovalRequest",
|
||||
"tests.workflow.test_checkpoint:_TestExecutorState",
|
||||
],
|
||||
)
|
||||
|
||||
# Create proper WorkflowMessage objects
|
||||
msg1 = WorkflowMessage(data="msg1", source_id="s", target_id="t")
|
||||
|
||||
@@ -0,0 +1,218 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for restricted checkpoint deserialization.
|
||||
|
||||
These tests verify that persisted checkpoint loading uses a restricted
|
||||
unpickler by default:
|
||||
- Arbitrary callables are blocked during deserialization
|
||||
- __reduce__ payloads cannot execute code during deserialization
|
||||
- FileCheckpointStorage accepts allowed_checkpoint_types for extension
|
||||
- User-defined types are blocked unless explicitly allowed
|
||||
- Built-in safe types and framework types are always allowed
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import WorkflowCheckpointException
|
||||
from agent_framework._workflows._checkpoint import FileCheckpointStorage
|
||||
from agent_framework._workflows._checkpoint_encoding import (
|
||||
_PICKLE_MARKER,
|
||||
_TYPE_MARKER,
|
||||
decode_checkpoint_value,
|
||||
encode_checkpoint_value,
|
||||
)
|
||||
|
||||
|
||||
class MaliciousPayload:
|
||||
"""A class whose __reduce__ executes code during unpickling."""
|
||||
|
||||
def __reduce__(self):
|
||||
return (os.getpid, ())
|
||||
|
||||
|
||||
def test_restricted_decode_blocks_arbitrary_callable():
|
||||
"""Restricted decoding blocks arbitrary module-level callables."""
|
||||
pickled = pickle.dumps(os.getpid, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
encoded_b64 = base64.b64encode(pickled).decode("ascii")
|
||||
|
||||
checkpoint_value = {
|
||||
_PICKLE_MARKER: encoded_b64,
|
||||
_TYPE_MARKER: "builtins:builtin_function_or_method",
|
||||
}
|
||||
|
||||
with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"):
|
||||
decode_checkpoint_value(checkpoint_value, allowed_types=frozenset())
|
||||
|
||||
|
||||
def test_restricted_decode_blocks_reduce_payload():
|
||||
"""__reduce__-based payloads are blocked before code can execute."""
|
||||
payload = MaliciousPayload()
|
||||
pickled = pickle.dumps(payload, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
encoded_b64 = base64.b64encode(pickled).decode("ascii")
|
||||
|
||||
checkpoint_value = {
|
||||
_PICKLE_MARKER: encoded_b64,
|
||||
_TYPE_MARKER: f"{MaliciousPayload.__module__}:{MaliciousPayload.__qualname__}",
|
||||
}
|
||||
|
||||
with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"):
|
||||
decode_checkpoint_value(checkpoint_value, allowed_types=frozenset())
|
||||
|
||||
|
||||
def test_restricted_decode_prevents_code_execution():
|
||||
"""Restricted deserialization prevents __reduce__ code from running."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
marker_file = os.path.join(tmpdir, "checkpoint_test_marker")
|
||||
|
||||
payload_bytes = pickle.dumps(
|
||||
type(
|
||||
"Exploit",
|
||||
(),
|
||||
{
|
||||
"__reduce__": lambda self: (
|
||||
eval,
|
||||
(f"open({marker_file!r}, 'w').write('pwned')",),
|
||||
)
|
||||
},
|
||||
)(),
|
||||
protocol=pickle.HIGHEST_PROTOCOL,
|
||||
)
|
||||
encoded_b64 = base64.b64encode(payload_bytes).decode("ascii")
|
||||
|
||||
checkpoint_value = {
|
||||
_PICKLE_MARKER: encoded_b64,
|
||||
_TYPE_MARKER: "builtins:int",
|
||||
}
|
||||
with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"):
|
||||
decode_checkpoint_value(checkpoint_value, allowed_types=frozenset())
|
||||
|
||||
assert not os.path.exists(marker_file), (
|
||||
"Restricted unpickler should have prevented code execution, but the marker file was created."
|
||||
)
|
||||
|
||||
|
||||
def test_file_checkpoint_storage_accepts_allowed_types():
|
||||
"""FileCheckpointStorage.__init__ accepts allowed_checkpoint_types."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = FileCheckpointStorage(
|
||||
tmpdir,
|
||||
allowed_checkpoint_types=["some.module:SomeType"],
|
||||
)
|
||||
assert storage is not None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _AllowedTestState:
|
||||
"""Test dataclass that will be explicitly allowed."""
|
||||
|
||||
name: str
|
||||
value: int
|
||||
|
||||
|
||||
def test_restricted_decode_blocks_unlisted_user_type():
|
||||
"""User-defined types are blocked when not in allowed_checkpoint_types."""
|
||||
original = _AllowedTestState(name="test", value=42)
|
||||
encoded = encode_checkpoint_value(original)
|
||||
|
||||
with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"):
|
||||
decode_checkpoint_value(encoded, allowed_types=frozenset())
|
||||
|
||||
|
||||
def test_restricted_decode_allows_listed_user_type():
|
||||
"""User-defined types are allowed when listed in allowed_types."""
|
||||
original = _AllowedTestState(name="test", value=42)
|
||||
encoded = encode_checkpoint_value(original)
|
||||
|
||||
type_key = f"{_AllowedTestState.__module__}:{_AllowedTestState.__qualname__}"
|
||||
decoded = decode_checkpoint_value(encoded, allowed_types=frozenset({type_key}))
|
||||
|
||||
assert isinstance(decoded, _AllowedTestState)
|
||||
assert decoded.name == "test"
|
||||
assert decoded.value == 42
|
||||
|
||||
|
||||
def test_restricted_decode_allows_builtin_safe_types():
|
||||
"""Built-in safe types (datetime, set, etc.) are always allowed."""
|
||||
test_values = [
|
||||
datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
{1, 2, 3},
|
||||
frozenset({4, 5, 6}),
|
||||
(1, "two", 3.0),
|
||||
complex(1, 2),
|
||||
]
|
||||
for original in test_values:
|
||||
encoded = encode_checkpoint_value(original)
|
||||
decoded = decode_checkpoint_value(encoded, allowed_types=frozenset())
|
||||
assert decoded == original
|
||||
|
||||
|
||||
def test_unrestricted_decode_allows_arbitrary_types():
|
||||
"""Without allowed_types, decode_checkpoint_value remains unrestricted."""
|
||||
original = _AllowedTestState(name="test", value=42)
|
||||
encoded = encode_checkpoint_value(original)
|
||||
|
||||
decoded = decode_checkpoint_value(encoded)
|
||||
|
||||
assert isinstance(decoded, _AllowedTestState)
|
||||
assert decoded.name == "test"
|
||||
|
||||
|
||||
async def test_file_storage_blocks_unlisted_user_type():
|
||||
"""FileCheckpointStorage blocks user types not in allowed_checkpoint_types."""
|
||||
from agent_framework import WorkflowCheckpoint
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save with a storage that allows the type
|
||||
type_key = f"{_AllowedTestState.__module__}:{_AllowedTestState.__qualname__}"
|
||||
save_storage = FileCheckpointStorage(tmpdir, allowed_checkpoint_types=[type_key])
|
||||
|
||||
checkpoint = WorkflowCheckpoint(
|
||||
workflow_name="test",
|
||||
graph_signature_hash="hash",
|
||||
state={"data": _AllowedTestState(name="test", value=1)},
|
||||
)
|
||||
await save_storage.save(checkpoint)
|
||||
|
||||
# Load with a storage that does NOT allow the type
|
||||
load_storage = FileCheckpointStorage(tmpdir)
|
||||
with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"):
|
||||
await load_storage.load(checkpoint.checkpoint_id)
|
||||
|
||||
|
||||
async def test_file_storage_allows_listed_user_type():
|
||||
"""FileCheckpointStorage allows user types listed in allowed_checkpoint_types."""
|
||||
from agent_framework import WorkflowCheckpoint
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
type_key = f"{_AllowedTestState.__module__}:{_AllowedTestState.__qualname__}"
|
||||
storage = FileCheckpointStorage(tmpdir, allowed_checkpoint_types=[type_key])
|
||||
|
||||
checkpoint = WorkflowCheckpoint(
|
||||
workflow_name="test",
|
||||
graph_signature_hash="hash",
|
||||
state={"data": _AllowedTestState(name="allowed", value=99)},
|
||||
)
|
||||
await storage.save(checkpoint)
|
||||
loaded = await storage.load(checkpoint.checkpoint_id)
|
||||
|
||||
assert isinstance(loaded.state["data"], _AllowedTestState)
|
||||
assert loaded.state["data"].name == "allowed"
|
||||
assert loaded.state["data"].value == 99
|
||||
|
||||
|
||||
def test_restricted_unpickler_raises_pickle_error():
|
||||
"""_RestrictedUnpickler.find_class raises pickle.UnpicklingError, not a framework exception."""
|
||||
from agent_framework._workflows._checkpoint_encoding import _RestrictedUnpickler
|
||||
|
||||
pickled = pickle.dumps(os.getpid, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
unpickler = _RestrictedUnpickler(pickled, frozenset())
|
||||
with pytest.raises(pickle.UnpicklingError, match="deserialization blocked"):
|
||||
unpickler.load()
|
||||
@@ -130,7 +130,17 @@ async def test_checkpoint_with_pending_request_info_events():
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use file-based storage to test full serialization
|
||||
storage = FileCheckpointStorage(temp_dir)
|
||||
storage = FileCheckpointStorage(
|
||||
temp_dir,
|
||||
allowed_checkpoint_types=[
|
||||
"tests.workflow.test_request_info_and_response:UserApprovalRequest",
|
||||
"tests.workflow.test_request_info_and_response:CalculationRequest",
|
||||
"tests.workflow.test_request_info_event_rehydrate:MockRequest",
|
||||
"tests.workflow.test_request_info_event_rehydrate:SimpleApproval",
|
||||
"tests.workflow.test_request_info_event_rehydrate:SlottedApproval",
|
||||
"tests.workflow.test_request_info_event_rehydrate:TimedApproval",
|
||||
],
|
||||
)
|
||||
|
||||
# Create workflow with checkpointing enabled
|
||||
executor = ApprovalRequiredExecutor(id="approval_executor")
|
||||
@@ -225,7 +235,17 @@ async def test_checkpoint_restore_with_responses_does_not_reemit_handled_request
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use file-based storage to test full serialization
|
||||
storage = FileCheckpointStorage(temp_dir)
|
||||
storage = FileCheckpointStorage(
|
||||
temp_dir,
|
||||
allowed_checkpoint_types=[
|
||||
"tests.workflow.test_request_info_and_response:UserApprovalRequest",
|
||||
"tests.workflow.test_request_info_and_response:CalculationRequest",
|
||||
"tests.workflow.test_request_info_event_rehydrate:MockRequest",
|
||||
"tests.workflow.test_request_info_event_rehydrate:SimpleApproval",
|
||||
"tests.workflow.test_request_info_event_rehydrate:SlottedApproval",
|
||||
"tests.workflow.test_request_info_event_rehydrate:TimedApproval",
|
||||
],
|
||||
)
|
||||
|
||||
# Create workflow with checkpointing enabled
|
||||
executor = ApprovalRequiredExecutor(id="approval_executor")
|
||||
@@ -288,7 +308,17 @@ async def test_checkpoint_restore_with_partial_responses_reemits_unhandled_reque
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage = FileCheckpointStorage(temp_dir)
|
||||
storage = FileCheckpointStorage(
|
||||
temp_dir,
|
||||
allowed_checkpoint_types=[
|
||||
"tests.workflow.test_request_info_and_response:UserApprovalRequest",
|
||||
"tests.workflow.test_request_info_and_response:CalculationRequest",
|
||||
"tests.workflow.test_request_info_event_rehydrate:MockRequest",
|
||||
"tests.workflow.test_request_info_event_rehydrate:SimpleApproval",
|
||||
"tests.workflow.test_request_info_event_rehydrate:SlottedApproval",
|
||||
"tests.workflow.test_request_info_event_rehydrate:TimedApproval",
|
||||
],
|
||||
)
|
||||
|
||||
# Create workflow with multiple requests
|
||||
executor = MultiRequestExecutor(id="multi_executor")
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user