Python: Fix: Verify types during checkpoint deserialization to prevent marker spoofing (#3243)

* Initial plan

* Add validation for reserved keywords in checkpoint encoding/decoding

Co-authored-by: TaoChenOSU <12570346+TaoChenOSU@users.noreply.github.com>

* Refactor to eliminate duplicate code in model protocol detection

Co-authored-by: TaoChenOSU <12570346+TaoChenOSU@users.noreply.github.com>

* Fix pyright type narrowing issue for dataclass check

Co-authored-by: TaoChenOSU <12570346+TaoChenOSU@users.noreply.github.com>

* Add comprehensive unit tests for checkpoint encoding

Co-authored-by: TaoChenOSU <12570346+TaoChenOSU@users.noreply.github.com>

* Remove serialization-time reserved keyword validation to fix failing tests

The serialization-time validation was too aggressive and blocked legitimate use cases
where encoded data was being re-encoded. Security is now enforced only at deserialization
time by validating that classes marked with DATACLASS_MARKER are actual dataclasses and
classes marked with MODEL_MARKER actually support the model protocol.

Co-authored-by: TaoChenOSU <12570346+TaoChenOSU@users.noreply.github.com>

* Apply ruff formatting to checkpoint encoding file

Co-authored-by: TaoChenOSU <12570346+TaoChenOSU@users.noreply.github.com>

* Changes before error encountered

Co-authored-by: TaoChenOSU <12570346+TaoChenOSU@users.noreply.github.com>

* Revert "Changes before error encountered"

This reverts commit f515b880dc.

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: TaoChenOSU <12570346+TaoChenOSU@users.noreply.github.com>
Co-authored-by: Tao Chen <taochen@microsoft.com>
This commit is contained in:
Copilot
2026-01-22 20:07:39 +01:00
committed by GitHub
Unverified
parent 958e6d27ce
commit 87c9d74bd7
3 changed files with 558 additions and 7 deletions
@@ -146,6 +146,10 @@ def decode_checkpoint_value(value: Any) -> Any:
cls = None
if cls is not None:
# Verify the class actually supports the model protocol
if not _class_supports_model_protocol(cls):
logger.debug(f"Class {type_key} does not support model protocol; returning raw value")
return decoded_payload
if strategy == "to_dict" and hasattr(cls, "from_dict"):
with contextlib.suppress(Exception):
return cls.from_dict(decoded_payload)
@@ -169,6 +173,10 @@ def decode_checkpoint_value(value: Any) -> Any:
if module is None:
module = importlib.import_module(module_name)
cls_dc: Any = getattr(module, class_name)
# Verify the class is actually a dataclass type (not an instance)
if not isinstance(cls_dc, type) or not is_dataclass(cls_dc):
logger.debug(f"Class {type_key_dc} is not a dataclass type; returning raw value")
return decoded_raw
constructed = _instantiate_checkpoint_dataclass(cls_dc, decoded_raw)
if constructed is not None:
return constructed
@@ -188,6 +196,22 @@ def decode_checkpoint_value(value: Any) -> Any:
return value
def _class_supports_model_protocol(cls: type[Any]) -> bool:
"""Check if a class type supports the model serialization protocol.
Checks for pairs of serialization/deserialization methods:
- to_dict/from_dict
- to_json/from_json
"""
has_to_dict = hasattr(cls, "to_dict") and callable(getattr(cls, "to_dict", None))
has_from_dict = hasattr(cls, "from_dict") and callable(getattr(cls, "from_dict", None))
has_to_json = hasattr(cls, "to_json") and callable(getattr(cls, "to_json", None))
has_from_json = hasattr(cls, "from_json") and callable(getattr(cls, "from_json", None))
return (has_to_dict and has_from_dict) or (has_to_json and has_from_json)
def _supports_model_protocol(obj: object) -> bool:
"""Detect objects that expose dictionary serialization hooks."""
try:
@@ -195,13 +219,7 @@ def _supports_model_protocol(obj: object) -> bool:
except Exception:
return False
has_to_dict = hasattr(obj, "to_dict") and callable(getattr(obj, "to_dict", None)) # type: ignore[arg-type]
has_from_dict = hasattr(obj_type, "from_dict") and callable(getattr(obj_type, "from_dict", None))
has_to_json = hasattr(obj, "to_json") and callable(getattr(obj, "to_json", None)) # type: ignore[arg-type]
has_from_json = hasattr(obj_type, "from_json") and callable(getattr(obj_type, "from_json", None))
return (has_to_dict and has_from_dict) or (has_to_json and has_from_json)
return _class_supports_model_protocol(obj_type)
def _import_qualified_name(qualname: str) -> type[Any] | None:
@@ -3,7 +3,10 @@
from dataclasses import dataclass # noqa: I001
from typing import Any, cast
from agent_framework._workflows._checkpoint_encoding import (
DATACLASS_MARKER,
MODEL_MARKER,
decode_checkpoint_value,
encode_checkpoint_value,
)
@@ -126,3 +129,110 @@ def test_encode_decode_nested_structures() -> None:
assert response.data == "first response"
assert isinstance(response.original_request, SampleRequest)
assert response.original_request.request_id == "req-1"
def test_encode_allows_marker_key_without_value_key() -> None:
"""Test that encoding a dict with only the marker key (no 'value') is allowed."""
dict_with_marker_only = {
MODEL_MARKER: "some.module:FakeClass",
"other_key": "test",
}
encoded = encode_checkpoint_value(dict_with_marker_only)
assert MODEL_MARKER in encoded
assert "other_key" in encoded
def test_encode_allows_value_key_without_marker_key() -> None:
"""Test that encoding a dict with only 'value' key (no marker) is allowed."""
dict_with_value_only = {
"value": {"data": "test"},
"other_key": "test",
}
encoded = encode_checkpoint_value(dict_with_value_only)
assert "value" in encoded
assert "other_key" in encoded
def test_encode_allows_marker_with_value_key() -> None:
"""Test that encoding a dict with marker and 'value' keys is allowed.
This is allowed because legitimate encoded data may contain these keys,
and security is enforced at deserialization time by validating class types.
"""
dict_with_both = {
MODEL_MARKER: "some.module:SomeClass",
"value": {"data": "test"},
"strategy": "to_dict",
}
encoded = encode_checkpoint_value(dict_with_both)
assert MODEL_MARKER in encoded
assert "value" in encoded
class NotADataclass:
"""A regular class that is not a dataclass."""
def __init__(self, value: str) -> None:
self.value = value
def get_value(self) -> str:
return self.value
class NotAModel:
"""A regular class that does not support the model protocol."""
def __init__(self, value: str) -> None:
self.value = value
def get_value(self) -> str:
return self.value
def test_decode_rejects_non_dataclass_with_dataclass_marker() -> None:
"""Test that decode returns raw value when marked class is not a dataclass."""
# Manually construct a payload that claims NotADataclass is a dataclass
fake_payload = {
DATACLASS_MARKER: f"{NotADataclass.__module__}:{NotADataclass.__name__}",
"value": {"value": "test_value"},
}
decoded = decode_checkpoint_value(fake_payload)
# Should return the raw decoded value, not an instance of NotADataclass
assert isinstance(decoded, dict)
assert decoded["value"] == "test_value"
def test_decode_rejects_non_model_with_model_marker() -> None:
"""Test that decode returns raw value when marked class doesn't support model protocol."""
# Manually construct a payload that claims NotAModel supports the model protocol
fake_payload = {
MODEL_MARKER: f"{NotAModel.__module__}:{NotAModel.__name__}",
"strategy": "to_dict",
"value": {"value": "test_value"},
}
decoded = decode_checkpoint_value(fake_payload)
# Should return the raw decoded value, not an instance of NotAModel
assert isinstance(decoded, dict)
assert decoded["value"] == "test_value"
def test_encode_allows_nested_dict_with_marker_keys() -> None:
"""Test that encoding allows nested dicts containing marker patterns.
Security is enforced at deserialization time, not serialization time,
so legitimate encoded data can contain markers at any nesting level.
"""
nested_data = {
"outer": {
MODEL_MARKER: "some.module:SomeClass",
"value": {"data": "test"},
}
}
encoded = encode_checkpoint_value(nested_data)
assert "outer" in encoded
assert MODEL_MARKER in encoded["outer"]
@@ -0,0 +1,423 @@
# Copyright (c) Microsoft. All rights reserved.
from dataclasses import dataclass
from typing import Any
from agent_framework._workflows._checkpoint_encoding import (
_CYCLE_SENTINEL,
DATACLASS_MARKER,
MODEL_MARKER,
encode_checkpoint_value,
)
@dataclass
class SimpleDataclass:
"""A simple dataclass for testing encoding."""
name: str
value: int
@dataclass
class NestedDataclass:
"""A dataclass with nested dataclass field."""
outer_name: str
inner: SimpleDataclass
class ModelWithToDict:
"""A class that implements to_dict/from_dict protocol."""
def __init__(self, data: str) -> None:
self.data = data
def to_dict(self) -> dict[str, Any]:
return {"data": self.data}
@classmethod
def from_dict(cls, d: dict[str, Any]) -> "ModelWithToDict":
return cls(data=d["data"])
class ModelWithToJson:
"""A class that implements to_json/from_json protocol."""
def __init__(self, data: str) -> None:
self.data = data
def to_json(self) -> str:
return f'{{"data": "{self.data}"}}'
@classmethod
def from_json(cls, json_str: str) -> "ModelWithToJson":
import json
d = json.loads(json_str)
return cls(data=d["data"])
class UnknownObject:
"""A class that doesn't support any serialization protocol."""
def __init__(self, value: str) -> None:
self.value = value
def __str__(self) -> str:
return f"UnknownObject({self.value})"
# --- Tests for primitive encoding ---
def test_encode_string() -> None:
"""Test encoding a string value."""
result = encode_checkpoint_value("hello")
assert result == "hello"
def test_encode_integer() -> None:
"""Test encoding an integer value."""
result = encode_checkpoint_value(42)
assert result == 42
def test_encode_float() -> None:
"""Test encoding a float value."""
result = encode_checkpoint_value(3.14)
assert result == 3.14
def test_encode_boolean_true() -> None:
"""Test encoding a True boolean value."""
result = encode_checkpoint_value(True)
assert result is True
def test_encode_boolean_false() -> None:
"""Test encoding a False boolean value."""
result = encode_checkpoint_value(False)
assert result is False
def test_encode_none() -> None:
"""Test encoding a None value."""
result = encode_checkpoint_value(None)
assert result is None
# --- Tests for collection encoding ---
def test_encode_empty_dict() -> None:
"""Test encoding an empty dictionary."""
result = encode_checkpoint_value({})
assert result == {}
def test_encode_simple_dict() -> None:
"""Test encoding a simple dictionary with primitive values."""
data = {"name": "test", "count": 5, "active": True}
result = encode_checkpoint_value(data)
assert result == {"name": "test", "count": 5, "active": True}
def test_encode_dict_with_non_string_keys() -> None:
"""Test encoding a dictionary with non-string keys (converted to strings)."""
data = {1: "one", 2: "two"}
result = encode_checkpoint_value(data)
assert result == {"1": "one", "2": "two"}
def test_encode_empty_list() -> None:
"""Test encoding an empty list."""
result = encode_checkpoint_value([])
assert result == []
def test_encode_simple_list() -> None:
"""Test encoding a simple list with primitive values."""
data = [1, 2, 3, "four"]
result = encode_checkpoint_value(data)
assert result == [1, 2, 3, "four"]
def test_encode_tuple() -> None:
"""Test encoding a tuple (converted to list)."""
data = (1, 2, 3)
result = encode_checkpoint_value(data)
assert result == [1, 2, 3]
def test_encode_set() -> None:
"""Test encoding a set (converted to list)."""
data = {1, 2, 3}
result = encode_checkpoint_value(data)
assert isinstance(result, list)
assert sorted(result) == [1, 2, 3]
def test_encode_nested_dict() -> None:
"""Test encoding a nested dictionary structure."""
data = {
"outer": {
"inner": {
"value": 42,
}
}
}
result = encode_checkpoint_value(data)
assert result == {"outer": {"inner": {"value": 42}}}
def test_encode_list_of_dicts() -> None:
"""Test encoding a list containing dictionaries."""
data = [{"a": 1}, {"b": 2}]
result = encode_checkpoint_value(data)
assert result == [{"a": 1}, {"b": 2}]
# --- Tests for dataclass encoding ---
def test_encode_simple_dataclass() -> None:
"""Test encoding a simple dataclass."""
obj = SimpleDataclass(name="test", value=42)
result = encode_checkpoint_value(obj)
assert isinstance(result, dict)
assert DATACLASS_MARKER in result
assert "value" in result
assert result["value"] == {"name": "test", "value": 42}
def test_encode_nested_dataclass() -> None:
"""Test encoding a dataclass with nested dataclass fields."""
inner = SimpleDataclass(name="inner", value=10)
outer = NestedDataclass(outer_name="outer", inner=inner)
result = encode_checkpoint_value(outer)
assert isinstance(result, dict)
assert DATACLASS_MARKER in result
assert "value" in result
outer_value = result["value"]
assert outer_value["outer_name"] == "outer"
assert DATACLASS_MARKER in outer_value["inner"]
def test_encode_list_of_dataclasses() -> None:
"""Test encoding a list containing dataclass instances."""
data = [
SimpleDataclass(name="first", value=1),
SimpleDataclass(name="second", value=2),
]
result = encode_checkpoint_value(data)
assert isinstance(result, list)
assert len(result) == 2
for item in result:
assert DATACLASS_MARKER in item
def test_encode_dict_with_dataclass_values() -> None:
"""Test encoding a dictionary with dataclass values."""
data = {
"item1": SimpleDataclass(name="first", value=1),
"item2": SimpleDataclass(name="second", value=2),
}
result = encode_checkpoint_value(data)
assert isinstance(result, dict)
assert DATACLASS_MARKER in result["item1"]
assert DATACLASS_MARKER in result["item2"]
# --- Tests for model protocol encoding ---
def test_encode_model_with_to_dict() -> None:
"""Test encoding an object implementing to_dict/from_dict protocol."""
obj = ModelWithToDict(data="test_data")
result = encode_checkpoint_value(obj)
assert isinstance(result, dict)
assert MODEL_MARKER in result
assert result["strategy"] == "to_dict"
assert result["value"] == {"data": "test_data"}
def test_encode_model_with_to_json() -> None:
"""Test encoding an object implementing to_json/from_json protocol."""
obj = ModelWithToJson(data="test_data")
result = encode_checkpoint_value(obj)
assert isinstance(result, dict)
assert MODEL_MARKER in result
assert result["strategy"] == "to_json"
assert '"data": "test_data"' in result["value"]
# --- Tests for unknown object encoding ---
def test_encode_unknown_object_fallback_to_string() -> None:
"""Test that unknown objects are encoded as strings."""
obj = UnknownObject(value="test")
result = encode_checkpoint_value(obj)
assert isinstance(result, str)
assert "UnknownObject" in result
# --- Tests for cycle detection ---
def test_encode_dict_with_self_reference() -> None:
"""Test that dict self-references are detected and handled."""
data: dict[str, Any] = {"name": "test"}
data["self"] = data # Create circular reference
result = encode_checkpoint_value(data)
assert result["name"] == "test"
assert result["self"] == _CYCLE_SENTINEL
def test_encode_list_with_self_reference() -> None:
"""Test that list self-references are detected and handled."""
data: list[Any] = [1, 2]
data.append(data) # Create circular reference
result = encode_checkpoint_value(data)
assert result[0] == 1
assert result[1] == 2
assert result[2] == _CYCLE_SENTINEL
# --- Tests for reserved keyword handling ---
# Note: Security is enforced at deserialization time by validating class types,
# not at serialization time. This allows legitimate encoded data to be re-encoded.
def test_encode_allows_dict_with_model_marker_and_value() -> None:
"""Test that encoding a dict with MODEL_MARKER and 'value' is allowed.
Security is enforced at deserialization time, not serialization time.
"""
data = {
MODEL_MARKER: "some.module:SomeClass",
"value": {"data": "test"},
}
result = encode_checkpoint_value(data)
assert MODEL_MARKER in result
assert "value" in result
def test_encode_allows_dict_with_dataclass_marker_and_value() -> None:
"""Test that encoding a dict with DATACLASS_MARKER and 'value' is allowed.
Security is enforced at deserialization time, not serialization time.
"""
data = {
DATACLASS_MARKER: "some.module:SomeClass",
"value": {"field": "test"},
}
result = encode_checkpoint_value(data)
assert DATACLASS_MARKER in result
assert "value" in result
def test_encode_allows_nested_dict_with_marker_keys() -> None:
"""Test that encoding nested dict with marker keys is allowed.
Security is enforced at deserialization time, not serialization time.
"""
nested_data = {
"outer": {
MODEL_MARKER: "some.module:SomeClass",
"value": {"data": "test"},
}
}
result = encode_checkpoint_value(nested_data)
assert "outer" in result
assert MODEL_MARKER in result["outer"]
def test_encode_allows_marker_without_value() -> None:
"""Test that a dict with marker key but without 'value' key is allowed."""
data = {
MODEL_MARKER: "some.module:SomeClass",
"other_key": "allowed",
}
result = encode_checkpoint_value(data)
assert MODEL_MARKER in result
assert result["other_key"] == "allowed"
def test_encode_allows_value_without_marker() -> None:
"""Test that a dict with 'value' key but without marker is allowed."""
data = {
"value": {"nested": "data"},
"other_key": "allowed",
}
result = encode_checkpoint_value(data)
assert "value" in result
assert result["other_key"] == "allowed"
# --- Tests for max depth protection ---
def test_encode_deep_nesting_triggers_max_depth() -> None:
"""Test that very deep nesting triggers max depth protection."""
# Create a deeply nested structure (over 100 levels)
data: dict[str, Any] = {"level": 0}
current = data
for i in range(105):
current["nested"] = {"level": i + 1}
current = current["nested"]
result = encode_checkpoint_value(data)
# Navigate to find the max_depth sentinel
current_result = result
found_max_depth = False
for _ in range(110):
if isinstance(current_result, dict) and "nested" in current_result:
current_result = current_result["nested"]
if current_result == "<max_depth>":
found_max_depth = True
break
else:
break
assert found_max_depth, "Expected <max_depth> sentinel to be found in deeply nested structure"
# --- Tests for mixed complex structures ---
def test_encode_complex_mixed_structure() -> None:
"""Test encoding a complex structure with mixed types."""
data = {
"string_value": "hello",
"int_value": 42,
"float_value": 3.14,
"bool_value": True,
"none_value": None,
"list_value": [1, 2, 3],
"nested_dict": {"a": 1, "b": 2},
"dataclass_value": SimpleDataclass(name="test", value=100),
}
result = encode_checkpoint_value(data)
assert result["string_value"] == "hello"
assert result["int_value"] == 42
assert result["float_value"] == 3.14
assert result["bool_value"] is True
assert result["none_value"] is None
assert result["list_value"] == [1, 2, 3]
assert result["nested_dict"] == {"a": 1, "b": 2}
assert DATACLASS_MARKER in result["dataclass_value"]