diff --git a/python/packages/core/agent_framework/_workflows/_checkpoint.py b/python/packages/core/agent_framework/_workflows/_checkpoint.py index f9a940a7db..22b4a1ea24 100644 --- a/python/packages/core/agent_framework/_workflows/_checkpoint.py +++ b/python/packages/core/agent_framework/_workflows/_checkpoint.py @@ -244,10 +244,10 @@ class FileCheckpointStorage: is serialized using pickle and embedded as base64-encoded strings within the JSON. This allows for human-readable checkpoint files while preserving the ability to store complex Python objects. - By default, checkpoint deserialization is restricted to a built-in set of safe - Python types (primitives, datetime, uuid, ...) and all ``agent_framework`` - internal types. To allow additional application-specific types, pass them via - the ``allowed_checkpoint_types`` parameter using ``"module:qualname"`` format. + By default, checkpoint deserialization is restricted to a built-in set of safe Python types + (primitives, datetime, uuid, ...), all ``agent_framework`` internal types, and OpenAI SDK types + (``openai.types``). To allow additional application-specific types, pass them via the + ``allowed_checkpoint_types`` parameter using ``"module:qualname"`` format. Example:: diff --git a/python/packages/core/agent_framework/_workflows/_checkpoint_encoding.py b/python/packages/core/agent_framework/_workflows/_checkpoint_encoding.py index a25a08c66a..dd1fb3d704 100644 --- a/python/packages/core/agent_framework/_workflows/_checkpoint_encoding.py +++ b/python/packages/core/agent_framework/_workflows/_checkpoint_encoding.py @@ -10,9 +10,9 @@ This hybrid approach provides: When ``allowed_types`` is supplied to :func:`decode_checkpoint_value`, a ``RestrictedUnpickler`` is used that limits which classes may be instantiated during deserialization. The default built-in safe set covers common Python -value types (primitives, datetime, uuid, ...) and all ``agent_framework`` -internal types. Callers can extend the set by passing additional -``"module:qualname"`` strings. +value types (primitives, datetime, uuid, ...), all ``agent_framework`` internal +types, and all ``openai.types`` types. Callers can extend the set by passing +additional ``"module:qualname"`` strings. """ from __future__ import annotations @@ -37,6 +37,9 @@ _JSON_NATIVE_TYPES = (str, int, float, bool, type(None)) # Module prefix for framework-internal types that are always allowed _FRAMEWORK_MODULE_PREFIX = "agent_framework." +# Module prefix for OpenAI SDK types that are always allowed +_OPENAI_MODULE_PREFIX = "openai.types." + # Built-in types considered safe for checkpoint deserialization. # Each entry is a ``module:qualname`` string matching the format produced by # :func:`_type_to_key`. These are the classes for which pickle's @@ -84,8 +87,9 @@ class _RestrictedUnpickler(pickle.Unpickler): # noqa: S301 """Unpickler that restricts which classes may be instantiated. Only classes whose ``module:qualname`` key appears in the combined allow - set (built-in safe types + framework types + caller-specified extras) are - permitted. All other classes raise :class:`pickle.UnpicklingError`. + set (built-in safe types + framework types + OpenAI SDK types + + caller-specified extras) are permitted. All other classes raise + :class:`pickle.UnpicklingError`. """ def __init__(self, data: bytes, allowed_types: frozenset[str]) -> None: @@ -99,6 +103,7 @@ class _RestrictedUnpickler(pickle.Unpickler): # noqa: S301 type_key in _BUILTIN_ALLOWED_TYPE_KEYS or type_key in self._allowed_types or module.startswith(_FRAMEWORK_MODULE_PREFIX) + or module.startswith(_OPENAI_MODULE_PREFIX) ): return super().find_class(module, name) # type: ignore[no-any-return] # nosec diff --git a/python/packages/core/tests/workflow/test_checkpoint_unrestricted_pickle.py b/python/packages/core/tests/workflow/test_checkpoint_unrestricted_pickle.py index c70d8c85c3..77304841b2 100644 --- a/python/packages/core/tests/workflow/test_checkpoint_unrestricted_pickle.py +++ b/python/packages/core/tests/workflow/test_checkpoint_unrestricted_pickle.py @@ -216,3 +216,50 @@ def test_restricted_unpickler_raises_pickle_error(): unpickler = _RestrictedUnpickler(pickled, frozenset()) with pytest.raises(pickle.UnpicklingError, match="deserialization blocked"): unpickler.load() + + +def test_restricted_decode_allows_openai_types(): + """OpenAI SDK types are always allowed during restricted deserialization.""" + from openai.types.chat.chat_completion import ChatCompletion, Choice + from openai.types.chat.chat_completion_message import ChatCompletionMessage + from openai.types.completion_usage import CompletionUsage + + completion = ChatCompletion( + id="chatcmpl-test", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(role="assistant", content="hello"), + ) + ], + created=1700000000, + model="gpt-4", + object="chat.completion", + usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2), + ) + encoded = encode_checkpoint_value(completion) + decoded = decode_checkpoint_value(encoded, allowed_types=frozenset()) + + assert isinstance(decoded, ChatCompletion) + assert decoded.id == "chatcmpl-test" + assert decoded.choices[0].message.content == "hello" + + +def test_restricted_decode_allows_openai_response_types(): + """OpenAI Responses API types are always allowed during restricted deserialization.""" + from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails, ResponseUsage + + usage = ResponseUsage( + input_tokens=10, + output_tokens=20, + total_tokens=30, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), + ) + encoded = encode_checkpoint_value(usage) + decoded = decode_checkpoint_value(encoded, allowed_types=frozenset()) + + assert isinstance(decoded, ResponseUsage) + assert decoded.input_tokens == 10 + assert decoded.output_tokens == 20