mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
10d10364a9
* cleanup of threads and serialization * fix for sliding window * fix redis test * updated from comments * updated context provider and threads * updated lock * add asyncio default * fix redis tests * fix tests * fix tests * renamed to invoking * fixed tests * fix for instructions
356 lines
14 KiB
Python
356 lines
14 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
from collections.abc import Sequence
|
|
from typing import Any, Protocol, TypeVar
|
|
|
|
from pydantic import model_validator
|
|
|
|
from ._memory import AggregateContextProvider
|
|
from ._pydantic import AFBaseModel
|
|
from ._types import ChatMessage
|
|
from .exceptions import AgentThreadException
|
|
|
|
__all__ = ["AgentThread", "ChatMessageStore", "ChatMessageStoreProtocol"]
|
|
|
|
|
|
class ChatMessageStoreProtocol(Protocol):
|
|
"""Defines methods for storing and retrieving chat messages associated with a specific thread.
|
|
|
|
Implementations of this protocol are responsible for managing the storage of chat messages,
|
|
including handling large volumes of data by truncating or summarizing messages as necessary.
|
|
"""
|
|
|
|
async def list_messages(self) -> list[ChatMessage]:
|
|
"""Gets all the messages from the store that should be used for the next agent invocation.
|
|
|
|
Messages are returned in ascending chronological order, with the oldest message first.
|
|
|
|
If the messages stored in the store become very large, it is up to the store to
|
|
truncate, summarize or otherwise limit the number of messages returned.
|
|
|
|
When using implementations of ChatMessageStoreProtocol, a new one should be created for each thread
|
|
since they may contain state that is specific to a thread.
|
|
"""
|
|
...
|
|
|
|
async def add_messages(self, messages: Sequence[ChatMessage]) -> None:
|
|
"""Adds messages to the store."""
|
|
...
|
|
|
|
@classmethod
|
|
async def deserialize(cls, serialized_store_state: Any, **kwargs: Any) -> "ChatMessageStoreProtocol":
|
|
"""Creates a new instance of the store from previously serialized state.
|
|
|
|
This method, together with serialize_state can be used to save and load messages from a persistent store
|
|
if this store only has messages in memory.
|
|
"""
|
|
...
|
|
|
|
async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) -> None:
|
|
"""Update the current ChatMessageStore instance from serialized state data.
|
|
|
|
Args:
|
|
serialized_store_state: Previously serialized state data containing messages.
|
|
**kwargs: Additional arguments for deserialization.
|
|
"""
|
|
...
|
|
|
|
async def serialize(self, **kwargs: Any) -> Any:
|
|
"""Serializes the current object's state.
|
|
|
|
This method, together with deserialize can be used to save and load messages from a persistent store
|
|
if this store only has messages in memory.
|
|
"""
|
|
...
|
|
|
|
|
|
class AgentThreadState(AFBaseModel):
|
|
"""State model for serializing and deserializing thread information.
|
|
|
|
Attributes:
|
|
service_thread_id: Optional ID of the thread managed by the agent service.
|
|
chat_message_store_state: Optional serialized state of the chat message store.
|
|
"""
|
|
|
|
service_thread_id: str | None = None
|
|
chat_message_store_state: Any | None = None
|
|
|
|
@model_validator(mode="before")
|
|
def validate_only_one(cls, values: dict[str, Any]) -> dict[str, Any]:
|
|
if (
|
|
isinstance(values, dict)
|
|
and values.get("service_thread_id") is not None
|
|
and values.get("chat_message_store_state") is not None
|
|
):
|
|
raise AgentThreadException("Only one of service_thread_id or chat_message_store_state may be set.")
|
|
return values
|
|
|
|
|
|
class ChatMessageStoreState(AFBaseModel):
|
|
"""State model for serializing and deserializing chat message store data.
|
|
|
|
Attributes:
|
|
messages: List of chat messages stored in the message store.
|
|
"""
|
|
|
|
messages: list[ChatMessage]
|
|
|
|
|
|
TChatMessageStore = TypeVar("TChatMessageStore", bound="ChatMessageStore")
|
|
|
|
|
|
class ChatMessageStore:
|
|
"""An in-memory implementation of ChatMessageStoreProtocol that stores messages in a list.
|
|
|
|
This implementation provides a simple, list-based storage for chat messages
|
|
with support for serialization and deserialization. It implements all the
|
|
required methods of the ChatMessageStoreProtocol protocol.
|
|
|
|
The store maintains messages in memory and provides methods to serialize
|
|
and deserialize the state for persistence purposes.
|
|
|
|
Args:
|
|
messages: Optional initial list of ChatMessage objects to populate the store.
|
|
"""
|
|
|
|
def __init__(self, messages: Sequence[ChatMessage] | None = None):
|
|
"""Create a ChatMessageStore for use in a thread.
|
|
|
|
Args:
|
|
messages: The messages to store.
|
|
"""
|
|
self.messages = list(messages) if messages else []
|
|
|
|
async def add_messages(self, messages: Sequence[ChatMessage]) -> None:
|
|
"""Add messages to the store.
|
|
|
|
Args:
|
|
messages: Sequence of ChatMessage objects to add to the store.
|
|
"""
|
|
self.messages.extend(messages)
|
|
|
|
async def list_messages(self) -> list[ChatMessage]:
|
|
"""Get all messages from the store in chronological order.
|
|
|
|
Returns:
|
|
List of ChatMessage objects, ordered from oldest to newest.
|
|
"""
|
|
return self.messages
|
|
|
|
@classmethod
|
|
async def deserialize(
|
|
cls: type[TChatMessageStore], serialized_store_state: Any, **kwargs: Any
|
|
) -> TChatMessageStore:
|
|
"""Create a new ChatMessageStore instance from serialized state data.
|
|
|
|
Args:
|
|
serialized_store_state: Previously serialized state data containing messages.
|
|
**kwargs: Additional arguments for deserialization.
|
|
|
|
Returns:
|
|
A new ChatMessageStore instance populated with messages from the serialized state.
|
|
"""
|
|
state = ChatMessageStoreState.model_validate(serialized_store_state, **kwargs)
|
|
if state.messages:
|
|
return cls(messages=state.messages)
|
|
return cls()
|
|
|
|
async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) -> None:
|
|
"""Update the current ChatMessageStore instance from serialized state data.
|
|
|
|
Args:
|
|
serialized_store_state: Previously serialized state data containing messages.
|
|
**kwargs: Additional arguments for deserialization.
|
|
"""
|
|
if not serialized_store_state:
|
|
return
|
|
state = ChatMessageStoreState.model_validate(serialized_store_state, **kwargs)
|
|
if state.messages:
|
|
self.messages = state.messages
|
|
|
|
async def serialize(self, **kwargs: Any) -> Any:
|
|
"""Serialize the current store state for persistence.
|
|
|
|
Args:
|
|
**kwargs: Additional arguments for serialization.
|
|
|
|
Returns:
|
|
Serialized state data that can be used with deserialize_state.
|
|
"""
|
|
state = ChatMessageStoreState(messages=self.messages)
|
|
return state.model_dump(**kwargs)
|
|
|
|
|
|
TAgentThread = TypeVar("TAgentThread", bound="AgentThread")
|
|
|
|
|
|
class AgentThread:
|
|
"""The Agent thread class, this can represent both a locally managed thread or a thread managed by the service."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
service_thread_id: str | None = None,
|
|
message_store: ChatMessageStoreProtocol | None = None,
|
|
context_provider: AggregateContextProvider | None = None,
|
|
) -> None:
|
|
"""Initialize an AgentThread, do not use this method manually, always use: agent.get_new_thread().
|
|
|
|
Args:
|
|
service_thread_id: Optional ID of the thread managed by the agent service.
|
|
message_store: Optional ChatMessageStore implementation for managing chat messages.
|
|
context_provider: Optional ContextProvider for the thread.
|
|
|
|
Note:
|
|
Either service_thread_id or message_store may be set, but not both.
|
|
"""
|
|
if service_thread_id is not None and message_store is not None:
|
|
raise AgentThreadException("Only the service_thread_id or message_store may be set, but not both.")
|
|
|
|
self._service_thread_id = service_thread_id
|
|
self._message_store = message_store
|
|
self.context_provider = context_provider
|
|
|
|
@property
|
|
def is_initialized(self) -> bool:
|
|
"""Indicates if the thread is initialized.
|
|
|
|
This means either the service_thread_id or the message_store is set.
|
|
"""
|
|
return self._service_thread_id is not None or self._message_store is not None
|
|
|
|
@property
|
|
def service_thread_id(self) -> str | None:
|
|
"""Gets the ID of the current thread to support cases where the thread is owned by the agent service."""
|
|
return self._service_thread_id
|
|
|
|
@service_thread_id.setter
|
|
def service_thread_id(self, service_thread_id: str | None) -> None:
|
|
"""Sets the ID of the current thread to support cases where the thread is owned by the agent service.
|
|
|
|
Note that either service_thread_id or message_store may be set, but not both.
|
|
"""
|
|
if service_thread_id is None:
|
|
return
|
|
|
|
if self._message_store is not None:
|
|
raise AgentThreadException(
|
|
"Only the service_thread_id or message_store may be set, "
|
|
"but not both and switching from one to another is not supported."
|
|
)
|
|
self._service_thread_id = service_thread_id
|
|
|
|
@property
|
|
def message_store(self) -> ChatMessageStoreProtocol | None:
|
|
"""Gets the ChatMessageStoreProtocol used by this thread."""
|
|
return self._message_store
|
|
|
|
@message_store.setter
|
|
def message_store(self, message_store: ChatMessageStoreProtocol | None) -> None:
|
|
"""Sets the ChatMessageStoreProtocol used by this thread.
|
|
|
|
Note that either service_thread_id or message_store may be set, but not both.
|
|
"""
|
|
if message_store is None:
|
|
return
|
|
|
|
if self._service_thread_id is not None:
|
|
raise AgentThreadException(
|
|
"Only the service_thread_id or message_store may be set, "
|
|
"but not both and switching from one to another is not supported."
|
|
)
|
|
|
|
self._message_store = message_store
|
|
|
|
async def on_new_messages(self, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
|
|
"""Invoked when a new message has been contributed to the chat by any participant."""
|
|
if self._service_thread_id is not None:
|
|
# If the thread messages are stored in the service there is nothing to do here,
|
|
# since invoking the service should already update the thread.
|
|
return
|
|
if self._message_store is None:
|
|
# If there is no conversation id, and no store we can
|
|
# create a default in memory store.
|
|
self._message_store = ChatMessageStore()
|
|
# If a store has been provided, we need to add the messages to the store.
|
|
if isinstance(new_messages, ChatMessage):
|
|
new_messages = [new_messages]
|
|
await self._message_store.add_messages(new_messages)
|
|
|
|
async def serialize(self, **kwargs: Any) -> dict[str, Any]:
|
|
"""Serializes the current object's state.
|
|
|
|
Args:
|
|
**kwargs: Arguments for serialization.
|
|
"""
|
|
chat_message_store_state = None
|
|
if self._message_store is not None:
|
|
chat_message_store_state = await self._message_store.serialize(**kwargs)
|
|
|
|
state = AgentThreadState(
|
|
service_thread_id=self._service_thread_id, chat_message_store_state=chat_message_store_state
|
|
)
|
|
return state.model_dump()
|
|
|
|
@classmethod
|
|
async def deserialize(
|
|
cls: type[TAgentThread],
|
|
serialized_thread_state: dict[str, Any],
|
|
*,
|
|
message_store: ChatMessageStoreProtocol | None = None,
|
|
**kwargs: Any,
|
|
) -> TAgentThread:
|
|
"""Deserializes the state from a dictionary into a new AgentThread instance.
|
|
|
|
Args:
|
|
serialized_thread_state: The serialized thread state as a dictionary.
|
|
message_store: Optional ChatMessageStoreProtocol to use for managing messages.
|
|
If not provided, a new ChatMessageStore will be created if needed.
|
|
**kwargs: Additional arguments for deserialization.
|
|
|
|
Returns:
|
|
A new AgentThread instance with properties set from the serialized state.
|
|
"""
|
|
state = AgentThreadState.model_validate(serialized_thread_state)
|
|
|
|
if state.service_thread_id is not None:
|
|
return cls(service_thread_id=state.service_thread_id)
|
|
|
|
# If we don't have any ChatMessageStoreProtocol state return here.
|
|
if state.chat_message_store_state is None:
|
|
return cls()
|
|
|
|
if message_store is not None:
|
|
try:
|
|
await message_store.update_from_state(state.chat_message_store_state, **kwargs)
|
|
except Exception as ex:
|
|
raise AgentThreadException("Failed to deserialize the provided message store.") from ex
|
|
return cls(message_store=message_store)
|
|
try:
|
|
message_store = await ChatMessageStore.deserialize(state.chat_message_store_state, **kwargs)
|
|
except Exception as ex:
|
|
raise AgentThreadException("Failed to deserialize the message store.") from ex
|
|
return cls(message_store=message_store)
|
|
|
|
async def update_from_thread_state(
|
|
self,
|
|
serialized_thread_state: dict[str, Any],
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Deserializes the state from a dictionary into the thread properties."""
|
|
state = AgentThreadState.model_validate(serialized_thread_state)
|
|
|
|
if state.service_thread_id is not None:
|
|
self.service_thread_id = state.service_thread_id
|
|
# Since we have an ID, we should not have a chat message store and we can return here.
|
|
return
|
|
# If we don't have any ChatMessageStoreProtocol state return here.
|
|
if state.chat_message_store_state is None:
|
|
return
|
|
if self.message_store is not None:
|
|
await self.message_store.update_from_state(state.chat_message_store_state, **kwargs)
|
|
# If we don't have a chat message store yet, create an in-memory one.
|
|
return
|
|
# Create the message store from the default.
|
|
self.message_store = await ChatMessageStore.deserialize(state.chat_message_store_state, **kwargs) # type: ignore
|