mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Prevent pickle deserialization of untrusted HITL HTTP input (#4566)
* fix: prevent pickle deserialization of untrusted HITL input
Add strip_pickle_markers() to sanitize HTTP input before it reaches
pickle.loads() via the checkpoint decoding path. Applied as a 3-layer
defence-in-depth:
1. _app.py: sanitize req.get_json() at the HTTP boundary
2. _workflow.py: sanitize in _deserialize_hitl_response() before decode
3. _serialization.py: sanitize in reconstruct_to_type() as final guard
Any dict containing __pickled__ or __type__ markers from untrusted
sources is replaced with None, blocking arbitrary code execution via
crafted payloads to POST /workflow/respond/{instanceId}/{requestId}.
Includes 12 new unit tests covering the sanitizer and end-to-end
attack prevention.
* refactor: address review concerns for pickle fix
1. Remove deserialize_value() fallback in _deserialize_hitl_response
untrusted HITL data now returns as-is when no type hint is available,
never flowing into pickle.loads().
2. Move strip_pickle_markers() out of reconstruct_to_type() the function
is general-purpose again; untrusted-data callers are responsible for
sanitizing first (documented with NOTE comment).
3. Define _PICKLE_MARKER/_TYPE_MARKER as local constants with import-time
assertions against core's values decouples from private names while
failing loudly if core ever changes them.
4. Update tests to reflect new responsibility boundaries.
* fix: simplify warning message and fix ruff RUF001 lint
* fix: suppress pyright reportPrivateUsage on core marker imports
* Lower marker-strip log from warning to debug to avoid log flooding
* Replace assert with RuntimeError for marker sync checks (ruff S101)
* Fix pyright and ruff CI errors in security fix
- Use cast() for dict/list comprehensions in strip_pickle_markers (pyright)
- type: ignore for narrowed dict return in _workflow.py (pyright)
- Simplify marker imports: use core constants directly, remove local copies
- Remove duplicate pyright ignore comment
* Remove duplicate end-to-end test in TestStripPickleMarkers
* Suppress mypy redundant-cast on list cast needed by pyright
This commit is contained in:
committed by
GitHub
Unverified
parent
55fc882ca8
commit
09b3e2e4f0
@@ -44,7 +44,7 @@ from ._context import CapturingRunnerContext
|
||||
from ._entities import create_agent_entity
|
||||
from ._errors import IncomingRequestError
|
||||
from ._orchestration import AgentOrchestrationContextType, AgentTask, AzureFunctionsAgentExecutor
|
||||
from ._serialization import deserialize_value, serialize_value
|
||||
from ._serialization import deserialize_value, serialize_value, strip_pickle_markers
|
||||
from ._workflow import (
|
||||
SOURCE_HITL_RESPONSE,
|
||||
SOURCE_ORCHESTRATOR,
|
||||
@@ -515,6 +515,10 @@ class AgentFunctionApp(DFAppBase):
|
||||
except ValueError:
|
||||
return self._build_error_response("Request body must be valid JSON.")
|
||||
|
||||
# Sanitize untrusted HTTP input before it reaches pickle.loads().
|
||||
# See strip_pickle_markers() docstring for details on the attack vector.
|
||||
response_data = strip_pickle_markers(response_data)
|
||||
|
||||
# Send the response as an external event
|
||||
# The request_id is used as the event name for correlation
|
||||
await client.raise_event(
|
||||
|
||||
@@ -22,9 +22,14 @@ import importlib
|
||||
import logging
|
||||
from contextlib import suppress
|
||||
from dataclasses import is_dataclass
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
|
||||
from agent_framework._workflows._checkpoint_encoding import (
|
||||
_PICKLE_MARKER, # pyright: ignore[reportPrivateUsage]
|
||||
_TYPE_MARKER, # pyright: ignore[reportPrivateUsage]
|
||||
decode_checkpoint_value,
|
||||
encode_checkpoint_value,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -48,6 +53,41 @@ def resolve_type(type_key: str) -> type | None:
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Pickle marker sanitization (security)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def strip_pickle_markers(data: Any) -> Any:
|
||||
"""Recursively strip pickle/type markers from untrusted data.
|
||||
|
||||
The core checkpoint encoding uses ``__pickled__`` and ``__type__`` markers to
|
||||
roundtrip arbitrary Python objects via *pickle*. If an attacker crafts an
|
||||
HTTP payload that contains these markers, the data would flow into
|
||||
``pickle.loads()`` and enable **arbitrary code execution**.
|
||||
|
||||
This function walks the incoming data structure and replaces any ``dict``
|
||||
that contains either marker key with ``None``, neutralising the attack
|
||||
vector while leaving all other data untouched.
|
||||
|
||||
It **must** be called on every value that originates from an untrusted
|
||||
source (e.g. ``req.get_json()``) *before* the value is passed to
|
||||
``deserialize_value`` / ``decode_checkpoint_value``.
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
if _PICKLE_MARKER in data or _TYPE_MARKER in data:
|
||||
logger.debug("Stripped pickle/type markers from untrusted input.")
|
||||
return None
|
||||
typed_dict = cast(dict[str, Any], data)
|
||||
return {k: strip_pickle_markers(v) for k, v in typed_dict.items()}
|
||||
|
||||
if isinstance(data, list):
|
||||
typed_list = cast(list[Any], data) # type: ignore[redundant-cast]
|
||||
return [strip_pickle_markers(item) for item in typed_list]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Serialize / Deserialize
|
||||
# ============================================================================
|
||||
@@ -117,7 +157,10 @@ def reconstruct_to_type(value: Any, target_type: type) -> Any:
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
# Try decoding if data has pickle markers (from checkpoint encoding)
|
||||
# Try decoding if data has pickle markers (from checkpoint encoding).
|
||||
# NOTE: This function is general-purpose. Callers that handle untrusted
|
||||
# data (e.g. HITL responses) MUST call strip_pickle_markers() before
|
||||
# passing data here. See _deserialize_hitl_response in _workflow.py.
|
||||
decoded = deserialize_value(value)
|
||||
if not isinstance(decoded, dict):
|
||||
return decoded
|
||||
|
||||
@@ -50,7 +50,7 @@ from azure.durable_functions import DurableOrchestrationContext
|
||||
|
||||
from ._context import CapturingRunnerContext
|
||||
from ._orchestration import AzureFunctionsAgentExecutor
|
||||
from ._serialization import deserialize_value, reconstruct_to_type, resolve_type, serialize_value
|
||||
from ._serialization import deserialize_value, reconstruct_to_type, resolve_type, serialize_value, strip_pickle_markers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -961,6 +961,13 @@ def _deserialize_hitl_response(response_data: Any, response_type_str: str | None
|
||||
type(response_data).__name__,
|
||||
)
|
||||
|
||||
if response_data is None:
|
||||
return None
|
||||
|
||||
# Sanitize untrusted external input before deserialization.
|
||||
# HITL response data originates from an HTTP POST and must not contain
|
||||
# pickle/type markers that would reach pickle.loads().
|
||||
response_data = strip_pickle_markers(response_data)
|
||||
if response_data is None:
|
||||
return None
|
||||
|
||||
@@ -969,7 +976,7 @@ def _deserialize_hitl_response(response_data: Any, response_type_str: str | None
|
||||
logger.debug("Response data is not a dict, returning as-is: %s", type(response_data).__name__)
|
||||
return response_data
|
||||
|
||||
# Try to deserialize using the type hint
|
||||
# Try to reconstruct using the type hint (Pydantic / dataclass)
|
||||
if response_type_str:
|
||||
response_type = resolve_type(response_type_str)
|
||||
if response_type:
|
||||
@@ -979,6 +986,8 @@ def _deserialize_hitl_response(response_data: Any, response_type_str: str | None
|
||||
return result
|
||||
logger.warning("Could not resolve response type: %s", response_type_str)
|
||||
|
||||
# Fall back to generic deserialization
|
||||
logger.debug("Falling back to generic deserialization")
|
||||
return deserialize_value(response_data)
|
||||
# No type hint available - return the sanitized dict as-is.
|
||||
# We intentionally do NOT call deserialize_value() here because HITL
|
||||
# response data is untrusted and must never flow into pickle.loads().
|
||||
logger.debug("No type hint; returning sanitized data as-is")
|
||||
return response_data # type: ignore[reportUnknownVariableType]
|
||||
|
||||
@@ -21,6 +21,7 @@ from agent_framework_azurefunctions._serialization import (
|
||||
deserialize_value,
|
||||
reconstruct_to_type,
|
||||
serialize_value,
|
||||
strip_pickle_markers,
|
||||
)
|
||||
|
||||
|
||||
@@ -353,7 +354,11 @@ class TestReconstructToType:
|
||||
assert result.comment == "Great"
|
||||
|
||||
def test_reconstruct_from_checkpoint_markers(self) -> None:
|
||||
"""Test that data with checkpoint markers is decoded via deserialize_value."""
|
||||
"""Test that data with checkpoint markers is decoded via deserialize_value.
|
||||
|
||||
reconstruct_to_type is general-purpose and handles trusted checkpoint
|
||||
data. Untrusted HITL callers must call strip_pickle_markers() first.
|
||||
"""
|
||||
original = SampleData(value=99, name="marker-test")
|
||||
encoded = serialize_value(original)
|
||||
|
||||
@@ -372,3 +377,73 @@ class TestReconstructToType:
|
||||
result = reconstruct_to_type(data, Unrelated)
|
||||
|
||||
assert result == data
|
||||
|
||||
def test_reconstruct_strips_injected_pickle_markers(self) -> None:
|
||||
"""End-to-end: strip_pickle_markers + reconstruct_to_type blocks attack.
|
||||
|
||||
This mirrors the real HITL flow where callers sanitize before reconstruction.
|
||||
"""
|
||||
malicious = {"__pickled__": "gASVDgAAAAAAAACMBHRlc3SULg==", "__type__": "builtins:str"}
|
||||
sanitized = strip_pickle_markers(malicious)
|
||||
result = reconstruct_to_type(sanitized, str)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestStripPickleMarkers:
|
||||
"""Security tests for strip_pickle_markers — the defence-in-depth layer
|
||||
that prevents untrusted HTTP input from reaching pickle.loads()."""
|
||||
|
||||
def test_strips_top_level_pickle_marker(self) -> None:
|
||||
"""A dict containing __pickled__ must be replaced with None."""
|
||||
data = {"__pickled__": "PAYLOAD", "__type__": "os:system"}
|
||||
assert strip_pickle_markers(data) is None
|
||||
|
||||
def test_strips_top_level_type_marker_only(self) -> None:
|
||||
"""Even __type__ alone (without __pickled__) must be neutralised."""
|
||||
data = {"__type__": "os:system", "other": "value"}
|
||||
assert strip_pickle_markers(data) is None
|
||||
|
||||
def test_strips_nested_pickle_marker(self) -> None:
|
||||
"""Pickle markers nested inside a dict must be neutralised."""
|
||||
data = {"safe": "value", "nested": {"__pickled__": "PAYLOAD", "__type__": "os:system"}}
|
||||
result = strip_pickle_markers(data)
|
||||
assert result == {"safe": "value", "nested": None}
|
||||
|
||||
def test_strips_pickle_marker_in_list(self) -> None:
|
||||
"""Pickle markers inside a list element must be neutralised."""
|
||||
data = [{"__pickled__": "PAYLOAD"}, "safe"]
|
||||
result = strip_pickle_markers(data)
|
||||
assert result == [None, "safe"]
|
||||
|
||||
def test_strips_deeply_nested_marker(self) -> None:
|
||||
"""Deeply nested pickle markers must be neutralised."""
|
||||
data = {"a": {"b": {"c": {"__pickled__": "deep"}}}}
|
||||
result = strip_pickle_markers(data)
|
||||
assert result == {"a": {"b": {"c": None}}}
|
||||
|
||||
def test_preserves_safe_dict(self) -> None:
|
||||
"""Dicts without pickle markers must be left untouched."""
|
||||
data = {"approved": True, "reason": "Looks good"}
|
||||
assert strip_pickle_markers(data) == data
|
||||
|
||||
def test_preserves_primitives(self) -> None:
|
||||
"""Primitive values must pass through unchanged."""
|
||||
assert strip_pickle_markers("hello") == "hello"
|
||||
assert strip_pickle_markers(42) == 42
|
||||
assert strip_pickle_markers(None) is None
|
||||
assert strip_pickle_markers(True) is True
|
||||
|
||||
def test_preserves_safe_list(self) -> None:
|
||||
"""Lists without pickle markers must be left untouched."""
|
||||
data = [1, "two", {"key": "value"}]
|
||||
assert strip_pickle_markers(data) == data
|
||||
|
||||
def test_mixed_safe_and_malicious(self) -> None:
|
||||
"""Only the malicious entries should be stripped; safe entries remain."""
|
||||
data = {
|
||||
"user_input": "hello",
|
||||
"evil": {"__pickled__": "PAYLOAD", "__type__": "os:system"},
|
||||
"count": 42,
|
||||
}
|
||||
result = strip_pickle_markers(data)
|
||||
assert result == {"user_input": "hello", "evil": None, "count": 42}
|
||||
|
||||
Reference in New Issue
Block a user