Files
agent-framework/python/packages/main/agent_framework/_threads.py
T
Eduard van Valkenburg 10d10364a9 Python: [BREAKING] cleanup of thread API and serialization (#893)
* cleanup of threads and serialization

* fix for sliding window

* fix redis test

* updated from comments

* updated context provider and threads

* updated lock

* add asyncio default

* fix redis tests

* fix tests

* fix tests

* renamed to invoking

* fixed tests

* fix for instructions
2025-09-29 16:22:34 +00:00

356 lines
14 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
from collections.abc import Sequence
from typing import Any, Protocol, TypeVar
from pydantic import model_validator
from ._memory import AggregateContextProvider
from ._pydantic import AFBaseModel
from ._types import ChatMessage
from .exceptions import AgentThreadException
__all__ = ["AgentThread", "ChatMessageStore", "ChatMessageStoreProtocol"]
class ChatMessageStoreProtocol(Protocol):
"""Defines methods for storing and retrieving chat messages associated with a specific thread.
Implementations of this protocol are responsible for managing the storage of chat messages,
including handling large volumes of data by truncating or summarizing messages as necessary.
"""
async def list_messages(self) -> list[ChatMessage]:
"""Gets all the messages from the store that should be used for the next agent invocation.
Messages are returned in ascending chronological order, with the oldest message first.
If the messages stored in the store become very large, it is up to the store to
truncate, summarize or otherwise limit the number of messages returned.
When using implementations of ChatMessageStoreProtocol, a new one should be created for each thread
since they may contain state that is specific to a thread.
"""
...
async def add_messages(self, messages: Sequence[ChatMessage]) -> None:
"""Adds messages to the store."""
...
@classmethod
async def deserialize(cls, serialized_store_state: Any, **kwargs: Any) -> "ChatMessageStoreProtocol":
"""Creates a new instance of the store from previously serialized state.
This method, together with serialize_state can be used to save and load messages from a persistent store
if this store only has messages in memory.
"""
...
async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) -> None:
"""Update the current ChatMessageStore instance from serialized state data.
Args:
serialized_store_state: Previously serialized state data containing messages.
**kwargs: Additional arguments for deserialization.
"""
...
async def serialize(self, **kwargs: Any) -> Any:
"""Serializes the current object's state.
This method, together with deserialize can be used to save and load messages from a persistent store
if this store only has messages in memory.
"""
...
class AgentThreadState(AFBaseModel):
"""State model for serializing and deserializing thread information.
Attributes:
service_thread_id: Optional ID of the thread managed by the agent service.
chat_message_store_state: Optional serialized state of the chat message store.
"""
service_thread_id: str | None = None
chat_message_store_state: Any | None = None
@model_validator(mode="before")
def validate_only_one(cls, values: dict[str, Any]) -> dict[str, Any]:
if (
isinstance(values, dict)
and values.get("service_thread_id") is not None
and values.get("chat_message_store_state") is not None
):
raise AgentThreadException("Only one of service_thread_id or chat_message_store_state may be set.")
return values
class ChatMessageStoreState(AFBaseModel):
"""State model for serializing and deserializing chat message store data.
Attributes:
messages: List of chat messages stored in the message store.
"""
messages: list[ChatMessage]
TChatMessageStore = TypeVar("TChatMessageStore", bound="ChatMessageStore")
class ChatMessageStore:
"""An in-memory implementation of ChatMessageStoreProtocol that stores messages in a list.
This implementation provides a simple, list-based storage for chat messages
with support for serialization and deserialization. It implements all the
required methods of the ChatMessageStoreProtocol protocol.
The store maintains messages in memory and provides methods to serialize
and deserialize the state for persistence purposes.
Args:
messages: Optional initial list of ChatMessage objects to populate the store.
"""
def __init__(self, messages: Sequence[ChatMessage] | None = None):
"""Create a ChatMessageStore for use in a thread.
Args:
messages: The messages to store.
"""
self.messages = list(messages) if messages else []
async def add_messages(self, messages: Sequence[ChatMessage]) -> None:
"""Add messages to the store.
Args:
messages: Sequence of ChatMessage objects to add to the store.
"""
self.messages.extend(messages)
async def list_messages(self) -> list[ChatMessage]:
"""Get all messages from the store in chronological order.
Returns:
List of ChatMessage objects, ordered from oldest to newest.
"""
return self.messages
@classmethod
async def deserialize(
cls: type[TChatMessageStore], serialized_store_state: Any, **kwargs: Any
) -> TChatMessageStore:
"""Create a new ChatMessageStore instance from serialized state data.
Args:
serialized_store_state: Previously serialized state data containing messages.
**kwargs: Additional arguments for deserialization.
Returns:
A new ChatMessageStore instance populated with messages from the serialized state.
"""
state = ChatMessageStoreState.model_validate(serialized_store_state, **kwargs)
if state.messages:
return cls(messages=state.messages)
return cls()
async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) -> None:
"""Update the current ChatMessageStore instance from serialized state data.
Args:
serialized_store_state: Previously serialized state data containing messages.
**kwargs: Additional arguments for deserialization.
"""
if not serialized_store_state:
return
state = ChatMessageStoreState.model_validate(serialized_store_state, **kwargs)
if state.messages:
self.messages = state.messages
async def serialize(self, **kwargs: Any) -> Any:
"""Serialize the current store state for persistence.
Args:
**kwargs: Additional arguments for serialization.
Returns:
Serialized state data that can be used with deserialize_state.
"""
state = ChatMessageStoreState(messages=self.messages)
return state.model_dump(**kwargs)
TAgentThread = TypeVar("TAgentThread", bound="AgentThread")
class AgentThread:
"""The Agent thread class, this can represent both a locally managed thread or a thread managed by the service."""
def __init__(
self,
*,
service_thread_id: str | None = None,
message_store: ChatMessageStoreProtocol | None = None,
context_provider: AggregateContextProvider | None = None,
) -> None:
"""Initialize an AgentThread, do not use this method manually, always use: agent.get_new_thread().
Args:
service_thread_id: Optional ID of the thread managed by the agent service.
message_store: Optional ChatMessageStore implementation for managing chat messages.
context_provider: Optional ContextProvider for the thread.
Note:
Either service_thread_id or message_store may be set, but not both.
"""
if service_thread_id is not None and message_store is not None:
raise AgentThreadException("Only the service_thread_id or message_store may be set, but not both.")
self._service_thread_id = service_thread_id
self._message_store = message_store
self.context_provider = context_provider
@property
def is_initialized(self) -> bool:
"""Indicates if the thread is initialized.
This means either the service_thread_id or the message_store is set.
"""
return self._service_thread_id is not None or self._message_store is not None
@property
def service_thread_id(self) -> str | None:
"""Gets the ID of the current thread to support cases where the thread is owned by the agent service."""
return self._service_thread_id
@service_thread_id.setter
def service_thread_id(self, service_thread_id: str | None) -> None:
"""Sets the ID of the current thread to support cases where the thread is owned by the agent service.
Note that either service_thread_id or message_store may be set, but not both.
"""
if service_thread_id is None:
return
if self._message_store is not None:
raise AgentThreadException(
"Only the service_thread_id or message_store may be set, "
"but not both and switching from one to another is not supported."
)
self._service_thread_id = service_thread_id
@property
def message_store(self) -> ChatMessageStoreProtocol | None:
"""Gets the ChatMessageStoreProtocol used by this thread."""
return self._message_store
@message_store.setter
def message_store(self, message_store: ChatMessageStoreProtocol | None) -> None:
"""Sets the ChatMessageStoreProtocol used by this thread.
Note that either service_thread_id or message_store may be set, but not both.
"""
if message_store is None:
return
if self._service_thread_id is not None:
raise AgentThreadException(
"Only the service_thread_id or message_store may be set, "
"but not both and switching from one to another is not supported."
)
self._message_store = message_store
async def on_new_messages(self, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
"""Invoked when a new message has been contributed to the chat by any participant."""
if self._service_thread_id is not None:
# If the thread messages are stored in the service there is nothing to do here,
# since invoking the service should already update the thread.
return
if self._message_store is None:
# If there is no conversation id, and no store we can
# create a default in memory store.
self._message_store = ChatMessageStore()
# If a store has been provided, we need to add the messages to the store.
if isinstance(new_messages, ChatMessage):
new_messages = [new_messages]
await self._message_store.add_messages(new_messages)
async def serialize(self, **kwargs: Any) -> dict[str, Any]:
"""Serializes the current object's state.
Args:
**kwargs: Arguments for serialization.
"""
chat_message_store_state = None
if self._message_store is not None:
chat_message_store_state = await self._message_store.serialize(**kwargs)
state = AgentThreadState(
service_thread_id=self._service_thread_id, chat_message_store_state=chat_message_store_state
)
return state.model_dump()
@classmethod
async def deserialize(
cls: type[TAgentThread],
serialized_thread_state: dict[str, Any],
*,
message_store: ChatMessageStoreProtocol | None = None,
**kwargs: Any,
) -> TAgentThread:
"""Deserializes the state from a dictionary into a new AgentThread instance.
Args:
serialized_thread_state: The serialized thread state as a dictionary.
message_store: Optional ChatMessageStoreProtocol to use for managing messages.
If not provided, a new ChatMessageStore will be created if needed.
**kwargs: Additional arguments for deserialization.
Returns:
A new AgentThread instance with properties set from the serialized state.
"""
state = AgentThreadState.model_validate(serialized_thread_state)
if state.service_thread_id is not None:
return cls(service_thread_id=state.service_thread_id)
# If we don't have any ChatMessageStoreProtocol state return here.
if state.chat_message_store_state is None:
return cls()
if message_store is not None:
try:
await message_store.update_from_state(state.chat_message_store_state, **kwargs)
except Exception as ex:
raise AgentThreadException("Failed to deserialize the provided message store.") from ex
return cls(message_store=message_store)
try:
message_store = await ChatMessageStore.deserialize(state.chat_message_store_state, **kwargs)
except Exception as ex:
raise AgentThreadException("Failed to deserialize the message store.") from ex
return cls(message_store=message_store)
async def update_from_thread_state(
self,
serialized_thread_state: dict[str, Any],
**kwargs: Any,
) -> None:
"""Deserializes the state from a dictionary into the thread properties."""
state = AgentThreadState.model_validate(serialized_thread_state)
if state.service_thread_id is not None:
self.service_thread_id = state.service_thread_id
# Since we have an ID, we should not have a chat message store and we can return here.
return
# If we don't have any ChatMessageStoreProtocol state return here.
if state.chat_message_store_state is None:
return
if self.message_store is not None:
await self.message_store.update_from_state(state.chat_message_store_state, **kwargs)
# If we don't have a chat message store yet, create an in-memory one.
return
# Create the message store from the default.
self.message_store = await ChatMessageStore.deserialize(state.chat_message_store_state, **kwargs) # type: ignore