Python: improved thread serialization and deser with better tests (#1316)

* improved thread serialization and deser with better tests

* update redis with the same
This commit is contained in:
Eduard van Valkenburg
2025-10-08 21:04:11 +02:00
committed by GitHub
Unverified
parent c2c8ec3d4e
commit a36e183600
5 changed files with 119 additions and 68 deletions
@@ -212,17 +212,18 @@ class SerializationMixin:
return result
def to_json(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> str:
def to_json(self, *, exclude: set[str] | None = None, exclude_none: bool = True, **kwargs: Any) -> str:
"""Convert the instance to a JSON string.
Keyword Args:
exclude: The set of field names to exclude from serialization.
exclude_none: Whether to exclude None values from the output. Defaults to True.
**kwargs: passed through to the json.dumps method.
Returns:
JSON string representation of the instance.
"""
return json.dumps(self.to_dict(exclude=exclude, exclude_none=exclude_none))
return json.dumps(self.to_dict(exclude=exclude, exclude_none=exclude_none), **kwargs)
@classmethod
def from_dict(
@@ -1,11 +1,10 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import Sequence
from collections.abc import MutableMapping, Sequence
from typing import Any, Protocol, TypeVar
from pydantic import BaseModel, ConfigDict, model_validator
from ._memory import AggregateContextProvider
from ._serialization import SerializationMixin
from ._types import ChatMessage
from .exceptions import AgentThreadException
@@ -73,7 +72,9 @@ class ChatMessageStoreProtocol(Protocol):
...
@classmethod
async def deserialize(cls, serialized_store_state: Any, **kwargs: Any) -> "ChatMessageStoreProtocol":
async def deserialize(
cls, serialized_store_state: MutableMapping[str, Any], **kwargs: Any
) -> "ChatMessageStoreProtocol":
"""Creates a new instance of the store from previously serialized state.
This method, together with ``serialize()`` can be used to save and load messages from a persistent store
@@ -90,7 +91,7 @@ class ChatMessageStoreProtocol(Protocol):
"""
...
async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) -> None:
async def update_from_state(self, serialized_store_state: MutableMapping[str, Any], **kwargs: Any) -> None:
"""Update the current ChatMessageStore instance from serialized state data.
Args:
@@ -101,7 +102,7 @@ class ChatMessageStoreProtocol(Protocol):
"""
...
async def serialize(self, **kwargs: Any) -> Any:
async def serialize(self, **kwargs: Any) -> dict[str, Any]:
"""Serializes the current object's state.
This method, together with ``deserialize()`` can be used to save and load messages from a persistent store
@@ -116,40 +117,66 @@ class ChatMessageStoreProtocol(Protocol):
...
class ChatMessageStoreState(BaseModel):
class ChatMessageStoreState(SerializationMixin):
"""State model for serializing and deserializing chat message store data.
Attributes:
messages: List of chat messages stored in the message store.
"""
messages: list[ChatMessage]
def __init__(
self,
messages: Sequence[ChatMessage] | Sequence[MutableMapping[str, Any]] | None = None,
**kwargs: Any,
) -> None:
"""Create the store state.
model_config = ConfigDict(arbitrary_types_allowed=True)
Args:
messages: a list of messages or a list of the dict representation of messages.
Keyword Args:
**kwargs: not used for this, but might be used by subclasses.
"""
if not messages:
self.messages: list[ChatMessage] = []
if not isinstance(messages, list):
raise TypeError("Messages should be a list")
new_messages: list[ChatMessage] = []
for msg in messages:
if isinstance(msg, ChatMessage):
new_messages.append(msg)
else:
new_messages.append(ChatMessage.from_dict(msg))
self.messages = new_messages
class AgentThreadState(BaseModel):
"""State model for serializing and deserializing thread information.
class AgentThreadState(SerializationMixin):
"""State model for serializing and deserializing thread information."""
Attributes:
service_thread_id: Optional ID of the thread managed by the agent service.
chat_message_store_state: Optional serialized state of the chat message store.
"""
def __init__(
self,
*,
service_thread_id: str | None = None,
chat_message_store_state: ChatMessageStoreState | MutableMapping[str, Any] | None = None,
) -> None:
"""Create a AgentThread state.
service_thread_id: str | None = None
chat_message_store_state: ChatMessageStoreState | None = None
model_config = ConfigDict(arbitrary_types_allowed=True)
@model_validator(mode="before")
def validate_only_one(cls, values: dict[str, Any]) -> dict[str, Any]:
if (
isinstance(values, dict)
and values.get("service_thread_id") is not None
and values.get("chat_message_store_state") is not None
):
raise AgentThreadException("Only one of service_thread_id or chat_message_store_state may be set.")
return values
Keyword Args:
service_thread_id: Optional ID of the thread managed by the agent service.
chat_message_store_state: Optional serialized state of the chat message store.
"""
if service_thread_id is not None and chat_message_store_state is not None:
raise AgentThreadException("A thread cannot have both a service_thread_id and a chat_message_store.")
self.service_thread_id = service_thread_id
self.chat_message_store_state: ChatMessageStoreState | None = None
if chat_message_store_state is not None:
if isinstance(chat_message_store_state, dict):
self.chat_message_store_state = ChatMessageStoreState.from_dict(chat_message_store_state)
elif isinstance(chat_message_store_state, ChatMessageStoreState):
self.chat_message_store_state = chat_message_store_state
else:
raise TypeError("Could not parse ChatMessageStoreState.")
TChatMessageStore = TypeVar("TChatMessageStore", bound="ChatMessageStore")
@@ -213,7 +240,7 @@ class ChatMessageStore:
@classmethod
async def deserialize(
cls: type[TChatMessageStore], serialized_store_state: Any, **kwargs: Any
cls: type[TChatMessageStore], serialized_store_state: MutableMapping[str, Any], **kwargs: Any
) -> TChatMessageStore:
"""Create a new ChatMessageStore instance from serialized state data.
@@ -226,12 +253,12 @@ class ChatMessageStore:
Returns:
A new ChatMessageStore instance populated with messages from the serialized state.
"""
state = ChatMessageStoreState.model_validate(serialized_store_state, **kwargs)
state = ChatMessageStoreState.from_dict(serialized_store_state, **kwargs)
if state.messages:
return cls(messages=state.messages)
return cls()
async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) -> None:
async def update_from_state(self, serialized_store_state: MutableMapping[str, Any], **kwargs: Any) -> None:
"""Update the current ChatMessageStore instance from serialized state data.
Args:
@@ -242,11 +269,11 @@ class ChatMessageStore:
"""
if not serialized_store_state:
return
state = ChatMessageStoreState.model_validate(serialized_store_state, **kwargs)
state = ChatMessageStoreState.from_dict(serialized_store_state, **kwargs)
if state.messages:
self.messages = state.messages
async def serialize(self, **kwargs: Any) -> Any:
async def serialize(self, **kwargs: Any) -> dict[str, Any]:
"""Serialize the current store state for persistence.
Keyword Args:
@@ -256,7 +283,7 @@ class ChatMessageStore:
Serialized state data that can be used with deserialize_state.
"""
state = ChatMessageStoreState(messages=self.messages)
return state.model_dump(**kwargs)
return state.to_dict()
TAgentThread = TypeVar("TAgentThread", bound="AgentThread")
@@ -403,12 +430,12 @@ class AgentThread:
state = AgentThreadState(
service_thread_id=self._service_thread_id, chat_message_store_state=chat_message_store_state
)
return state.model_dump()
return state.to_dict(exclude_none=False)
@classmethod
async def deserialize(
cls: type[TAgentThread],
serialized_thread_state: dict[str, Any],
serialized_thread_state: MutableMapping[str, Any],
*,
message_store: ChatMessageStoreProtocol | None = None,
**kwargs: Any,
@@ -426,7 +453,7 @@ class AgentThread:
Returns:
A new AgentThread instance with properties set from the serialized state.
"""
state = AgentThreadState.model_validate(serialized_thread_state)
state = AgentThreadState.from_dict(serialized_thread_state)
if state.service_thread_id is not None:
return cls(service_thread_id=state.service_thread_id)
@@ -437,19 +464,19 @@ class AgentThread:
if message_store is not None:
try:
await message_store.update_from_state(state.chat_message_store_state, **kwargs)
await message_store.add_messages(state.chat_message_store_state.messages, **kwargs)
except Exception as ex:
raise AgentThreadException("Failed to deserialize the provided message store.") from ex
return cls(message_store=message_store)
try:
message_store = await ChatMessageStore.deserialize(state.chat_message_store_state, **kwargs)
message_store = ChatMessageStore(messages=state.chat_message_store_state.messages, **kwargs)
except Exception as ex:
raise AgentThreadException("Failed to deserialize the message store.") from ex
return cls(message_store=message_store)
async def update_from_thread_state(
self,
serialized_thread_state: dict[str, Any],
serialized_thread_state: MutableMapping[str, Any],
**kwargs: Any,
) -> None:
"""Deserializes the state from a dictionary into the thread properties.
@@ -460,7 +487,7 @@ class AgentThread:
Keyword Args:
**kwargs: Additional arguments for deserialization.
"""
state = AgentThreadState.model_validate(serialized_thread_state)
state = AgentThreadState.from_dict(serialized_thread_state)
if state.service_thread_id is not None:
self.service_thread_id = state.service_thread_id
@@ -470,8 +497,8 @@ class AgentThread:
if state.chat_message_store_state is None:
return
if self.message_store is not None:
await self.message_store.update_from_state(state.chat_message_store_state, **kwargs)
await self.message_store.add_messages(state.chat_message_store_state.messages, **kwargs)
# If we don't have a chat message store yet, create an in-memory one.
return
# Create the message store from the default.
self.message_store = await ChatMessageStore.deserialize(state.chat_message_store_state, **kwargs) # type: ignore
self.message_store = ChatMessageStore(messages=state.chat_message_store_state.messages, **kwargs)
@@ -224,11 +224,15 @@ class TestAgentThread:
"""Test _deserialize with existing message store."""
store = MockChatMessageStore()
thread = AgentThread(message_store=store)
serialized_data: dict[str, Any] = {"service_thread_id": None, "chat_message_store_state": {"messages": []}}
serialized_data: dict[str, Any] = {
"service_thread_id": None,
"chat_message_store_state": {"messages": [ChatMessage(role="user", text="test")]},
}
await thread.update_from_thread_state(serialized_data)
assert store._deserialize_calls == 1 # pyright: ignore[reportPrivateUsage]
assert store._messages
assert store._messages[0].text == "test"
async def test_serialize_with_service_thread_id(self) -> None:
"""Test serialize with service_thread_id."""
@@ -268,6 +272,23 @@ class TestAgentThread:
assert store._serialize_calls == 1 # pyright: ignore[reportPrivateUsage]
async def test_serialize_round_trip_messages(self, sample_messages: list[ChatMessage]) -> None:
"""Test a roundtrip of the serialization."""
store = ChatMessageStore(sample_messages)
thread = AgentThread(message_store=store)
new_thread = await AgentThread.deserialize(await thread.serialize())
assert new_thread.message_store is not None
new_messages = await new_thread.message_store.list_messages()
assert len(new_messages) == len(sample_messages)
assert {new.text for new in new_messages} == {orig.text for orig in sample_messages}
async def test_serialize_round_trip_thread_id(self) -> None:
"""Test a roundtrip of the serialization."""
thread = AgentThread(service_thread_id="test-1234")
new_thread = await AgentThread.deserialize(await thread.serialize())
assert new_thread.message_store is None
assert new_thread.service_thread_id == "test-1234"
class TestChatMessageList:
"""Test cases for ChatMessageStore class."""
@@ -377,7 +398,7 @@ class TestThreadState:
def test_init_with_chat_message_store_state(self) -> None:
"""Test AgentThreadState initialization with chat_message_store_state."""
store_data: dict[str, Any] = {"messages": []}
state = AgentThreadState.model_validate({"chat_message_store_state": store_data})
state = AgentThreadState.from_dict({"chat_message_store_state": store_data})
assert state.service_thread_id is None
assert state.chat_message_store_state.messages == []
@@ -385,9 +406,7 @@ class TestThreadState:
def test_init_with_both(self) -> None:
"""Test AgentThreadState initialization with both parameters."""
store_data: dict[str, Any] = {"messages": []}
with pytest.raises(
AgentThreadException, match="Only one of service_thread_id or chat_message_store_state may be set"
):
with pytest.raises(AgentThreadException):
AgentThreadState(service_thread_id="test-conv-123", chat_message_store_state=store_data)
def test_init_defaults(self) -> None:
@@ -2,23 +2,30 @@
from __future__ import annotations
import json
from collections.abc import Sequence
from typing import Any
from uuid import uuid4
import redis.asyncio as redis
from agent_framework import ChatMessage
from pydantic import BaseModel
from agent_framework._serialization import SerializationMixin
class RedisStoreState(BaseModel):
class RedisStoreState(SerializationMixin):
"""State model for serializing and deserializing Redis chat message store data."""
thread_id: str
redis_url: str | None = None
key_prefix: str = "chat_messages"
max_messages: int | None = None
def __init__(
self,
thread_id: str,
redis_url: str | None = None,
key_prefix: str = "chat_messages",
max_messages: int | None = None,
) -> None:
"""State model for serializing and deserializing Redis chat message store data."""
self.thread_id = thread_id
self.redis_url = redis_url
self.key_prefix = key_prefix
self.max_messages = max_messages
class RedisChatMessageStore:
@@ -241,7 +248,7 @@ class RedisChatMessageStore:
key_prefix=self.key_prefix,
max_messages=self.max_messages,
)
return state.model_dump(**kwargs)
return state.to_dict(exclude_none=False, **kwargs)
@classmethod
async def deserialize(cls, serialized_store_state: Any, **kwargs: Any) -> RedisChatMessageStore:
@@ -268,7 +275,7 @@ class RedisChatMessageStore:
raise ValueError("serialized_store_state is required for deserialization")
# Validate and parse the serialized state using Pydantic
state = RedisStoreState.model_validate(serialized_store_state, **kwargs)
state = RedisStoreState.from_dict(serialized_store_state, **kwargs)
# Create and return a new store instance with the deserialized configuration
return cls(
@@ -296,7 +303,7 @@ class RedisChatMessageStore:
return
# Validate and parse the serialized state using Pydantic
state = RedisStoreState.model_validate(serialized_store_state, **kwargs)
state = RedisStoreState.from_dict(serialized_store_state, **kwargs)
# Update store configuration from deserialized state
self.thread_id = state.thread_id
@@ -344,10 +351,8 @@ class RedisChatMessageStore:
Returns:
JSON string representation of the message.
"""
# Convert ChatMessage to dictionary using custom serialization
message_dict = message.to_dict()
# Serialize to compact JSON (no extra whitespace for Redis efficiency)
return json.dumps(message_dict, separators=(",", ":"))
return message.to_json(separators=(",", ":"))
def _deserialize_message(self, serialized_message: str) -> ChatMessage:
"""Deserialize a JSON string to ChatMessage.
@@ -358,10 +363,8 @@ class RedisChatMessageStore:
Returns:
ChatMessage object.
"""
# Parse JSON string back to dictionary
message_dict = json.loads(serialized_message)
# Reconstruct ChatMessage using custom deserialization
return ChatMessage.from_dict(message_dict)
return ChatMessage.from_json(serialized_message)
# ============================================================================
# List-like Convenience Methods (Redis-optimized async versions)
@@ -242,6 +242,7 @@ class TestRedisChatMessageStore:
state = await redis_store.serialize()
expected_state = {
"type": "redis_store_state",
"thread_id": "test_thread_123",
"redis_url": "redis://localhost:6379",
"key_prefix": "chat_messages",