mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add OpenAI types to default checkpoint encoding allow list (#5297)
* Add OpenAI types to default checkpoint encoding allow list * Address comments
This commit is contained in:
committed by
GitHub
Unverified
parent
69697065ab
commit
8f7fd9525d
@@ -244,10 +244,10 @@ class FileCheckpointStorage:
|
||||
is serialized using pickle and embedded as base64-encoded strings within the JSON. This allows
|
||||
for human-readable checkpoint files while preserving the ability to store complex Python objects.
|
||||
|
||||
By default, checkpoint deserialization is restricted to a built-in set of safe
|
||||
Python types (primitives, datetime, uuid, ...) and all ``agent_framework``
|
||||
internal types. To allow additional application-specific types, pass them via
|
||||
the ``allowed_checkpoint_types`` parameter using ``"module:qualname"`` format.
|
||||
By default, checkpoint deserialization is restricted to a built-in set of safe Python types
|
||||
(primitives, datetime, uuid, ...), all ``agent_framework`` internal types, and OpenAI SDK types
|
||||
(``openai.types``). To allow additional application-specific types, pass them via the
|
||||
``allowed_checkpoint_types`` parameter using ``"module:qualname"`` format.
|
||||
|
||||
Example::
|
||||
|
||||
|
||||
@@ -10,9 +10,9 @@ This hybrid approach provides:
|
||||
When ``allowed_types`` is supplied to :func:`decode_checkpoint_value`, a
|
||||
``RestrictedUnpickler`` is used that limits which classes may be instantiated
|
||||
during deserialization. The default built-in safe set covers common Python
|
||||
value types (primitives, datetime, uuid, ...) and all ``agent_framework``
|
||||
internal types. Callers can extend the set by passing additional
|
||||
``"module:qualname"`` strings.
|
||||
value types (primitives, datetime, uuid, ...), all ``agent_framework`` internal
|
||||
types, and all ``openai.types`` types. Callers can extend the set by passing
|
||||
additional ``"module:qualname"`` strings.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -37,6 +37,9 @@ _JSON_NATIVE_TYPES = (str, int, float, bool, type(None))
|
||||
# Module prefix for framework-internal types that are always allowed
|
||||
_FRAMEWORK_MODULE_PREFIX = "agent_framework."
|
||||
|
||||
# Module prefix for OpenAI SDK types that are always allowed
|
||||
_OPENAI_MODULE_PREFIX = "openai.types."
|
||||
|
||||
# Built-in types considered safe for checkpoint deserialization.
|
||||
# Each entry is a ``module:qualname`` string matching the format produced by
|
||||
# :func:`_type_to_key`. These are the classes for which pickle's
|
||||
@@ -84,8 +87,9 @@ class _RestrictedUnpickler(pickle.Unpickler): # noqa: S301
|
||||
"""Unpickler that restricts which classes may be instantiated.
|
||||
|
||||
Only classes whose ``module:qualname`` key appears in the combined allow
|
||||
set (built-in safe types + framework types + caller-specified extras) are
|
||||
permitted. All other classes raise :class:`pickle.UnpicklingError`.
|
||||
set (built-in safe types + framework types + OpenAI SDK types +
|
||||
caller-specified extras) are permitted. All other classes raise
|
||||
:class:`pickle.UnpicklingError`.
|
||||
"""
|
||||
|
||||
def __init__(self, data: bytes, allowed_types: frozenset[str]) -> None:
|
||||
@@ -99,6 +103,7 @@ class _RestrictedUnpickler(pickle.Unpickler): # noqa: S301
|
||||
type_key in _BUILTIN_ALLOWED_TYPE_KEYS
|
||||
or type_key in self._allowed_types
|
||||
or module.startswith(_FRAMEWORK_MODULE_PREFIX)
|
||||
or module.startswith(_OPENAI_MODULE_PREFIX)
|
||||
):
|
||||
return super().find_class(module, name) # type: ignore[no-any-return] # nosec
|
||||
|
||||
|
||||
@@ -216,3 +216,50 @@ def test_restricted_unpickler_raises_pickle_error():
|
||||
unpickler = _RestrictedUnpickler(pickled, frozenset())
|
||||
with pytest.raises(pickle.UnpicklingError, match="deserialization blocked"):
|
||||
unpickler.load()
|
||||
|
||||
|
||||
def test_restricted_decode_allows_openai_types():
|
||||
"""OpenAI SDK types are always allowed during restricted deserialization."""
|
||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
completion = ChatCompletion(
|
||||
id="chatcmpl-test",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content="hello"),
|
||||
)
|
||||
],
|
||||
created=1700000000,
|
||||
model="gpt-4",
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2),
|
||||
)
|
||||
encoded = encode_checkpoint_value(completion)
|
||||
decoded = decode_checkpoint_value(encoded, allowed_types=frozenset())
|
||||
|
||||
assert isinstance(decoded, ChatCompletion)
|
||||
assert decoded.id == "chatcmpl-test"
|
||||
assert decoded.choices[0].message.content == "hello"
|
||||
|
||||
|
||||
def test_restricted_decode_allows_openai_response_types():
|
||||
"""OpenAI Responses API types are always allowed during restricted deserialization."""
|
||||
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails, ResponseUsage
|
||||
|
||||
usage = ResponseUsage(
|
||||
input_tokens=10,
|
||||
output_tokens=20,
|
||||
total_tokens=30,
|
||||
input_tokens_details=InputTokensDetails(cached_tokens=0),
|
||||
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
|
||||
)
|
||||
encoded = encode_checkpoint_value(usage)
|
||||
decoded = decode_checkpoint_value(encoded, allowed_types=frozenset())
|
||||
|
||||
assert isinstance(decoded, ResponseUsage)
|
||||
assert decoded.input_tokens == 10
|
||||
assert decoded.output_tokens == 20
|
||||
|
||||
Reference in New Issue
Block a user