Python: Prevent pickle deserialization of untrusted HITL HTTP input (#4566)

* fix: prevent pickle deserialization of untrusted HITL input

Add strip_pickle_markers() to sanitize HTTP input before it reaches
pickle.loads() via the checkpoint decoding path. Applied as a 3-layer
defence-in-depth:

1. _app.py: sanitize req.get_json() at the HTTP boundary
2. _workflow.py: sanitize in _deserialize_hitl_response() before decode
3. _serialization.py: sanitize in reconstruct_to_type() as final guard

Any dict containing __pickled__ or __type__ markers from untrusted
sources is replaced with None, blocking arbitrary code execution via
crafted payloads to POST /workflow/respond/{instanceId}/{requestId}.

Includes 12 new unit tests covering the sanitizer and end-to-end
attack prevention.

* refactor: address review concerns for pickle fix

1. Remove deserialize_value() fallback in _deserialize_hitl_response
   untrusted HITL data now returns as-is when no type hint is available,
   never flowing into pickle.loads().

2. Move strip_pickle_markers() out of reconstruct_to_type()  the function
   is general-purpose again; untrusted-data callers are responsible for
   sanitizing first (documented with NOTE comment).

3. Define _PICKLE_MARKER/_TYPE_MARKER as local constants with import-time
   assertions against core's values  decouples from private names while
   failing loudly if core ever changes them.

4. Update tests to reflect new responsibility boundaries.

* fix: simplify warning message and fix ruff RUF001 lint

* fix: suppress pyright reportPrivateUsage on core marker imports

* Lower marker-strip log from warning to debug to avoid log flooding

* Replace assert with RuntimeError for marker sync checks (ruff S101)

* Fix pyright and ruff CI errors in security fix

- Use cast() for dict/list comprehensions in strip_pickle_markers (pyright)
- type: ignore for narrowed dict return in _workflow.py (pyright)
- Simplify marker imports: use core constants directly, remove local copies
- Remove duplicate pyright ignore comment

* Remove duplicate end-to-end test in TestStripPickleMarkers

* Suppress mypy redundant-cast on list cast needed by pyright
This commit is contained in:
Ahmed Muhsin
2026-03-10 14:29:33 -05:00
committed by GitHub
Unverified
parent 55fc882ca8
commit 09b3e2e4f0
4 changed files with 141 additions and 10 deletions
@@ -44,7 +44,7 @@ from ._context import CapturingRunnerContext
from ._entities import create_agent_entity
from ._errors import IncomingRequestError
from ._orchestration import AgentOrchestrationContextType, AgentTask, AzureFunctionsAgentExecutor
from ._serialization import deserialize_value, serialize_value
from ._serialization import deserialize_value, serialize_value, strip_pickle_markers
from ._workflow import (
SOURCE_HITL_RESPONSE,
SOURCE_ORCHESTRATOR,
@@ -515,6 +515,10 @@ class AgentFunctionApp(DFAppBase):
except ValueError:
return self._build_error_response("Request body must be valid JSON.")
# Sanitize untrusted HTTP input before it reaches pickle.loads().
# See strip_pickle_markers() docstring for details on the attack vector.
response_data = strip_pickle_markers(response_data)
# Send the response as an external event
# The request_id is used as the event name for correlation
await client.raise_event(
@@ -22,9 +22,14 @@ import importlib
import logging
from contextlib import suppress
from dataclasses import is_dataclass
from typing import Any
from typing import Any, cast
from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
from agent_framework._workflows._checkpoint_encoding import (
_PICKLE_MARKER, # pyright: ignore[reportPrivateUsage]
_TYPE_MARKER, # pyright: ignore[reportPrivateUsage]
decode_checkpoint_value,
encode_checkpoint_value,
)
from pydantic import BaseModel
logger = logging.getLogger(__name__)
@@ -48,6 +53,41 @@ def resolve_type(type_key: str) -> type | None:
return None
# ============================================================================
# Pickle marker sanitization (security)
# ============================================================================
def strip_pickle_markers(data: Any) -> Any:
"""Recursively strip pickle/type markers from untrusted data.
The core checkpoint encoding uses ``__pickled__`` and ``__type__`` markers to
roundtrip arbitrary Python objects via *pickle*. If an attacker crafts an
HTTP payload that contains these markers, the data would flow into
``pickle.loads()`` and enable **arbitrary code execution**.
This function walks the incoming data structure and replaces any ``dict``
that contains either marker key with ``None``, neutralising the attack
vector while leaving all other data untouched.
It **must** be called on every value that originates from an untrusted
source (e.g. ``req.get_json()``) *before* the value is passed to
``deserialize_value`` / ``decode_checkpoint_value``.
"""
if isinstance(data, dict):
if _PICKLE_MARKER in data or _TYPE_MARKER in data:
logger.debug("Stripped pickle/type markers from untrusted input.")
return None
typed_dict = cast(dict[str, Any], data)
return {k: strip_pickle_markers(v) for k, v in typed_dict.items()}
if isinstance(data, list):
typed_list = cast(list[Any], data) # type: ignore[redundant-cast]
return [strip_pickle_markers(item) for item in typed_list]
return data
# ============================================================================
# Serialize / Deserialize
# ============================================================================
@@ -117,7 +157,10 @@ def reconstruct_to_type(value: Any, target_type: type) -> Any:
if not isinstance(value, dict):
return value
# Try decoding if data has pickle markers (from checkpoint encoding)
# Try decoding if data has pickle markers (from checkpoint encoding).
# NOTE: This function is general-purpose. Callers that handle untrusted
# data (e.g. HITL responses) MUST call strip_pickle_markers() before
# passing data here. See _deserialize_hitl_response in _workflow.py.
decoded = deserialize_value(value)
if not isinstance(decoded, dict):
return decoded
@@ -50,7 +50,7 @@ from azure.durable_functions import DurableOrchestrationContext
from ._context import CapturingRunnerContext
from ._orchestration import AzureFunctionsAgentExecutor
from ._serialization import deserialize_value, reconstruct_to_type, resolve_type, serialize_value
from ._serialization import deserialize_value, reconstruct_to_type, resolve_type, serialize_value, strip_pickle_markers
logger = logging.getLogger(__name__)
@@ -961,6 +961,13 @@ def _deserialize_hitl_response(response_data: Any, response_type_str: str | None
type(response_data).__name__,
)
if response_data is None:
return None
# Sanitize untrusted external input before deserialization.
# HITL response data originates from an HTTP POST and must not contain
# pickle/type markers that would reach pickle.loads().
response_data = strip_pickle_markers(response_data)
if response_data is None:
return None
@@ -969,7 +976,7 @@ def _deserialize_hitl_response(response_data: Any, response_type_str: str | None
logger.debug("Response data is not a dict, returning as-is: %s", type(response_data).__name__)
return response_data
# Try to deserialize using the type hint
# Try to reconstruct using the type hint (Pydantic / dataclass)
if response_type_str:
response_type = resolve_type(response_type_str)
if response_type:
@@ -979,6 +986,8 @@ def _deserialize_hitl_response(response_data: Any, response_type_str: str | None
return result
logger.warning("Could not resolve response type: %s", response_type_str)
# Fall back to generic deserialization
logger.debug("Falling back to generic deserialization")
return deserialize_value(response_data)
# No type hint available - return the sanitized dict as-is.
# We intentionally do NOT call deserialize_value() here because HITL
# response data is untrusted and must never flow into pickle.loads().
logger.debug("No type hint; returning sanitized data as-is")
return response_data # type: ignore[reportUnknownVariableType]
@@ -21,6 +21,7 @@ from agent_framework_azurefunctions._serialization import (
deserialize_value,
reconstruct_to_type,
serialize_value,
strip_pickle_markers,
)
@@ -353,7 +354,11 @@ class TestReconstructToType:
assert result.comment == "Great"
def test_reconstruct_from_checkpoint_markers(self) -> None:
"""Test that data with checkpoint markers is decoded via deserialize_value."""
"""Test that data with checkpoint markers is decoded via deserialize_value.
reconstruct_to_type is general-purpose and handles trusted checkpoint
data. Untrusted HITL callers must call strip_pickle_markers() first.
"""
original = SampleData(value=99, name="marker-test")
encoded = serialize_value(original)
@@ -372,3 +377,73 @@ class TestReconstructToType:
result = reconstruct_to_type(data, Unrelated)
assert result == data
def test_reconstruct_strips_injected_pickle_markers(self) -> None:
"""End-to-end: strip_pickle_markers + reconstruct_to_type blocks attack.
This mirrors the real HITL flow where callers sanitize before reconstruction.
"""
malicious = {"__pickled__": "gASVDgAAAAAAAACMBHRlc3SULg==", "__type__": "builtins:str"}
sanitized = strip_pickle_markers(malicious)
result = reconstruct_to_type(sanitized, str)
assert result is None
class TestStripPickleMarkers:
"""Security tests for strip_pickle_markers — the defence-in-depth layer
that prevents untrusted HTTP input from reaching pickle.loads()."""
def test_strips_top_level_pickle_marker(self) -> None:
"""A dict containing __pickled__ must be replaced with None."""
data = {"__pickled__": "PAYLOAD", "__type__": "os:system"}
assert strip_pickle_markers(data) is None
def test_strips_top_level_type_marker_only(self) -> None:
"""Even __type__ alone (without __pickled__) must be neutralised."""
data = {"__type__": "os:system", "other": "value"}
assert strip_pickle_markers(data) is None
def test_strips_nested_pickle_marker(self) -> None:
"""Pickle markers nested inside a dict must be neutralised."""
data = {"safe": "value", "nested": {"__pickled__": "PAYLOAD", "__type__": "os:system"}}
result = strip_pickle_markers(data)
assert result == {"safe": "value", "nested": None}
def test_strips_pickle_marker_in_list(self) -> None:
"""Pickle markers inside a list element must be neutralised."""
data = [{"__pickled__": "PAYLOAD"}, "safe"]
result = strip_pickle_markers(data)
assert result == [None, "safe"]
def test_strips_deeply_nested_marker(self) -> None:
"""Deeply nested pickle markers must be neutralised."""
data = {"a": {"b": {"c": {"__pickled__": "deep"}}}}
result = strip_pickle_markers(data)
assert result == {"a": {"b": {"c": None}}}
def test_preserves_safe_dict(self) -> None:
"""Dicts without pickle markers must be left untouched."""
data = {"approved": True, "reason": "Looks good"}
assert strip_pickle_markers(data) == data
def test_preserves_primitives(self) -> None:
"""Primitive values must pass through unchanged."""
assert strip_pickle_markers("hello") == "hello"
assert strip_pickle_markers(42) == 42
assert strip_pickle_markers(None) is None
assert strip_pickle_markers(True) is True
def test_preserves_safe_list(self) -> None:
"""Lists without pickle markers must be left untouched."""
data = [1, "two", {"key": "value"}]
assert strip_pickle_markers(data) == data
def test_mixed_safe_and_malicious(self) -> None:
"""Only the malicious entries should be stripped; safe entries remain."""
data = {
"user_input": "hello",
"evil": {"__pickled__": "PAYLOAD", "__type__": "os:system"},
"count": 42,
}
result = strip_pickle_markers(data)
assert result == {"user_input": "hello", "evil": None, "count": 42}