Python: Add OpenAI types to default checkpoint encoding allow list (#5297)

* Add OpenAI types to default checkpoint encoding allow list

* Address comments
This commit is contained in:
Tao Chen
2026-04-15 20:58:28 -07:00
committed by GitHub
Unverified
parent 69697065ab
commit 8f7fd9525d
3 changed files with 61 additions and 9 deletions
@@ -244,10 +244,10 @@ class FileCheckpointStorage:
is serialized using pickle and embedded as base64-encoded strings within the JSON. This allows
for human-readable checkpoint files while preserving the ability to store complex Python objects.
By default, checkpoint deserialization is restricted to a built-in set of safe
Python types (primitives, datetime, uuid, ...) and all ``agent_framework``
internal types. To allow additional application-specific types, pass them via
the ``allowed_checkpoint_types`` parameter using ``"module:qualname"`` format.
By default, checkpoint deserialization is restricted to a built-in set of safe Python types
(primitives, datetime, uuid, ...), all ``agent_framework`` internal types, and OpenAI SDK types
(``openai.types``). To allow additional application-specific types, pass them via the
``allowed_checkpoint_types`` parameter using ``"module:qualname"`` format.
Example::
@@ -10,9 +10,9 @@ This hybrid approach provides:
When ``allowed_types`` is supplied to :func:`decode_checkpoint_value`, a
``RestrictedUnpickler`` is used that limits which classes may be instantiated
during deserialization. The default built-in safe set covers common Python
value types (primitives, datetime, uuid, ...) and all ``agent_framework``
internal types. Callers can extend the set by passing additional
``"module:qualname"`` strings.
value types (primitives, datetime, uuid, ...), all ``agent_framework`` internal
types, and all ``openai.types`` types. Callers can extend the set by passing
additional ``"module:qualname"`` strings.
"""
from __future__ import annotations
@@ -37,6 +37,9 @@ _JSON_NATIVE_TYPES = (str, int, float, bool, type(None))
# Module prefix for framework-internal types that are always allowed
_FRAMEWORK_MODULE_PREFIX = "agent_framework."
# Module prefix for OpenAI SDK types that are always allowed
_OPENAI_MODULE_PREFIX = "openai.types."
# Built-in types considered safe for checkpoint deserialization.
# Each entry is a ``module:qualname`` string matching the format produced by
# :func:`_type_to_key`. These are the classes for which pickle's
@@ -84,8 +87,9 @@ class _RestrictedUnpickler(pickle.Unpickler): # noqa: S301
"""Unpickler that restricts which classes may be instantiated.
Only classes whose ``module:qualname`` key appears in the combined allow
set (built-in safe types + framework types + caller-specified extras) are
permitted. All other classes raise :class:`pickle.UnpicklingError`.
set (built-in safe types + framework types + OpenAI SDK types +
caller-specified extras) are permitted. All other classes raise
:class:`pickle.UnpicklingError`.
"""
def __init__(self, data: bytes, allowed_types: frozenset[str]) -> None:
@@ -99,6 +103,7 @@ class _RestrictedUnpickler(pickle.Unpickler): # noqa: S301
type_key in _BUILTIN_ALLOWED_TYPE_KEYS
or type_key in self._allowed_types
or module.startswith(_FRAMEWORK_MODULE_PREFIX)
or module.startswith(_OPENAI_MODULE_PREFIX)
):
return super().find_class(module, name) # type: ignore[no-any-return] # nosec
@@ -216,3 +216,50 @@ def test_restricted_unpickler_raises_pickle_error():
unpickler = _RestrictedUnpickler(pickled, frozenset())
with pytest.raises(pickle.UnpicklingError, match="deserialization blocked"):
unpickler.load()
def test_restricted_decode_allows_openai_types():
"""OpenAI SDK types are always allowed during restricted deserialization."""
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.completion_usage import CompletionUsage
completion = ChatCompletion(
id="chatcmpl-test",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content="hello"),
)
],
created=1700000000,
model="gpt-4",
object="chat.completion",
usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2),
)
encoded = encode_checkpoint_value(completion)
decoded = decode_checkpoint_value(encoded, allowed_types=frozenset())
assert isinstance(decoded, ChatCompletion)
assert decoded.id == "chatcmpl-test"
assert decoded.choices[0].message.content == "hello"
def test_restricted_decode_allows_openai_response_types():
"""OpenAI Responses API types are always allowed during restricted deserialization."""
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails, ResponseUsage
usage = ResponseUsage(
input_tokens=10,
output_tokens=20,
total_tokens=30,
input_tokens_details=InputTokensDetails(cached_tokens=0),
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
)
encoded = encode_checkpoint_value(usage)
decoded = decode_checkpoint_value(encoded, allowed_types=frozenset())
assert isinstance(decoded, ResponseUsage)
assert decoded.input_tokens == 10
assert decoded.output_tokens == 20