Python: Restrict persisted checkpoint deserialization by default (#4941)

* Harden Python checkpoint persistence defaults

Add RestrictedUnpickler to _checkpoint_encoding.py that limits which
types may be instantiated during pickle deserialization.  By default
FileCheckpointStorage now uses the restricted unpickler, allowing only:

- Built-in Python value types (primitives, datetime, uuid, decimal,
  collections, etc.)
- All agent_framework.* internal types
- Additional types specified via the new allowed_checkpoint_types
  parameter on FileCheckpointStorage

This narrows the default type surface area for persisted checkpoints
while keeping framework-owned scenarios working without extra
configuration.  Developers can extend the allowed set by passing
"module:qualname" strings to allowed_checkpoint_types.

The decode_checkpoint_value function retains backward-compatible
unrestricted behavior when called without the new allowed_types kwarg.

Fixes #4894

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix: resolve mypy no-any-return error in checkpoint encoding

Add explicit type annotation for super().find_class() return value
to satisfy mypy's no-any-return check.

Fixes #4894

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Simplify find_class return in _RestrictedUnpickler (#4894)

Remove unnecessary intermediate variable and apply # noqa: S301 # nosec
directly on the super().find_class() call, matching the established
pattern used on the pickle.loads() call in the same file.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Address review feedback for #4894: Python: Harden Python checkpoint persistence defaults

* Restore # noqa: S301 on line 102 of _checkpoint_encoding.py (#4894)

The review feedback correctly identified that removing the # noqa: S301
suppression from the find_class return statement would cause a ruff S301
lint failure, since the project enables bandit ("S") rules. This
restores consistency with lines 82 and 246 in the same file.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Address review feedback for #4894: Python: Harden Python checkpoint persistence defaults

* Address PR review comments on checkpoint encoding (#4894)

- Move module docstring to proper position after __future__ import
- Fix find_class return type annotation to type[Any]
- Add missing # noqa: S301 pragma on find_class return
- Improve error message to reference both allowed_types param and
  FileCheckpointStorage.allowed_checkpoint_types
- Add -> None return annotation to FileCheckpointStorage.__init__
- Replace tempfile.mktemp with TemporaryDirectory in test
- Replace contextlib.suppress with pytest.raises for precise assertion
- Remove unused contextlib import

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Address PR #4941 review comments: fix docstring position and return type

- Move module docstring before 'from __future__' import so it populates
  __doc__ (comment #4)
- Change find_class return annotation from type[Any] to type to avoid
  misleading callers about non-type returns like copyreg._reconstructor
  (comment #2)

Comments #1, #3, #5, #6, #7, #8 were already addressed in the current code.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Address review feedback for #4894: review comment fixes

* fix: use pickle.UnpicklingError in RestrictedUnpickler and improve docstring (#4894)

- Change _RestrictedUnpickler.find_class to raise pickle.UnpicklingError
  instead of WorkflowCheckpointException, since it is pickle-level concern
  that gets wrapped by the caller in _base64_to_unpickle.
- Remove now-unnecessary WorkflowCheckpointException re-raise in
  _base64_to_unpickle (pickle.UnpicklingError is caught by the generic
  except Exception handler and wrapped).
- Expand decode_checkpoint_value docstring to show a concrete example of
  the module:qualname format with a user-defined class.
- Add regression test verifying find_class raises pickle.UnpicklingError.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix: address PR #4941 review comments for checkpoint encoding

- Comment 1 (line 103): Already resolved in prior commit — _RestrictedUnpickler
  now raises pickle.UnpicklingError instead of WorkflowCheckpointException.

- Comment 2 (line 140): Add concrete usage examples to decode_checkpoint_value
  docstring showing both direct allowed_types usage and FileCheckpointStorage
  allowed_checkpoint_types usage. Rename 'SafeState' to 'MyState' across all
  docstrings for consistency, making it clear this is a user-defined class name.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix: replace deprecated 'builtin' repo with pre-commit-hooks in pre-commit config

pre-commit 4.x no longer supports 'repo: builtin'. Merge those hooks into
the existing pre-commit-hooks repo entry.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* style: apply pyupgrade formatting to docstring example

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix: resolve pre-commit hook paths for monorepo git root

The poe-check and bandit hooks referenced paths relative to python/
but pre-commit runs hooks from the git root (monorepo root). Fix
poe-check entry to cd into python/ first, and update bandit config
path to python/pyproject.toml.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix pre-commit config paths for prek --cd python execution

Revert bandit config path from 'python/pyproject.toml' to 'pyproject.toml'
and poe-check entry from explicit 'cd python' wrapper to direct invocation,
since prek --cd python already sets the working directory to python/.

Also apply ruff formatting fixes to cosmos checkpoint storage files.

Fixes #4894

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix: add builtins:getattr to checkpoint deserialization allowlist

Pickle uses builtins:getattr to reconstruct enum members (e.g.,
WorkflowMessage.type which is a MessageType enum). Without it in the
allowlist, checkpoint roundtrip tests fail with
WorkflowCheckpointException.

Fixes #4894

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Address review feedback for #4894: review comment fixes

---------

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Evan Mattson
2026-04-10 09:04:17 +09:00
committed by GitHub
Unverified
parent 5e8fe0be1f
commit 4dbe696e0e
8 changed files with 28479 additions and 43 deletions
+2 -4
View File
@@ -1,7 +1,8 @@
fail_fast: true
exclude: ^scripts/
repos:
- repo: builtin
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-toml
name: Check TOML files
@@ -34,9 +35,6 @@ repos:
- id: no-commit-to-branch
name: Protect main branch
args: [--branch, main]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-ast
name: Check Valid Python Samples
types: ["python"]
@@ -244,14 +244,39 @@ class FileCheckpointStorage:
is serialized using pickle and embedded as base64-encoded strings within the JSON. This allows
for human-readable checkpoint files while preserving the ability to store complex Python objects.
SECURITY WARNING: Checkpoints use pickle for data serialization. Only load checkpoints
from trusted sources. Loading a malicious checkpoint file can execute arbitrary code.
By default, checkpoint deserialization is restricted to a built-in set of safe
Python types (primitives, datetime, uuid, ...) and all ``agent_framework``
internal types. To allow additional application-specific types, pass them via
the ``allowed_checkpoint_types`` parameter using ``"module:qualname"`` format.
Example::
storage = FileCheckpointStorage(
"/tmp/checkpoints",
allowed_checkpoint_types=[
"my_app.models:MyState",
],
)
"""
def __init__(self, storage_path: str | Path):
"""Initialize the file storage."""
def __init__(
self,
storage_path: str | Path,
*,
allowed_checkpoint_types: list[str] | None = None,
) -> None:
"""Initialize the file storage.
Args:
storage_path: Directory path where checkpoint files will be stored.
allowed_checkpoint_types: Additional types (beyond the built-in safe set
and framework types) that are permitted during checkpoint
deserialization. Each entry should be a ``"module:qualname"``
string (e.g., ``"my_app.models:MyState"``).
"""
self.storage_path = Path(storage_path)
self.storage_path.mkdir(parents=True, exist_ok=True)
self._allowed_types: frozenset[str] = frozenset(allowed_checkpoint_types or [])
logger.info(f"Initialized file checkpoint storage at {self.storage_path}")
def _validate_file_path(self, checkpoint_id: CheckpointID) -> Path:
@@ -327,7 +352,7 @@ class FileCheckpointStorage:
from ._checkpoint_encoding import decode_checkpoint_value
try:
decoded_checkpoint_dict = decode_checkpoint_value(encoded_checkpoint)
decoded_checkpoint_dict = decode_checkpoint_value(encoded_checkpoint, allowed_types=self._allowed_types)
except WorkflowCheckpointException:
raise
checkpoint = WorkflowCheckpoint.from_dict(decoded_checkpoint_dict)
@@ -352,7 +377,9 @@ class FileCheckpointStorage:
encoded_checkpoint = json.load(f)
from ._checkpoint_encoding import decode_checkpoint_value
decoded_checkpoint_dict = decode_checkpoint_value(encoded_checkpoint)
decoded_checkpoint_dict = decode_checkpoint_value(
encoded_checkpoint, allowed_types=self._allowed_types
)
checkpoint = WorkflowCheckpoint.from_dict(decoded_checkpoint_dict)
if checkpoint.workflow_name == workflow_name:
checkpoints.append(checkpoint)
@@ -1,14 +1,5 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import base64
import logging
import pickle # nosec # noqa: S403
from typing import Any
from ..exceptions import WorkflowCheckpointException
"""Checkpoint encoding using JSON structure with pickle+base64 for arbitrary data.
This hybrid approach provides:
@@ -16,10 +7,23 @@ This hybrid approach provides:
- Full Python object fidelity via pickle for data values (non-JSON-native types)
- Base64 encoding to embed binary pickle data in JSON strings
SECURITY WARNING: Checkpoints use pickle for data serialization. Only load checkpoints
from trusted sources. Loading a malicious checkpoint file can execute arbitrary code.
When ``allowed_types`` is supplied to :func:`decode_checkpoint_value`, a
``RestrictedUnpickler`` is used that limits which classes may be instantiated
during deserialization. The default built-in safe set covers common Python
value types (primitives, datetime, uuid, ...) and all ``agent_framework``
internal types. Callers can extend the set by passing additional
``"module:qualname"`` strings.
"""
from __future__ import annotations
import base64
import io
import logging
import pickle # nosec # noqa: S403
from typing import Any
from ..exceptions import WorkflowCheckpointException
logger = logging.getLogger("agent_framework")
@@ -30,6 +34,82 @@ _TYPE_MARKER = "__type__"
# Types that are natively JSON-serializable and don't need pickling
_JSON_NATIVE_TYPES = (str, int, float, bool, type(None))
# Module prefix for framework-internal types that are always allowed
_FRAMEWORK_MODULE_PREFIX = "agent_framework."
# Built-in types considered safe for checkpoint deserialization.
# Each entry is a ``module:qualname`` string matching the format produced by
# :func:`_type_to_key`. These are the classes for which pickle's
# ``find_class`` will be called when unpickling common Python value types.
_BUILTIN_ALLOWED_TYPE_KEYS: frozenset[str] = frozenset({
# builtins
"builtins:object",
"builtins:complex",
"builtins:range",
"builtins:slice",
"builtins:int",
"builtins:float",
"builtins:str",
"builtins:bytes",
"builtins:bytearray",
"builtins:bool",
"builtins:set",
"builtins:frozenset",
"builtins:list",
"builtins:dict",
"builtins:tuple",
"builtins:type",
# getattr is used by pickle to reconstruct enum members
"builtins:getattr",
# copyreg helpers used by pickle for object reconstruction
"copyreg:_reconstructor",
# datetime
"datetime:datetime",
"datetime:date",
"datetime:time",
"datetime:timedelta",
"datetime:timezone",
# uuid
"uuid:UUID",
# decimal
"decimal:Decimal",
# collections
"collections:OrderedDict",
"collections:defaultdict",
"collections:deque",
})
class _RestrictedUnpickler(pickle.Unpickler): # noqa: S301
"""Unpickler that restricts which classes may be instantiated.
Only classes whose ``module:qualname`` key appears in the combined allow
set (built-in safe types + framework types + caller-specified extras) are
permitted. All other classes raise :class:`pickle.UnpicklingError`.
"""
def __init__(self, data: bytes, allowed_types: frozenset[str]) -> None:
super().__init__(io.BytesIO(data))
self._allowed_types = allowed_types
def find_class(self, module: str, name: str) -> type:
type_key = f"{module}:{name}"
if (
type_key in _BUILTIN_ALLOWED_TYPE_KEYS
or type_key in self._allowed_types
or module.startswith(_FRAMEWORK_MODULE_PREFIX)
):
return super().find_class(module, name) # type: ignore[no-any-return] # nosec
raise pickle.UnpicklingError(
f"Checkpoint deserialization blocked for type '{type_key}'. "
f"To allow this type, either include its 'module:qualname' key in the "
f"'allowed_types' set passed to 'decode_checkpoint_value', or add it to "
f"'allowed_checkpoint_types' on your checkpoint storage "
f"(for example, 'FileCheckpointStorage.allowed_checkpoint_types')."
)
def encode_checkpoint_value(value: Any) -> Any:
"""Encode a Python value for checkpoint storage.
@@ -48,29 +128,51 @@ def encode_checkpoint_value(value: Any) -> Any:
return _encode(value)
def decode_checkpoint_value(value: Any) -> Any:
def decode_checkpoint_value(value: Any, *, allowed_types: frozenset[str] | None = None) -> Any:
"""Decode a value from checkpoint storage.
Reverses the encoding performed by encode_checkpoint_value.
Pickled values (identified by _PICKLE_MARKER) are decoded and unpickled.
WARNING: Only call this with trusted data. Pickle can execute
arbitrary code during deserialization. The post-unpickle type verification
detects accidental corruption or type mismatches, but cannot prevent
arbitrary code execution from malicious pickle payloads.
Args:
value: A JSON-deserialized value from checkpoint storage.
allowed_types: If not ``None``, restrict pickle deserialization to the
built-in safe set, framework types, and the types listed here.
Each entry should use ``"module:qualname"`` format — that is, the
dotted module path followed by a colon and the class
``__qualname__``. For example, given a user-defined class::
# my_app/models.py
class MyState: ...
the corresponding entry would be ``"my_app.models:MyState"``::
decode_checkpoint_value(
data,
allowed_types=frozenset({"my_app.models:MyState"}),
)
When using :class:`FileCheckpointStorage`, pass the same strings
via ``allowed_checkpoint_types``::
storage = FileCheckpointStorage(
"/tmp/checkpoints",
allowed_checkpoint_types=["my_app.models:MyState"],
)
If ``None``, no restriction is applied (backward-compatible
behavior).
Returns:
The original Python value.
Raises:
WorkflowCheckpointException: If the unpickled object's type doesn't match
the recorded type, indicating corruption, or if the base64/pickle
data is malformed.
the recorded type, indicating corruption, if the base64/pickle
data is malformed, or if a disallowed type is encountered during
restricted deserialization.
"""
return _decode(value)
return _decode(value, allowed_types=allowed_types)
def _encode(value: Any) -> Any:
@@ -94,7 +196,7 @@ def _encode(value: Any) -> Any:
}
def _decode(value: Any) -> Any:
def _decode(value: Any, *, allowed_types: frozenset[str] | None = None) -> Any:
"""Recursively decode a value from JSON storage."""
# JSON-native types pass through
if isinstance(value, _JSON_NATIVE_TYPES):
@@ -104,16 +206,16 @@ def _decode(value: Any) -> Any:
if isinstance(value, dict):
# Pickled value: decode, unpickle, and verify type
if _PICKLE_MARKER in value and _TYPE_MARKER in value:
obj = _base64_to_unpickle(value[_PICKLE_MARKER]) # type: ignore
obj = _base64_to_unpickle(value[_PICKLE_MARKER], allowed_types=allowed_types) # type: ignore
_verify_type(obj, value.get(_TYPE_MARKER)) # type: ignore
return obj
# Regular dict: decode values recursively
return {k: _decode(v) for k, v in value.items()} # type: ignore
return {k: _decode(v, allowed_types=allowed_types) for k, v in value.items()} # type: ignore
# Handle encoded lists
if isinstance(value, list):
return [_decode(item) for item in value] # type: ignore
return [_decode(item, allowed_types=allowed_types) for item in value] # type: ignore
return value
@@ -148,15 +250,23 @@ def _pickle_to_base64(value: Any) -> str:
return base64.b64encode(pickled).decode("ascii")
def _base64_to_unpickle(encoded: str) -> Any:
def _base64_to_unpickle(encoded: str, *, allowed_types: frozenset[str] | None = None) -> Any:
"""Decode base64 string and unpickle.
Args:
encoded: Base64-encoded pickle data.
allowed_types: If not ``None``, use restricted unpickling that only
permits built-in safe types, framework types, and the specified
extra types.
Raises:
WorkflowCheckpointException: If the base64 data is corrupted or the pickle
format is incompatible.
WorkflowCheckpointException: If the base64 data is corrupted, the pickle
format is incompatible, or a disallowed type is encountered.
"""
try:
pickled = base64.b64decode(encoded.encode("ascii"))
if allowed_types is not None:
return _RestrictedUnpickler(pickled, allowed_types).load()
return pickle.loads(pickled) # nosec # noqa: S301
except Exception as exc:
raise WorkflowCheckpointException(f"Failed to decode pickled checkpoint data: {exc}") from exc
@@ -1048,7 +1048,10 @@ async def test_file_checkpoint_storage_roundtrip_datetime():
async def test_file_checkpoint_storage_roundtrip_dataclass():
"""Test that dataclass objects roundtrip correctly via pickle encoding."""
with tempfile.TemporaryDirectory() as temp_dir:
storage = FileCheckpointStorage(temp_dir)
storage = FileCheckpointStorage(
temp_dir,
allowed_checkpoint_types=["tests.workflow.test_checkpoint:_TestCustomData"],
)
custom_obj = _TestCustomData(name="test", value=42, tags=["a", "b", "c"])
@@ -1238,7 +1241,10 @@ async def test_file_checkpoint_storage_roundtrip_messages_with_complex_data():
async def test_file_checkpoint_storage_roundtrip_pending_request_info_events():
"""Test that pending_request_info_events with WorkflowEvent objects roundtrip correctly."""
with tempfile.TemporaryDirectory() as temp_dir:
storage = FileCheckpointStorage(temp_dir)
storage = FileCheckpointStorage(
temp_dir,
allowed_checkpoint_types=["tests.workflow.test_checkpoint:_TestToolApprovalRequest"],
)
# Create request_info events using the proper WorkflowEvent factory
event1 = WorkflowEvent.request_info(
@@ -1300,7 +1306,13 @@ async def test_file_checkpoint_storage_roundtrip_pending_request_info_events():
async def test_file_checkpoint_storage_roundtrip_full_checkpoint():
"""Test complete WorkflowCheckpoint roundtrip with all fields populated using proper types."""
with tempfile.TemporaryDirectory() as temp_dir:
storage = FileCheckpointStorage(temp_dir)
storage = FileCheckpointStorage(
temp_dir,
allowed_checkpoint_types=[
"tests.workflow.test_checkpoint:_TestApprovalRequest",
"tests.workflow.test_checkpoint:_TestExecutorState",
],
)
# Create proper WorkflowMessage objects
msg1 = WorkflowMessage(data="msg1", source_id="s", target_id="t")
@@ -0,0 +1,218 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for restricted checkpoint deserialization.
These tests verify that persisted checkpoint loading uses a restricted
unpickler by default:
- Arbitrary callables are blocked during deserialization
- __reduce__ payloads cannot execute code during deserialization
- FileCheckpointStorage accepts allowed_checkpoint_types for extension
- User-defined types are blocked unless explicitly allowed
- Built-in safe types and framework types are always allowed
"""
import base64
import os
import pickle
import tempfile
from dataclasses import dataclass
from datetime import datetime, timezone
import pytest
from agent_framework import WorkflowCheckpointException
from agent_framework._workflows._checkpoint import FileCheckpointStorage
from agent_framework._workflows._checkpoint_encoding import (
_PICKLE_MARKER,
_TYPE_MARKER,
decode_checkpoint_value,
encode_checkpoint_value,
)
class MaliciousPayload:
"""A class whose __reduce__ executes code during unpickling."""
def __reduce__(self):
return (os.getpid, ())
def test_restricted_decode_blocks_arbitrary_callable():
"""Restricted decoding blocks arbitrary module-level callables."""
pickled = pickle.dumps(os.getpid, protocol=pickle.HIGHEST_PROTOCOL)
encoded_b64 = base64.b64encode(pickled).decode("ascii")
checkpoint_value = {
_PICKLE_MARKER: encoded_b64,
_TYPE_MARKER: "builtins:builtin_function_or_method",
}
with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"):
decode_checkpoint_value(checkpoint_value, allowed_types=frozenset())
def test_restricted_decode_blocks_reduce_payload():
"""__reduce__-based payloads are blocked before code can execute."""
payload = MaliciousPayload()
pickled = pickle.dumps(payload, protocol=pickle.HIGHEST_PROTOCOL)
encoded_b64 = base64.b64encode(pickled).decode("ascii")
checkpoint_value = {
_PICKLE_MARKER: encoded_b64,
_TYPE_MARKER: f"{MaliciousPayload.__module__}:{MaliciousPayload.__qualname__}",
}
with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"):
decode_checkpoint_value(checkpoint_value, allowed_types=frozenset())
def test_restricted_decode_prevents_code_execution():
"""Restricted deserialization prevents __reduce__ code from running."""
with tempfile.TemporaryDirectory() as tmpdir:
marker_file = os.path.join(tmpdir, "checkpoint_test_marker")
payload_bytes = pickle.dumps(
type(
"Exploit",
(),
{
"__reduce__": lambda self: (
eval,
(f"open({marker_file!r}, 'w').write('pwned')",),
)
},
)(),
protocol=pickle.HIGHEST_PROTOCOL,
)
encoded_b64 = base64.b64encode(payload_bytes).decode("ascii")
checkpoint_value = {
_PICKLE_MARKER: encoded_b64,
_TYPE_MARKER: "builtins:int",
}
with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"):
decode_checkpoint_value(checkpoint_value, allowed_types=frozenset())
assert not os.path.exists(marker_file), (
"Restricted unpickler should have prevented code execution, but the marker file was created."
)
def test_file_checkpoint_storage_accepts_allowed_types():
"""FileCheckpointStorage.__init__ accepts allowed_checkpoint_types."""
with tempfile.TemporaryDirectory() as tmpdir:
storage = FileCheckpointStorage(
tmpdir,
allowed_checkpoint_types=["some.module:SomeType"],
)
assert storage is not None
@dataclass
class _AllowedTestState:
"""Test dataclass that will be explicitly allowed."""
name: str
value: int
def test_restricted_decode_blocks_unlisted_user_type():
"""User-defined types are blocked when not in allowed_checkpoint_types."""
original = _AllowedTestState(name="test", value=42)
encoded = encode_checkpoint_value(original)
with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"):
decode_checkpoint_value(encoded, allowed_types=frozenset())
def test_restricted_decode_allows_listed_user_type():
"""User-defined types are allowed when listed in allowed_types."""
original = _AllowedTestState(name="test", value=42)
encoded = encode_checkpoint_value(original)
type_key = f"{_AllowedTestState.__module__}:{_AllowedTestState.__qualname__}"
decoded = decode_checkpoint_value(encoded, allowed_types=frozenset({type_key}))
assert isinstance(decoded, _AllowedTestState)
assert decoded.name == "test"
assert decoded.value == 42
def test_restricted_decode_allows_builtin_safe_types():
"""Built-in safe types (datetime, set, etc.) are always allowed."""
test_values = [
datetime(2025, 1, 1, tzinfo=timezone.utc),
{1, 2, 3},
frozenset({4, 5, 6}),
(1, "two", 3.0),
complex(1, 2),
]
for original in test_values:
encoded = encode_checkpoint_value(original)
decoded = decode_checkpoint_value(encoded, allowed_types=frozenset())
assert decoded == original
def test_unrestricted_decode_allows_arbitrary_types():
"""Without allowed_types, decode_checkpoint_value remains unrestricted."""
original = _AllowedTestState(name="test", value=42)
encoded = encode_checkpoint_value(original)
decoded = decode_checkpoint_value(encoded)
assert isinstance(decoded, _AllowedTestState)
assert decoded.name == "test"
async def test_file_storage_blocks_unlisted_user_type():
"""FileCheckpointStorage blocks user types not in allowed_checkpoint_types."""
from agent_framework import WorkflowCheckpoint
with tempfile.TemporaryDirectory() as tmpdir:
# Save with a storage that allows the type
type_key = f"{_AllowedTestState.__module__}:{_AllowedTestState.__qualname__}"
save_storage = FileCheckpointStorage(tmpdir, allowed_checkpoint_types=[type_key])
checkpoint = WorkflowCheckpoint(
workflow_name="test",
graph_signature_hash="hash",
state={"data": _AllowedTestState(name="test", value=1)},
)
await save_storage.save(checkpoint)
# Load with a storage that does NOT allow the type
load_storage = FileCheckpointStorage(tmpdir)
with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"):
await load_storage.load(checkpoint.checkpoint_id)
async def test_file_storage_allows_listed_user_type():
"""FileCheckpointStorage allows user types listed in allowed_checkpoint_types."""
from agent_framework import WorkflowCheckpoint
with tempfile.TemporaryDirectory() as tmpdir:
type_key = f"{_AllowedTestState.__module__}:{_AllowedTestState.__qualname__}"
storage = FileCheckpointStorage(tmpdir, allowed_checkpoint_types=[type_key])
checkpoint = WorkflowCheckpoint(
workflow_name="test",
graph_signature_hash="hash",
state={"data": _AllowedTestState(name="allowed", value=99)},
)
await storage.save(checkpoint)
loaded = await storage.load(checkpoint.checkpoint_id)
assert isinstance(loaded.state["data"], _AllowedTestState)
assert loaded.state["data"].name == "allowed"
assert loaded.state["data"].value == 99
def test_restricted_unpickler_raises_pickle_error():
"""_RestrictedUnpickler.find_class raises pickle.UnpicklingError, not a framework exception."""
from agent_framework._workflows._checkpoint_encoding import _RestrictedUnpickler
pickled = pickle.dumps(os.getpid, protocol=pickle.HIGHEST_PROTOCOL)
unpickler = _RestrictedUnpickler(pickled, frozenset())
with pytest.raises(pickle.UnpicklingError, match="deserialization blocked"):
unpickler.load()
@@ -130,7 +130,17 @@ async def test_checkpoint_with_pending_request_info_events():
with tempfile.TemporaryDirectory() as temp_dir:
# Use file-based storage to test full serialization
storage = FileCheckpointStorage(temp_dir)
storage = FileCheckpointStorage(
temp_dir,
allowed_checkpoint_types=[
"tests.workflow.test_request_info_and_response:UserApprovalRequest",
"tests.workflow.test_request_info_and_response:CalculationRequest",
"tests.workflow.test_request_info_event_rehydrate:MockRequest",
"tests.workflow.test_request_info_event_rehydrate:SimpleApproval",
"tests.workflow.test_request_info_event_rehydrate:SlottedApproval",
"tests.workflow.test_request_info_event_rehydrate:TimedApproval",
],
)
# Create workflow with checkpointing enabled
executor = ApprovalRequiredExecutor(id="approval_executor")
@@ -225,7 +235,17 @@ async def test_checkpoint_restore_with_responses_does_not_reemit_handled_request
with tempfile.TemporaryDirectory() as temp_dir:
# Use file-based storage to test full serialization
storage = FileCheckpointStorage(temp_dir)
storage = FileCheckpointStorage(
temp_dir,
allowed_checkpoint_types=[
"tests.workflow.test_request_info_and_response:UserApprovalRequest",
"tests.workflow.test_request_info_and_response:CalculationRequest",
"tests.workflow.test_request_info_event_rehydrate:MockRequest",
"tests.workflow.test_request_info_event_rehydrate:SimpleApproval",
"tests.workflow.test_request_info_event_rehydrate:SlottedApproval",
"tests.workflow.test_request_info_event_rehydrate:TimedApproval",
],
)
# Create workflow with checkpointing enabled
executor = ApprovalRequiredExecutor(id="approval_executor")
@@ -288,7 +308,17 @@ async def test_checkpoint_restore_with_partial_responses_reemits_unhandled_reque
import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
storage = FileCheckpointStorage(temp_dir)
storage = FileCheckpointStorage(
temp_dir,
allowed_checkpoint_types=[
"tests.workflow.test_request_info_and_response:UserApprovalRequest",
"tests.workflow.test_request_info_and_response:CalculationRequest",
"tests.workflow.test_request_info_event_rehydrate:MockRequest",
"tests.workflow.test_request_info_event_rehydrate:SimpleApproval",
"tests.workflow.test_request_info_event_rehydrate:SlottedApproval",
"tests.workflow.test_request_info_event_rehydrate:TimedApproval",
],
)
# Create workflow with multiple requests
executor = MultiRequestExecutor(id="multi_executor")
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large Load Diff