mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: improved thread serialization and deser with better tests (#1316)
* improved thread serialization and deser with better tests * update redis with the same
This commit is contained in:
committed by
GitHub
Unverified
parent
c2c8ec3d4e
commit
a36e183600
@@ -212,17 +212,18 @@ class SerializationMixin:
|
||||
|
||||
return result
|
||||
|
||||
def to_json(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> str:
|
||||
def to_json(self, *, exclude: set[str] | None = None, exclude_none: bool = True, **kwargs: Any) -> str:
|
||||
"""Convert the instance to a JSON string.
|
||||
|
||||
Keyword Args:
|
||||
exclude: The set of field names to exclude from serialization.
|
||||
exclude_none: Whether to exclude None values from the output. Defaults to True.
|
||||
**kwargs: passed through to the json.dumps method.
|
||||
|
||||
Returns:
|
||||
JSON string representation of the instance.
|
||||
"""
|
||||
return json.dumps(self.to_dict(exclude=exclude, exclude_none=exclude_none))
|
||||
return json.dumps(self.to_dict(exclude=exclude, exclude_none=exclude_none), **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import MutableMapping, Sequence
|
||||
from typing import Any, Protocol, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
from ._memory import AggregateContextProvider
|
||||
from ._serialization import SerializationMixin
|
||||
from ._types import ChatMessage
|
||||
from .exceptions import AgentThreadException
|
||||
|
||||
@@ -73,7 +72,9 @@ class ChatMessageStoreProtocol(Protocol):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
async def deserialize(cls, serialized_store_state: Any, **kwargs: Any) -> "ChatMessageStoreProtocol":
|
||||
async def deserialize(
|
||||
cls, serialized_store_state: MutableMapping[str, Any], **kwargs: Any
|
||||
) -> "ChatMessageStoreProtocol":
|
||||
"""Creates a new instance of the store from previously serialized state.
|
||||
|
||||
This method, together with ``serialize()`` can be used to save and load messages from a persistent store
|
||||
@@ -90,7 +91,7 @@ class ChatMessageStoreProtocol(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) -> None:
|
||||
async def update_from_state(self, serialized_store_state: MutableMapping[str, Any], **kwargs: Any) -> None:
|
||||
"""Update the current ChatMessageStore instance from serialized state data.
|
||||
|
||||
Args:
|
||||
@@ -101,7 +102,7 @@ class ChatMessageStoreProtocol(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
async def serialize(self, **kwargs: Any) -> Any:
|
||||
async def serialize(self, **kwargs: Any) -> dict[str, Any]:
|
||||
"""Serializes the current object's state.
|
||||
|
||||
This method, together with ``deserialize()`` can be used to save and load messages from a persistent store
|
||||
@@ -116,40 +117,66 @@ class ChatMessageStoreProtocol(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class ChatMessageStoreState(BaseModel):
|
||||
class ChatMessageStoreState(SerializationMixin):
|
||||
"""State model for serializing and deserializing chat message store data.
|
||||
|
||||
Attributes:
|
||||
messages: List of chat messages stored in the message store.
|
||||
"""
|
||||
|
||||
messages: list[ChatMessage]
|
||||
def __init__(
|
||||
self,
|
||||
messages: Sequence[ChatMessage] | Sequence[MutableMapping[str, Any]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create the store state.
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
Args:
|
||||
messages: a list of messages or a list of the dict representation of messages.
|
||||
|
||||
Keyword Args:
|
||||
**kwargs: not used for this, but might be used by subclasses.
|
||||
|
||||
"""
|
||||
if not messages:
|
||||
self.messages: list[ChatMessage] = []
|
||||
if not isinstance(messages, list):
|
||||
raise TypeError("Messages should be a list")
|
||||
new_messages: list[ChatMessage] = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, ChatMessage):
|
||||
new_messages.append(msg)
|
||||
else:
|
||||
new_messages.append(ChatMessage.from_dict(msg))
|
||||
self.messages = new_messages
|
||||
|
||||
|
||||
class AgentThreadState(BaseModel):
|
||||
"""State model for serializing and deserializing thread information.
|
||||
class AgentThreadState(SerializationMixin):
|
||||
"""State model for serializing and deserializing thread information."""
|
||||
|
||||
Attributes:
|
||||
service_thread_id: Optional ID of the thread managed by the agent service.
|
||||
chat_message_store_state: Optional serialized state of the chat message store.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
service_thread_id: str | None = None,
|
||||
chat_message_store_state: ChatMessageStoreState | MutableMapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Create a AgentThread state.
|
||||
|
||||
service_thread_id: str | None = None
|
||||
chat_message_store_state: ChatMessageStoreState | None = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_only_one(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
if (
|
||||
isinstance(values, dict)
|
||||
and values.get("service_thread_id") is not None
|
||||
and values.get("chat_message_store_state") is not None
|
||||
):
|
||||
raise AgentThreadException("Only one of service_thread_id or chat_message_store_state may be set.")
|
||||
return values
|
||||
Keyword Args:
|
||||
service_thread_id: Optional ID of the thread managed by the agent service.
|
||||
chat_message_store_state: Optional serialized state of the chat message store.
|
||||
"""
|
||||
if service_thread_id is not None and chat_message_store_state is not None:
|
||||
raise AgentThreadException("A thread cannot have both a service_thread_id and a chat_message_store.")
|
||||
self.service_thread_id = service_thread_id
|
||||
self.chat_message_store_state: ChatMessageStoreState | None = None
|
||||
if chat_message_store_state is not None:
|
||||
if isinstance(chat_message_store_state, dict):
|
||||
self.chat_message_store_state = ChatMessageStoreState.from_dict(chat_message_store_state)
|
||||
elif isinstance(chat_message_store_state, ChatMessageStoreState):
|
||||
self.chat_message_store_state = chat_message_store_state
|
||||
else:
|
||||
raise TypeError("Could not parse ChatMessageStoreState.")
|
||||
|
||||
|
||||
TChatMessageStore = TypeVar("TChatMessageStore", bound="ChatMessageStore")
|
||||
@@ -213,7 +240,7 @@ class ChatMessageStore:
|
||||
|
||||
@classmethod
|
||||
async def deserialize(
|
||||
cls: type[TChatMessageStore], serialized_store_state: Any, **kwargs: Any
|
||||
cls: type[TChatMessageStore], serialized_store_state: MutableMapping[str, Any], **kwargs: Any
|
||||
) -> TChatMessageStore:
|
||||
"""Create a new ChatMessageStore instance from serialized state data.
|
||||
|
||||
@@ -226,12 +253,12 @@ class ChatMessageStore:
|
||||
Returns:
|
||||
A new ChatMessageStore instance populated with messages from the serialized state.
|
||||
"""
|
||||
state = ChatMessageStoreState.model_validate(serialized_store_state, **kwargs)
|
||||
state = ChatMessageStoreState.from_dict(serialized_store_state, **kwargs)
|
||||
if state.messages:
|
||||
return cls(messages=state.messages)
|
||||
return cls()
|
||||
|
||||
async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) -> None:
|
||||
async def update_from_state(self, serialized_store_state: MutableMapping[str, Any], **kwargs: Any) -> None:
|
||||
"""Update the current ChatMessageStore instance from serialized state data.
|
||||
|
||||
Args:
|
||||
@@ -242,11 +269,11 @@ class ChatMessageStore:
|
||||
"""
|
||||
if not serialized_store_state:
|
||||
return
|
||||
state = ChatMessageStoreState.model_validate(serialized_store_state, **kwargs)
|
||||
state = ChatMessageStoreState.from_dict(serialized_store_state, **kwargs)
|
||||
if state.messages:
|
||||
self.messages = state.messages
|
||||
|
||||
async def serialize(self, **kwargs: Any) -> Any:
|
||||
async def serialize(self, **kwargs: Any) -> dict[str, Any]:
|
||||
"""Serialize the current store state for persistence.
|
||||
|
||||
Keyword Args:
|
||||
@@ -256,7 +283,7 @@ class ChatMessageStore:
|
||||
Serialized state data that can be used with deserialize_state.
|
||||
"""
|
||||
state = ChatMessageStoreState(messages=self.messages)
|
||||
return state.model_dump(**kwargs)
|
||||
return state.to_dict()
|
||||
|
||||
|
||||
TAgentThread = TypeVar("TAgentThread", bound="AgentThread")
|
||||
@@ -403,12 +430,12 @@ class AgentThread:
|
||||
state = AgentThreadState(
|
||||
service_thread_id=self._service_thread_id, chat_message_store_state=chat_message_store_state
|
||||
)
|
||||
return state.model_dump()
|
||||
return state.to_dict(exclude_none=False)
|
||||
|
||||
@classmethod
|
||||
async def deserialize(
|
||||
cls: type[TAgentThread],
|
||||
serialized_thread_state: dict[str, Any],
|
||||
serialized_thread_state: MutableMapping[str, Any],
|
||||
*,
|
||||
message_store: ChatMessageStoreProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
@@ -426,7 +453,7 @@ class AgentThread:
|
||||
Returns:
|
||||
A new AgentThread instance with properties set from the serialized state.
|
||||
"""
|
||||
state = AgentThreadState.model_validate(serialized_thread_state)
|
||||
state = AgentThreadState.from_dict(serialized_thread_state)
|
||||
|
||||
if state.service_thread_id is not None:
|
||||
return cls(service_thread_id=state.service_thread_id)
|
||||
@@ -437,19 +464,19 @@ class AgentThread:
|
||||
|
||||
if message_store is not None:
|
||||
try:
|
||||
await message_store.update_from_state(state.chat_message_store_state, **kwargs)
|
||||
await message_store.add_messages(state.chat_message_store_state.messages, **kwargs)
|
||||
except Exception as ex:
|
||||
raise AgentThreadException("Failed to deserialize the provided message store.") from ex
|
||||
return cls(message_store=message_store)
|
||||
try:
|
||||
message_store = await ChatMessageStore.deserialize(state.chat_message_store_state, **kwargs)
|
||||
message_store = ChatMessageStore(messages=state.chat_message_store_state.messages, **kwargs)
|
||||
except Exception as ex:
|
||||
raise AgentThreadException("Failed to deserialize the message store.") from ex
|
||||
return cls(message_store=message_store)
|
||||
|
||||
async def update_from_thread_state(
|
||||
self,
|
||||
serialized_thread_state: dict[str, Any],
|
||||
serialized_thread_state: MutableMapping[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Deserializes the state from a dictionary into the thread properties.
|
||||
@@ -460,7 +487,7 @@ class AgentThread:
|
||||
Keyword Args:
|
||||
**kwargs: Additional arguments for deserialization.
|
||||
"""
|
||||
state = AgentThreadState.model_validate(serialized_thread_state)
|
||||
state = AgentThreadState.from_dict(serialized_thread_state)
|
||||
|
||||
if state.service_thread_id is not None:
|
||||
self.service_thread_id = state.service_thread_id
|
||||
@@ -470,8 +497,8 @@ class AgentThread:
|
||||
if state.chat_message_store_state is None:
|
||||
return
|
||||
if self.message_store is not None:
|
||||
await self.message_store.update_from_state(state.chat_message_store_state, **kwargs)
|
||||
await self.message_store.add_messages(state.chat_message_store_state.messages, **kwargs)
|
||||
# If we don't have a chat message store yet, create an in-memory one.
|
||||
return
|
||||
# Create the message store from the default.
|
||||
self.message_store = await ChatMessageStore.deserialize(state.chat_message_store_state, **kwargs) # type: ignore
|
||||
self.message_store = ChatMessageStore(messages=state.chat_message_store_state.messages, **kwargs)
|
||||
|
||||
@@ -224,11 +224,15 @@ class TestAgentThread:
|
||||
"""Test _deserialize with existing message store."""
|
||||
store = MockChatMessageStore()
|
||||
thread = AgentThread(message_store=store)
|
||||
serialized_data: dict[str, Any] = {"service_thread_id": None, "chat_message_store_state": {"messages": []}}
|
||||
serialized_data: dict[str, Any] = {
|
||||
"service_thread_id": None,
|
||||
"chat_message_store_state": {"messages": [ChatMessage(role="user", text="test")]},
|
||||
}
|
||||
|
||||
await thread.update_from_thread_state(serialized_data)
|
||||
|
||||
assert store._deserialize_calls == 1 # pyright: ignore[reportPrivateUsage]
|
||||
assert store._messages
|
||||
assert store._messages[0].text == "test"
|
||||
|
||||
async def test_serialize_with_service_thread_id(self) -> None:
|
||||
"""Test serialize with service_thread_id."""
|
||||
@@ -268,6 +272,23 @@ class TestAgentThread:
|
||||
|
||||
assert store._serialize_calls == 1 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
async def test_serialize_round_trip_messages(self, sample_messages: list[ChatMessage]) -> None:
|
||||
"""Test a roundtrip of the serialization."""
|
||||
store = ChatMessageStore(sample_messages)
|
||||
thread = AgentThread(message_store=store)
|
||||
new_thread = await AgentThread.deserialize(await thread.serialize())
|
||||
assert new_thread.message_store is not None
|
||||
new_messages = await new_thread.message_store.list_messages()
|
||||
assert len(new_messages) == len(sample_messages)
|
||||
assert {new.text for new in new_messages} == {orig.text for orig in sample_messages}
|
||||
|
||||
async def test_serialize_round_trip_thread_id(self) -> None:
|
||||
"""Test a roundtrip of the serialization."""
|
||||
thread = AgentThread(service_thread_id="test-1234")
|
||||
new_thread = await AgentThread.deserialize(await thread.serialize())
|
||||
assert new_thread.message_store is None
|
||||
assert new_thread.service_thread_id == "test-1234"
|
||||
|
||||
|
||||
class TestChatMessageList:
|
||||
"""Test cases for ChatMessageStore class."""
|
||||
@@ -377,7 +398,7 @@ class TestThreadState:
|
||||
def test_init_with_chat_message_store_state(self) -> None:
|
||||
"""Test AgentThreadState initialization with chat_message_store_state."""
|
||||
store_data: dict[str, Any] = {"messages": []}
|
||||
state = AgentThreadState.model_validate({"chat_message_store_state": store_data})
|
||||
state = AgentThreadState.from_dict({"chat_message_store_state": store_data})
|
||||
|
||||
assert state.service_thread_id is None
|
||||
assert state.chat_message_store_state.messages == []
|
||||
@@ -385,9 +406,7 @@ class TestThreadState:
|
||||
def test_init_with_both(self) -> None:
|
||||
"""Test AgentThreadState initialization with both parameters."""
|
||||
store_data: dict[str, Any] = {"messages": []}
|
||||
with pytest.raises(
|
||||
AgentThreadException, match="Only one of service_thread_id or chat_message_store_state may be set"
|
||||
):
|
||||
with pytest.raises(AgentThreadException):
|
||||
AgentThreadState(service_thread_id="test-conv-123", chat_message_store_state=store_data)
|
||||
|
||||
def test_init_defaults(self) -> None:
|
||||
|
||||
@@ -2,23 +2,30 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import redis.asyncio as redis
|
||||
from agent_framework import ChatMessage
|
||||
from pydantic import BaseModel
|
||||
from agent_framework._serialization import SerializationMixin
|
||||
|
||||
|
||||
class RedisStoreState(BaseModel):
|
||||
class RedisStoreState(SerializationMixin):
|
||||
"""State model for serializing and deserializing Redis chat message store data."""
|
||||
|
||||
thread_id: str
|
||||
redis_url: str | None = None
|
||||
key_prefix: str = "chat_messages"
|
||||
max_messages: int | None = None
|
||||
def __init__(
|
||||
self,
|
||||
thread_id: str,
|
||||
redis_url: str | None = None,
|
||||
key_prefix: str = "chat_messages",
|
||||
max_messages: int | None = None,
|
||||
) -> None:
|
||||
"""State model for serializing and deserializing Redis chat message store data."""
|
||||
self.thread_id = thread_id
|
||||
self.redis_url = redis_url
|
||||
self.key_prefix = key_prefix
|
||||
self.max_messages = max_messages
|
||||
|
||||
|
||||
class RedisChatMessageStore:
|
||||
@@ -241,7 +248,7 @@ class RedisChatMessageStore:
|
||||
key_prefix=self.key_prefix,
|
||||
max_messages=self.max_messages,
|
||||
)
|
||||
return state.model_dump(**kwargs)
|
||||
return state.to_dict(exclude_none=False, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def deserialize(cls, serialized_store_state: Any, **kwargs: Any) -> RedisChatMessageStore:
|
||||
@@ -268,7 +275,7 @@ class RedisChatMessageStore:
|
||||
raise ValueError("serialized_store_state is required for deserialization")
|
||||
|
||||
# Validate and parse the serialized state using Pydantic
|
||||
state = RedisStoreState.model_validate(serialized_store_state, **kwargs)
|
||||
state = RedisStoreState.from_dict(serialized_store_state, **kwargs)
|
||||
|
||||
# Create and return a new store instance with the deserialized configuration
|
||||
return cls(
|
||||
@@ -296,7 +303,7 @@ class RedisChatMessageStore:
|
||||
return
|
||||
|
||||
# Validate and parse the serialized state using Pydantic
|
||||
state = RedisStoreState.model_validate(serialized_store_state, **kwargs)
|
||||
state = RedisStoreState.from_dict(serialized_store_state, **kwargs)
|
||||
|
||||
# Update store configuration from deserialized state
|
||||
self.thread_id = state.thread_id
|
||||
@@ -344,10 +351,8 @@ class RedisChatMessageStore:
|
||||
Returns:
|
||||
JSON string representation of the message.
|
||||
"""
|
||||
# Convert ChatMessage to dictionary using custom serialization
|
||||
message_dict = message.to_dict()
|
||||
# Serialize to compact JSON (no extra whitespace for Redis efficiency)
|
||||
return json.dumps(message_dict, separators=(",", ":"))
|
||||
return message.to_json(separators=(",", ":"))
|
||||
|
||||
def _deserialize_message(self, serialized_message: str) -> ChatMessage:
|
||||
"""Deserialize a JSON string to ChatMessage.
|
||||
@@ -358,10 +363,8 @@ class RedisChatMessageStore:
|
||||
Returns:
|
||||
ChatMessage object.
|
||||
"""
|
||||
# Parse JSON string back to dictionary
|
||||
message_dict = json.loads(serialized_message)
|
||||
# Reconstruct ChatMessage using custom deserialization
|
||||
return ChatMessage.from_dict(message_dict)
|
||||
return ChatMessage.from_json(serialized_message)
|
||||
|
||||
# ============================================================================
|
||||
# List-like Convenience Methods (Redis-optimized async versions)
|
||||
|
||||
@@ -242,6 +242,7 @@ class TestRedisChatMessageStore:
|
||||
state = await redis_store.serialize()
|
||||
|
||||
expected_state = {
|
||||
"type": "redis_store_state",
|
||||
"thread_id": "test_thread_123",
|
||||
"redis_url": "redis://localhost:6379",
|
||||
"key_prefix": "chat_messages",
|
||||
|
||||
Reference in New Issue
Block a user