mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Context providers abstraction and Mem0 implementation (#631)
* Added context provider abstractions * Added mem0 implementation * Example and small fixes * Added unit tests for agent * Added unit tests for mem0 provider * Updated README * Small doc updates * Update python/packages/mem0/agent_framework_mem0/_provider.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Small fixes in tests * Renaming based on PR feedback * Small fixes * Added tests for AggregateContextProvider * Small improvements * More improvements based on PR feedback * Small constant update * Added more examples * Added README for Mem0 examples * Small updates to API * Updated initialization logic * Updates for context manager * Updated Context class * Dependency update * Revert changes * Fixed tests --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Chris <66376200+crickman@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
89c8418705
commit
57d09afe04
@@ -15,3 +15,5 @@ AGENT_FRAMEWORK_OTLP_ENDPOINT="http://localhost:4317/"
|
||||
AGENT_FRAMEWORK_ENABLE_OTEL=true
|
||||
AGENT_FRAMEWORK_ENABLE_SENSITIVE_DATA=true
|
||||
AGENT_FRAMEWORK_WORKFLOW_ENABLE_OTEL=true
|
||||
# Mem0
|
||||
MEM0_API_KEY=""
|
||||
|
||||
@@ -12,6 +12,7 @@ from ._agents import * # noqa: F403
|
||||
from ._clients import * # noqa: F403
|
||||
from ._logging import * # noqa: F403
|
||||
from ._mcp import * # noqa: F403
|
||||
from ._memory import * # noqa: F403
|
||||
from ._threads import * # noqa: F403
|
||||
from ._tools import * # noqa: F403
|
||||
from ._types import * # noqa: F403
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import sys
|
||||
from collections.abc import AsyncIterable, Callable, MutableMapping, Sequence
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
||||
from itertools import chain
|
||||
from typing import Any, ClassVar, Literal, Protocol, TypeVar, runtime_checkable
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -12,6 +11,7 @@ from pydantic import BaseModel, Field, PrivateAttr
|
||||
from ._clients import BaseChatClient, ChatClientProtocol
|
||||
from ._logging import get_logger
|
||||
from ._mcp import MCPTool
|
||||
from ._memory import AggregateContextProvider, Context, ContextProvider
|
||||
from ._pydantic import AFBaseModel
|
||||
from ._threads import AgentThread, ChatMessageStore, deserialize_thread_state, thread_on_new_messages
|
||||
from ._tools import FUNCTION_INVOKING_CHAT_CLIENT_MARKER, ToolProtocol
|
||||
@@ -137,12 +137,13 @@ class BaseAgent(AFBaseModel):
|
||||
name: The name of the agent, can be None.
|
||||
description: The description of the agent.
|
||||
display_name: The display name of the agent, which is either the name or id.
|
||||
|
||||
context_providers: The collection of multiple context providers to include during agent invocation.
|
||||
"""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
context_providers: AggregateContextProvider | None = None
|
||||
|
||||
async def _notify_thread_of_new_messages(
|
||||
self, thread: AgentThread, new_messages: ChatMessage | Sequence[ChatMessage]
|
||||
@@ -214,6 +215,7 @@ class ChatAgent(BaseAgent):
|
||||
user: str | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
chat_message_store_factory: Callable[[], ChatMessageStore] | None = None,
|
||||
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a ChatAgent.
|
||||
@@ -248,6 +250,7 @@ class ChatAgent(BaseAgent):
|
||||
additional_properties: additional properties to include in the request.
|
||||
chat_message_store_factory: factory function to create an instance of ChatMessageStore. If not provided,
|
||||
the default in-memory store will be used.
|
||||
context_providers: The collection of multiple context providers to include during agent invocation.
|
||||
kwargs: any additional keyword arguments.
|
||||
Unused, can be used by subclasses of this Agent.
|
||||
"""
|
||||
@@ -258,6 +261,8 @@ class ChatAgent(BaseAgent):
|
||||
|
||||
kwargs.update(additional_properties or {})
|
||||
|
||||
aggregate_context_providers = self._prepare_context_providers(context_providers)
|
||||
|
||||
# We ignore the MCP Servers here and store them separately,
|
||||
# we add their functions to the tools list at runtime
|
||||
normalized_tools = [] if tools is None else tools if isinstance(tools, list) else [tools]
|
||||
@@ -266,6 +271,7 @@ class ChatAgent(BaseAgent):
|
||||
args: dict[str, Any] = {
|
||||
"chat_client": chat_client,
|
||||
"chat_message_store_factory": chat_message_store_factory,
|
||||
"context_providers": aggregate_context_providers,
|
||||
"chat_options": ChatOptions(
|
||||
ai_model_id=model,
|
||||
frequency_penalty=frequency_penalty,
|
||||
@@ -301,12 +307,16 @@ class ChatAgent(BaseAgent):
|
||||
async def __aenter__(self) -> "Self":
|
||||
"""Async context manager entry.
|
||||
|
||||
If either the chat_client or the local_mcp_tools are context managers,
|
||||
If any of the chat_client, local_mcp_tools, or context_providers are context managers,
|
||||
they will be entered into the async exit stack to ensure proper cleanup.
|
||||
|
||||
This list might be extended in the future.
|
||||
"""
|
||||
for context_manager in chain([self.chat_client], self._local_mcp_tools):
|
||||
context_managers = [self.chat_client, *self._local_mcp_tools]
|
||||
if self.context_providers:
|
||||
context_managers.append(self.context_providers)
|
||||
|
||||
for context_manager in context_managers:
|
||||
if isinstance(context_manager, AbstractAsyncContextManager):
|
||||
await self._async_exit_stack.enter_async_context(context_manager)
|
||||
return self
|
||||
@@ -388,7 +398,10 @@ class ChatAgent(BaseAgent):
|
||||
will only be passed to functions that are called.
|
||||
"""
|
||||
input_messages = self._normalize_messages(messages)
|
||||
thread, thread_messages = await self._prepare_thread_and_messages(thread=thread, input_messages=input_messages)
|
||||
context = await self.context_providers.model_invoking(input_messages) if self.context_providers else None
|
||||
thread, thread_messages = await self._prepare_thread_and_messages(
|
||||
thread=thread, context=context, input_messages=input_messages
|
||||
)
|
||||
agent_name = self._get_agent_name()
|
||||
|
||||
# Resolve final tool list (runtime provided tools + local MCP server tools)
|
||||
@@ -441,6 +454,10 @@ class ChatAgent(BaseAgent):
|
||||
await self._notify_thread_of_new_messages(thread, input_messages)
|
||||
await self._notify_thread_of_new_messages(thread, response.messages)
|
||||
|
||||
if self.context_providers:
|
||||
await self.context_providers.thread_created(response.conversation_id)
|
||||
await self.context_providers.messages_adding(thread.service_thread_id, input_messages + response.messages)
|
||||
|
||||
return AgentRunResponse(
|
||||
messages=response.messages,
|
||||
response_id=response.response_id,
|
||||
@@ -509,7 +526,10 @@ class ChatAgent(BaseAgent):
|
||||
|
||||
"""
|
||||
input_messages = self._normalize_messages(messages)
|
||||
thread, thread_messages = await self._prepare_thread_and_messages(thread=thread, input_messages=input_messages)
|
||||
context = await self.context_providers.model_invoking(input_messages) if self.context_providers else None
|
||||
thread, thread_messages = await self._prepare_thread_and_messages(
|
||||
thread=thread, context=context, input_messages=input_messages
|
||||
)
|
||||
agent_name = self._get_agent_name()
|
||||
response_updates: list[ChatResponseUpdate] = []
|
||||
|
||||
@@ -575,6 +595,10 @@ class ChatAgent(BaseAgent):
|
||||
await self._notify_thread_of_new_messages(thread, input_messages)
|
||||
await self._notify_thread_of_new_messages(thread, response.messages)
|
||||
|
||||
if self.context_providers:
|
||||
await self.context_providers.thread_created(response.conversation_id)
|
||||
await self.context_providers.messages_adding(thread.service_thread_id, input_messages + response.messages)
|
||||
|
||||
def get_new_thread(self) -> AgentThread:
|
||||
message_store: ChatMessageStore | None = None
|
||||
|
||||
@@ -617,12 +641,14 @@ class ChatAgent(BaseAgent):
|
||||
self,
|
||||
*,
|
||||
thread: AgentThread | None,
|
||||
context: Context | None,
|
||||
input_messages: list[ChatMessage] | None = None,
|
||||
) -> tuple[AgentThread, list[ChatMessage]]:
|
||||
"""Prepare the messages for agent execution.
|
||||
|
||||
Args:
|
||||
thread: The conversation thread.
|
||||
context: Context to include in messages.
|
||||
input_messages: Messages to process.
|
||||
|
||||
Returns:
|
||||
@@ -636,6 +662,8 @@ class ChatAgent(BaseAgent):
|
||||
messages: list[ChatMessage] = []
|
||||
if self.instructions:
|
||||
messages.append(ChatMessage(role=Role.SYSTEM, text=self.instructions))
|
||||
if context and context.contents:
|
||||
messages.append(ChatMessage(role=Role.SYSTEM, contents=context.contents))
|
||||
if thread.message_store:
|
||||
messages.extend(await thread.message_store.list_messages() or [])
|
||||
messages.extend(input_messages or [])
|
||||
@@ -658,3 +686,18 @@ class ChatAgent(BaseAgent):
|
||||
|
||||
def _get_agent_name(self) -> str:
|
||||
return self.name or "UnnamedAgent"
|
||||
|
||||
def _prepare_context_providers(
|
||||
self,
|
||||
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
|
||||
) -> AggregateContextProvider | None:
|
||||
if not context_providers:
|
||||
return None
|
||||
|
||||
if isinstance(context_providers, AggregateContextProvider):
|
||||
return context_providers
|
||||
|
||||
if isinstance(context_providers, ContextProvider):
|
||||
return AggregateContextProvider([context_providers])
|
||||
|
||||
return AggregateContextProvider(context_providers)
|
||||
|
||||
@@ -9,6 +9,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from ._logging import get_logger
|
||||
from ._mcp import MCPTool
|
||||
from ._memory import AggregateContextProvider, ContextProvider
|
||||
from ._pydantic import AFBaseModel
|
||||
from ._threads import ChatMessageStore
|
||||
from ._tools import ToolProtocol
|
||||
@@ -463,6 +464,7 @@ class BaseChatClient(AFBaseModel, ABC):
|
||||
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
|
||||
| None = None,
|
||||
chat_message_store_factory: Callable[[], ChatMessageStore] | None = None,
|
||||
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "ChatAgent":
|
||||
"""Create an agent with the given name and instructions.
|
||||
@@ -473,6 +475,7 @@ class BaseChatClient(AFBaseModel, ABC):
|
||||
tools: Optional list of tools to associate with the agent.
|
||||
chat_message_store_factory: Factory function to create an instance of ChatMessageStore. If not provided,
|
||||
the default in-memory store will be used.
|
||||
context_providers: Context providers to include during agent invocation.
|
||||
**kwargs: Additional keyword arguments to pass to the agent.
|
||||
See ChatAgent for all the available options.
|
||||
|
||||
@@ -487,6 +490,7 @@ class BaseChatClient(AFBaseModel, ABC):
|
||||
instructions=instructions,
|
||||
tools=tools,
|
||||
chat_message_store_factory=chat_message_store_factory,
|
||||
context_providers=context_providers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import MutableSequence, Sequence
|
||||
from contextlib import AsyncExitStack
|
||||
from types import TracebackType
|
||||
|
||||
from ._pydantic import AFBaseModel
|
||||
from ._types import ChatMessage, Contents
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self # pragma: no cover
|
||||
|
||||
# region Context
|
||||
|
||||
|
||||
class Context(AFBaseModel):
|
||||
"""A class containing any context that should be provided to the AI model as supplied by an ContextProvider.
|
||||
|
||||
Each ContextProvider has the ability to provide its own context for each invocation.
|
||||
The Context class contains the additional context supplied by the ContextProvider.
|
||||
This context will be combined with context supplied by other providers before being passed to the AI model.
|
||||
This context is per invocation, and will not be stored as part of the chat history.
|
||||
"""
|
||||
|
||||
contents: list[Contents] | None = None
|
||||
"""
|
||||
Any content to pass to the AI model in addition to any other prompts
|
||||
that it may already have (in the case of an agent), or chat history that may already exist.
|
||||
"""
|
||||
|
||||
|
||||
# region ContextProvider
|
||||
|
||||
|
||||
class ContextProvider(AFBaseModel, ABC):
|
||||
"""Base class for all context providers.
|
||||
|
||||
A context provider is a component that can be used to enhance the AI's context management.
|
||||
It can listen to changes in the conversation and provide additional context to the AI model
|
||||
just before invocation.
|
||||
"""
|
||||
|
||||
async def thread_created(self, thread_id: str | None) -> None:
|
||||
"""Called just after a new thread is created.
|
||||
|
||||
Implementers can use this method to do any operations required at the creation of a new thread.
|
||||
For example, checking long term storage for any data that is relevant
|
||||
to the current session based on the input text.
|
||||
|
||||
Args:
|
||||
thread_id: The ID of the new thread.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def messages_adding(self, thread_id: str | None, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
|
||||
"""Called just before messages are added to the chat by any participant.
|
||||
|
||||
Inheritors can use this method to update their context based on new messages.
|
||||
|
||||
Args:
|
||||
thread_id: The ID of the thread for the new message.
|
||||
new_messages: New messages to add.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def model_invoking(self, messages: ChatMessage | MutableSequence[ChatMessage]) -> Context:
|
||||
"""Called just before the Model/Agent/etc. is invoked.
|
||||
|
||||
Implementers can load any additional context required at this time,
|
||||
and they should return any context that should be passed to the agent.
|
||||
|
||||
Args:
|
||||
messages: The most recent messages that the agent is being invoked with.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def __aenter__(self) -> "Self":
|
||||
"""Async context manager entry.
|
||||
|
||||
Override this method to perform any setup operations when the context provider is entered.
|
||||
|
||||
Returns:
|
||||
Self for chaining.
|
||||
"""
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Async context manager exit.
|
||||
|
||||
Override this method to perform any cleanup operations when the context provider is exited.
|
||||
|
||||
Args:
|
||||
exc_type: Exception type if an exception occurred, None otherwise.
|
||||
exc_val: Exception value if an exception occurred, None otherwise.
|
||||
exc_tb: Exception traceback if an exception occurred, None otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# region AggregateContextProvider
|
||||
|
||||
|
||||
class AggregateContextProvider(ContextProvider):
|
||||
"""A ContextProvider that contains multiple context providers.
|
||||
|
||||
It delegates events to multiple context providers and aggregates responses from those events before returning.
|
||||
"""
|
||||
|
||||
providers: list[ContextProvider]
|
||||
"""List of registered context providers."""
|
||||
|
||||
def __init__(self, context_providers: Sequence[ContextProvider] | None = None) -> None:
|
||||
"""Initialize AggregateContextProvider with context providers.
|
||||
|
||||
Args:
|
||||
context_providers: Context providers to add.
|
||||
"""
|
||||
super().__init__(providers=list(context_providers or [])) # type: ignore
|
||||
self._exit_stack: AsyncExitStack | None = None
|
||||
|
||||
def add(self, context_provider: ContextProvider) -> None:
|
||||
"""Adds new context provider.
|
||||
|
||||
Args:
|
||||
context_provider: Context provider to add.
|
||||
"""
|
||||
self.providers.append(context_provider)
|
||||
|
||||
async def thread_created(self, thread_id: str | None = None) -> None:
|
||||
await asyncio.gather(*[x.thread_created(thread_id) for x in self.providers])
|
||||
|
||||
async def messages_adding(self, thread_id: str | None, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
|
||||
await asyncio.gather(*[x.messages_adding(thread_id, new_messages) for x in self.providers])
|
||||
|
||||
async def model_invoking(self, messages: ChatMessage | MutableSequence[ChatMessage]) -> Context:
|
||||
sub_contexts = await asyncio.gather(*[x.model_invoking(messages) for x in self.providers])
|
||||
combined_context = Context()
|
||||
# Flatten the list of lists and filter out None values
|
||||
all_contents = []
|
||||
for ctx in sub_contexts:
|
||||
if ctx.contents:
|
||||
all_contents.extend(ctx.contents)
|
||||
|
||||
combined_context.contents = all_contents if all_contents else None
|
||||
return combined_context
|
||||
|
||||
async def __aenter__(self) -> "Self":
|
||||
"""Enter async context manager and set up all providers.
|
||||
|
||||
Returns:
|
||||
Self for chaining.
|
||||
"""
|
||||
self._exit_stack = AsyncExitStack()
|
||||
await self._exit_stack.__aenter__()
|
||||
|
||||
# Enter all context providers
|
||||
for provider in self.providers:
|
||||
await self._exit_stack.enter_async_context(provider)
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Exit async context manager and clean up all providers.
|
||||
|
||||
Args:
|
||||
exc_type: Exception type if an exception occurred, None otherwise.
|
||||
exc_val: Exception value if an exception occurred, None otherwise.
|
||||
exc_tb: Exception traceback if an exception occurred, None otherwise.
|
||||
"""
|
||||
if self._exit_stack is not None:
|
||||
await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
|
||||
self._exit_stack = None
|
||||
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
PACKAGE_NAME = "agent_framework_mem0"
|
||||
PACKAGE_EXTRA = "mem0"
|
||||
_IMPORTS = ["__version__", "Mem0Provider"]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in _IMPORTS:
|
||||
try:
|
||||
return getattr(importlib.import_module(PACKAGE_NAME), name)
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ModuleNotFoundError(
|
||||
f"The '{PACKAGE_EXTRA}' extra is not installed, "
|
||||
f"please do `pip install agent-framework[{PACKAGE_EXTRA}]`"
|
||||
) from exc
|
||||
raise AttributeError(f"Module {PACKAGE_NAME} has no attribute {name}.")
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return _IMPORTS
|
||||
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from agent_framework_mem0 import Mem0Provider, __version__
|
||||
|
||||
__all__ = ["Mem0Provider", "__version__"]
|
||||
@@ -455,7 +455,7 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient):
|
||||
)
|
||||
response_tools.append(
|
||||
WebSearchToolParam(
|
||||
type="web_search",
|
||||
type="web_search_preview",
|
||||
user_location=WebSearchUserLocation(
|
||||
type="approximate",
|
||||
city=location.get("city", None),
|
||||
|
||||
@@ -23,7 +23,7 @@ classifiers = [
|
||||
"Typing :: Typed",
|
||||
]
|
||||
dependencies = [
|
||||
"openai>=1.103.0",
|
||||
"openai>=1.99.0",
|
||||
"pydantic>=2.11.7",
|
||||
"pydantic-settings>=2.10.1",
|
||||
"typing-extensions>=4.14.0",
|
||||
|
||||
@@ -151,11 +151,19 @@ class MockBaseChatClient(BaseChatClient):
|
||||
logger.debug(f"Running base chat client inner, with: {messages=}, {chat_options=}, {kwargs=}")
|
||||
if not self.run_responses:
|
||||
return ChatResponse(messages=ChatMessage(role="assistant", text=f"test response - {messages[0].text}"))
|
||||
|
||||
response = self.run_responses.pop(0)
|
||||
|
||||
if chat_options.tool_choice == "none":
|
||||
return ChatResponse(
|
||||
messages=ChatMessage(role="assistant", text="I broke out of the function invocation loop...")
|
||||
messages=ChatMessage(
|
||||
role="assistant",
|
||||
text="I broke out of the function invocation loop...",
|
||||
),
|
||||
conversation_id=response.conversation_id,
|
||||
)
|
||||
return self.run_responses.pop(0)
|
||||
|
||||
return response
|
||||
|
||||
@override
|
||||
async def _inner_get_streaming_response(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import AsyncIterable, MutableSequence, Sequence
|
||||
from uuid import uuid4
|
||||
|
||||
from pytest import raises
|
||||
@@ -15,10 +15,12 @@ from agent_framework import (
|
||||
ChatMessage,
|
||||
ChatMessageList,
|
||||
ChatResponse,
|
||||
Contents,
|
||||
HostedCodeInterpreterTool,
|
||||
Role,
|
||||
TextContent,
|
||||
)
|
||||
from agent_framework._memory import AggregateContextProvider, Context, ContextProvider
|
||||
from agent_framework.exceptions import AgentExecutionException
|
||||
|
||||
|
||||
@@ -100,6 +102,7 @@ async def test_chat_client_agent_prepare_thread_and_messages(chat_client: ChatCl
|
||||
|
||||
_, result_messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
|
||||
thread=thread,
|
||||
context=Context(),
|
||||
input_messages=[ChatMessage(role=Role.USER, text="Test")],
|
||||
)
|
||||
|
||||
@@ -184,3 +187,210 @@ async def test_chat_client_agent_author_name_is_used_from_response(chat_client_b
|
||||
result = await agent.run("Hello")
|
||||
assert result.text == "test response"
|
||||
assert result.messages[0].author_name == "TestAuthor"
|
||||
|
||||
|
||||
# Mock context provider for testing
|
||||
class MockContextProvider(ContextProvider):
|
||||
context_contents: list[Contents] | None = None
|
||||
thread_created_called: bool = False
|
||||
messages_adding_called: bool = False
|
||||
model_invoking_called: bool = False
|
||||
thread_created_thread_id: str | None = None
|
||||
messages_adding_thread_id: str | None = None
|
||||
new_messages: list[ChatMessage] = []
|
||||
|
||||
def __init__(self, contents: list[Contents] | None = None) -> None:
|
||||
super().__init__()
|
||||
self.context_contents = contents
|
||||
self.thread_created_called = False
|
||||
self.messages_adding_called = False
|
||||
self.model_invoking_called = False
|
||||
self.thread_created_thread_id = None
|
||||
self.messages_adding_thread_id = None
|
||||
self.new_messages = []
|
||||
|
||||
async def thread_created(self, thread_id: str | None) -> None:
|
||||
self.thread_created_called = True
|
||||
self.thread_created_thread_id = thread_id
|
||||
|
||||
async def messages_adding(self, thread_id: str | None, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
|
||||
self.messages_adding_called = True
|
||||
self.messages_adding_thread_id = thread_id
|
||||
if isinstance(new_messages, ChatMessage):
|
||||
self.new_messages.append(new_messages)
|
||||
else:
|
||||
self.new_messages.extend(new_messages)
|
||||
|
||||
async def model_invoking(self, messages: ChatMessage | MutableSequence[ChatMessage]) -> Context:
|
||||
self.model_invoking_called = True
|
||||
return Context(contents=self.context_contents)
|
||||
|
||||
|
||||
async def test_chat_agent_context_providers_model_invoking(chat_client: ChatClientProtocol) -> None:
|
||||
"""Test that context providers' model_invoking is called during agent run."""
|
||||
mock_provider = MockContextProvider(contents=[TextContent("Test context instructions")])
|
||||
agent = ChatAgent(chat_client=chat_client, context_providers=mock_provider)
|
||||
|
||||
await agent.run("Hello")
|
||||
|
||||
assert mock_provider.model_invoking_called
|
||||
|
||||
|
||||
async def test_chat_agent_context_providers_thread_created(chat_client_base: ChatClientProtocol) -> None:
|
||||
"""Test that context providers' thread_created is called during agent run."""
|
||||
mock_provider = MockContextProvider()
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent("test response")])],
|
||||
conversation_id="test-thread-id",
|
||||
)
|
||||
]
|
||||
|
||||
agent = ChatAgent(chat_client=chat_client_base, context_providers=mock_provider)
|
||||
|
||||
await agent.run("Hello")
|
||||
|
||||
assert mock_provider.thread_created_called
|
||||
assert mock_provider.thread_created_thread_id == "test-thread-id"
|
||||
|
||||
|
||||
async def test_chat_agent_context_providers_messages_adding(chat_client: ChatClientProtocol) -> None:
|
||||
"""Test that context providers' messages_adding is called during agent run."""
|
||||
mock_provider = MockContextProvider()
|
||||
agent = ChatAgent(chat_client=chat_client, context_providers=mock_provider)
|
||||
|
||||
await agent.run("Hello")
|
||||
|
||||
assert mock_provider.messages_adding_called
|
||||
# Should be called with both input and response messages
|
||||
assert len(mock_provider.new_messages) >= 2
|
||||
|
||||
|
||||
async def test_chat_agent_context_instructions_in_messages(chat_client: ChatClientProtocol) -> None:
|
||||
"""Test that AI context instructions are included in messages."""
|
||||
mock_provider = MockContextProvider(contents=[TextContent("Context-specific instructions")])
|
||||
agent = ChatAgent(chat_client=chat_client, instructions="Agent instructions", context_providers=mock_provider)
|
||||
|
||||
# We need to test the _prepare_thread_and_messages method directly
|
||||
context = Context(contents=[TextContent("Context-specific instructions")])
|
||||
_, messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
|
||||
thread=None, context=context, input_messages=[ChatMessage(role=Role.USER, text="Hello")]
|
||||
)
|
||||
|
||||
# Should have agent instructions, context instructions, and user message
|
||||
assert len(messages) == 3
|
||||
assert messages[0].role == Role.SYSTEM
|
||||
assert messages[0].text == "Agent instructions"
|
||||
assert messages[1].role == Role.SYSTEM
|
||||
assert messages[1].text == "Context-specific instructions"
|
||||
assert messages[2].role == Role.USER
|
||||
assert messages[2].text == "Hello"
|
||||
|
||||
|
||||
async def test_chat_agent_context_instructions_without_agent_instructions(chat_client: ChatClientProtocol) -> None:
|
||||
"""Test that AI context instructions work when agent has no instructions."""
|
||||
agent = ChatAgent(chat_client=chat_client) # No instructions
|
||||
context = Context(contents=[TextContent("Context-only instructions")])
|
||||
|
||||
_, messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
|
||||
thread=None, context=context, input_messages=[ChatMessage(role=Role.USER, text="Hello")]
|
||||
)
|
||||
|
||||
# Should have context instructions and user message only
|
||||
assert len(messages) == 2
|
||||
assert messages[0].role == Role.SYSTEM
|
||||
assert messages[0].text == "Context-only instructions"
|
||||
assert messages[1].role == Role.USER
|
||||
assert messages[1].text == "Hello"
|
||||
|
||||
|
||||
async def test_chat_agent_no_context_instructions(chat_client: ChatClientProtocol) -> None:
|
||||
"""Test behavior when AI context has no instructions."""
|
||||
agent = ChatAgent(chat_client=chat_client, instructions="Agent instructions")
|
||||
context = Context() # No instructions
|
||||
|
||||
_, messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
|
||||
thread=None, context=context, input_messages=[ChatMessage(role=Role.USER, text="Hello")]
|
||||
)
|
||||
|
||||
# Should have agent instructions and user message only
|
||||
assert len(messages) == 2
|
||||
assert messages[0].role == Role.SYSTEM
|
||||
assert messages[0].text == "Agent instructions"
|
||||
assert messages[1].role == Role.USER
|
||||
assert messages[1].text == "Hello"
|
||||
|
||||
|
||||
async def test_chat_agent_run_stream_context_providers(chat_client: ChatClientProtocol) -> None:
|
||||
"""Test that context providers work with run_stream method."""
|
||||
mock_provider = MockContextProvider(contents=[TextContent("Stream context instructions")])
|
||||
agent = ChatAgent(chat_client=chat_client, context_providers=mock_provider)
|
||||
|
||||
# Collect all stream updates
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in agent.run_stream("Hello"):
|
||||
updates.append(update)
|
||||
|
||||
# Verify context provider was called
|
||||
assert mock_provider.model_invoking_called
|
||||
assert mock_provider.thread_created_called
|
||||
assert mock_provider.messages_adding_called
|
||||
|
||||
|
||||
async def test_chat_agent_multiple_context_providers(chat_client: ChatClientProtocol) -> None:
|
||||
"""Test that multiple context providers work together."""
|
||||
provider1 = MockContextProvider(contents=[TextContent("First provider instructions")])
|
||||
provider2 = MockContextProvider(contents=[TextContent("Second provider instructions")])
|
||||
|
||||
agent = ChatAgent(chat_client=chat_client, context_providers=[provider1, provider2])
|
||||
|
||||
await agent.run("Hello")
|
||||
|
||||
# Both providers should be called
|
||||
assert provider1.model_invoking_called
|
||||
assert provider1.thread_created_called
|
||||
assert provider1.messages_adding_called
|
||||
|
||||
assert provider2.model_invoking_called
|
||||
assert provider2.thread_created_called
|
||||
assert provider2.messages_adding_called
|
||||
|
||||
|
||||
async def test_chat_agent_aggregate_context_provider_combines_instructions() -> None:
|
||||
"""Test that AggregateContextProvider combines instructions from multiple providers."""
|
||||
provider1 = MockContextProvider(contents=[TextContent("First instruction")])
|
||||
provider2 = MockContextProvider(contents=[TextContent("Second instruction")])
|
||||
|
||||
aggregate = AggregateContextProvider()
|
||||
aggregate.providers.append(provider1)
|
||||
aggregate.providers.append(provider2)
|
||||
|
||||
# Test model_invoking combines instructions
|
||||
result = await aggregate.model_invoking([ChatMessage(role=Role.USER, text="Test")])
|
||||
|
||||
assert result.contents
|
||||
assert isinstance(result.contents[0], TextContent)
|
||||
assert isinstance(result.contents[1], TextContent)
|
||||
assert result.contents[0].text == "First instruction"
|
||||
assert result.contents[1].text == "Second instruction"
|
||||
|
||||
|
||||
async def test_chat_agent_context_providers_with_thread_service_id(chat_client_base: ChatClientProtocol) -> None:
|
||||
"""Test context providers with service-managed thread."""
|
||||
mock_provider = MockContextProvider()
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent("test response")])],
|
||||
conversation_id="service-thread-123",
|
||||
)
|
||||
]
|
||||
|
||||
agent = ChatAgent(chat_client=chat_client_base, context_providers=mock_provider)
|
||||
|
||||
# Use existing service-managed thread
|
||||
thread = AgentThread(service_thread_id="existing-thread-id")
|
||||
await agent.run("Hello", thread=thread)
|
||||
|
||||
# messages_adding should be called with the service thread ID from response
|
||||
assert mock_provider.messages_adding_called
|
||||
assert mock_provider.messages_adding_thread_id == "service-thread-123" # Updated thread ID from response
|
||||
|
||||
@@ -0,0 +1,302 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import MutableSequence, Sequence
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
from agent_framework import ChatMessage, Contents, Role, TextContent
|
||||
from agent_framework._memory import AggregateContextProvider, Context, ContextProvider
|
||||
|
||||
|
||||
class MockContextProvider(ContextProvider):
|
||||
"""Mock ContextProvider for testing."""
|
||||
|
||||
context_contents: list[Contents] | None = None
|
||||
thread_created_called: bool = False
|
||||
messages_adding_called: bool = False
|
||||
model_invoking_called: bool = False
|
||||
thread_created_thread_id: str | None = None
|
||||
messages_adding_thread_id: str | None = None
|
||||
messages_adding_new_messages: ChatMessage | Sequence[ChatMessage] | None = None
|
||||
model_invoking_messages: ChatMessage | MutableSequence[ChatMessage] | None = None
|
||||
|
||||
def __init__(self, context_contents: list[Contents] | None = None) -> None:
|
||||
super().__init__()
|
||||
self.context_contents = context_contents
|
||||
self.thread_created_called = False
|
||||
self.messages_adding_called = False
|
||||
self.model_invoking_called = False
|
||||
self.thread_created_thread_id = None
|
||||
self.messages_adding_thread_id = None
|
||||
self.messages_adding_new_messages = None
|
||||
self.model_invoking_messages = None
|
||||
|
||||
async def thread_created(self, thread_id: str | None) -> None:
|
||||
"""Track thread_created calls."""
|
||||
self.thread_created_called = True
|
||||
self.thread_created_thread_id = thread_id
|
||||
|
||||
async def messages_adding(self, thread_id: str | None, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
|
||||
"""Track messages_adding calls."""
|
||||
self.messages_adding_called = True
|
||||
self.messages_adding_thread_id = thread_id
|
||||
self.messages_adding_new_messages = new_messages
|
||||
|
||||
async def model_invoking(self, messages: ChatMessage | MutableSequence[ChatMessage]) -> Context:
|
||||
"""Track model_invoking calls and return context."""
|
||||
self.model_invoking_called = True
|
||||
self.model_invoking_messages = messages
|
||||
context = Context()
|
||||
context.contents = self.context_contents
|
||||
return context
|
||||
|
||||
|
||||
class TestAggregateContextProvider:
|
||||
"""Tests for AggregateContextProvider class."""
|
||||
|
||||
def test_init_with_no_providers(self) -> None:
|
||||
"""Test initialization with no providers."""
|
||||
aggregate = AggregateContextProvider()
|
||||
assert aggregate.providers == []
|
||||
|
||||
def test_init_with_none_providers(self) -> None:
|
||||
"""Test initialization with None providers."""
|
||||
aggregate = AggregateContextProvider(None)
|
||||
assert aggregate.providers == []
|
||||
|
||||
def test_init_with_providers(self) -> None:
|
||||
"""Test initialization with providers."""
|
||||
provider1 = MockContextProvider([TextContent("instructions1")])
|
||||
provider2 = MockContextProvider([TextContent("instructions1")])
|
||||
providers = [provider1, provider2]
|
||||
|
||||
aggregate = AggregateContextProvider(providers)
|
||||
assert len(aggregate.providers) == 2
|
||||
assert aggregate.providers[0] is provider1
|
||||
assert aggregate.providers[1] is provider2
|
||||
|
||||
def test_add_provider(self) -> None:
|
||||
"""Test adding a provider."""
|
||||
aggregate = AggregateContextProvider()
|
||||
provider = MockContextProvider([TextContent("instructions")])
|
||||
|
||||
aggregate.add(provider)
|
||||
assert len(aggregate.providers) == 1
|
||||
assert aggregate.providers[0] is provider
|
||||
|
||||
def test_add_multiple_providers(self) -> None:
|
||||
"""Test adding multiple providers."""
|
||||
aggregate = AggregateContextProvider()
|
||||
provider1 = MockContextProvider([TextContent("instructions1")])
|
||||
provider2 = MockContextProvider([TextContent("instructions2")])
|
||||
|
||||
aggregate.add(provider1)
|
||||
aggregate.add(provider2)
|
||||
|
||||
assert len(aggregate.providers) == 2
|
||||
assert aggregate.providers[0] is provider1
|
||||
assert aggregate.providers[1] is provider2
|
||||
|
||||
async def test_thread_created_with_no_providers(self) -> None:
|
||||
"""Test thread_created with no providers."""
|
||||
aggregate = AggregateContextProvider()
|
||||
|
||||
# Should not raise an exception
|
||||
await aggregate.thread_created("thread-123")
|
||||
|
||||
async def test_thread_created_with_providers(self) -> None:
|
||||
"""Test thread_created calls all providers."""
|
||||
provider1 = MockContextProvider([TextContent("instructions1")])
|
||||
provider2 = MockContextProvider([TextContent("instructions2")])
|
||||
aggregate = AggregateContextProvider([provider1, provider2])
|
||||
|
||||
thread_id = "thread-123"
|
||||
await aggregate.thread_created(thread_id)
|
||||
|
||||
assert provider1.thread_created_called
|
||||
assert provider1.thread_created_thread_id == thread_id
|
||||
assert provider2.thread_created_called
|
||||
assert provider2.thread_created_thread_id == thread_id
|
||||
|
||||
async def test_thread_created_with_none_thread_id(self) -> None:
|
||||
"""Test thread_created with None thread_id."""
|
||||
provider = MockContextProvider([TextContent("instructions")])
|
||||
aggregate = AggregateContextProvider([provider])
|
||||
|
||||
await aggregate.thread_created(None)
|
||||
|
||||
assert provider.thread_created_called
|
||||
assert provider.thread_created_thread_id is None
|
||||
|
||||
async def test_messages_adding_with_no_providers(self) -> None:
|
||||
"""Test messages_adding with no providers."""
|
||||
aggregate = AggregateContextProvider()
|
||||
message = ChatMessage(text="Hello", role=Role.USER)
|
||||
|
||||
# Should not raise an exception
|
||||
await aggregate.messages_adding("thread-123", message)
|
||||
|
||||
async def test_messages_adding_with_single_message(self) -> None:
|
||||
"""Test messages_adding with a single message."""
|
||||
provider1 = MockContextProvider([TextContent("instructions1")])
|
||||
provider2 = MockContextProvider([TextContent("instructions2")])
|
||||
aggregate = AggregateContextProvider([provider1, provider2])
|
||||
|
||||
thread_id = "thread-123"
|
||||
message = ChatMessage(text="Hello", role=Role.USER)
|
||||
await aggregate.messages_adding(thread_id, message)
|
||||
|
||||
assert provider1.messages_adding_called
|
||||
assert provider1.messages_adding_thread_id == thread_id
|
||||
assert provider1.messages_adding_new_messages == message
|
||||
assert provider2.messages_adding_called
|
||||
assert provider2.messages_adding_thread_id == thread_id
|
||||
assert provider2.messages_adding_new_messages == message
|
||||
|
||||
async def test_messages_adding_with_message_sequence(self) -> None:
|
||||
"""Test messages_adding with a sequence of messages."""
|
||||
provider = MockContextProvider([TextContent("instructions")])
|
||||
aggregate = AggregateContextProvider([provider])
|
||||
|
||||
thread_id = "thread-123"
|
||||
messages = [
|
||||
ChatMessage(text="Hello", role=Role.USER),
|
||||
ChatMessage(text="Hi there", role=Role.ASSISTANT),
|
||||
]
|
||||
await aggregate.messages_adding(thread_id, messages)
|
||||
|
||||
assert provider.messages_adding_called
|
||||
assert provider.messages_adding_thread_id == thread_id
|
||||
assert provider.messages_adding_new_messages == messages
|
||||
|
||||
async def test_model_invoking_with_no_providers(self) -> None:
|
||||
"""Test model_invoking with no providers."""
|
||||
aggregate = AggregateContextProvider()
|
||||
message = ChatMessage(text="Hello", role=Role.USER)
|
||||
|
||||
context = await aggregate.model_invoking(message)
|
||||
|
||||
assert isinstance(context, Context)
|
||||
assert not context.contents
|
||||
|
||||
async def test_model_invoking_with_single_provider(self) -> None:
|
||||
"""Test model_invoking with a single provider."""
|
||||
provider = MockContextProvider([TextContent("Test instructions")])
|
||||
aggregate = AggregateContextProvider([provider])
|
||||
|
||||
message = ChatMessage(text="Hello", role=Role.USER)
|
||||
context = await aggregate.model_invoking(message)
|
||||
|
||||
assert provider.model_invoking_called
|
||||
assert provider.model_invoking_messages == message
|
||||
assert isinstance(context, Context)
|
||||
|
||||
assert context.contents
|
||||
assert isinstance(context.contents[0], TextContent)
|
||||
assert context.contents[0].text == "Test instructions"
|
||||
|
||||
async def test_model_invoking_with_multiple_providers(self) -> None:
|
||||
"""Test model_invoking combines contexts from multiple providers."""
|
||||
provider1 = MockContextProvider([TextContent("Instructions 1")])
|
||||
provider2 = MockContextProvider([TextContent("Instructions 2")])
|
||||
provider3 = MockContextProvider([TextContent("Instructions 3")])
|
||||
aggregate = AggregateContextProvider([provider1, provider2, provider3])
|
||||
|
||||
messages = [ChatMessage(text="Hello", role=Role.USER)]
|
||||
context = await aggregate.model_invoking(messages)
|
||||
|
||||
assert provider1.model_invoking_called
|
||||
assert provider1.model_invoking_messages == messages
|
||||
assert provider2.model_invoking_called
|
||||
assert provider2.model_invoking_messages == messages
|
||||
assert provider3.model_invoking_called
|
||||
assert provider3.model_invoking_messages == messages
|
||||
|
||||
assert isinstance(context, Context)
|
||||
|
||||
assert context.contents
|
||||
assert isinstance(context.contents[0], TextContent)
|
||||
assert isinstance(context.contents[1], TextContent)
|
||||
assert isinstance(context.contents[2], TextContent)
|
||||
assert context.contents[0].text == "Instructions 1"
|
||||
assert context.contents[1].text == "Instructions 2"
|
||||
assert context.contents[2].text == "Instructions 3"
|
||||
|
||||
async def test_model_invoking_with_none_instructions(self) -> None:
|
||||
"""Test model_invoking filters out None instructions."""
|
||||
provider1 = MockContextProvider([TextContent("Instructions 1")])
|
||||
provider2 = MockContextProvider(None) # None instructions
|
||||
provider3 = MockContextProvider([TextContent("Instructions 3")])
|
||||
aggregate = AggregateContextProvider([provider1, provider2, provider3])
|
||||
|
||||
message = ChatMessage(text="Hello", role=Role.USER)
|
||||
context = await aggregate.model_invoking(message)
|
||||
|
||||
assert isinstance(context, Context)
|
||||
assert context.contents
|
||||
assert isinstance(context.contents[0], TextContent)
|
||||
assert isinstance(context.contents[1], TextContent)
|
||||
assert context.contents[0].text == "Instructions 1"
|
||||
assert context.contents[1].text == "Instructions 3"
|
||||
|
||||
async def test_model_invoking_with_all_none_instructions(self) -> None:
|
||||
"""Test model_invoking when all providers return None instructions."""
|
||||
provider1 = MockContextProvider(None)
|
||||
provider2 = MockContextProvider(None)
|
||||
aggregate = AggregateContextProvider([provider1, provider2])
|
||||
|
||||
message = ChatMessage(text="Hello", role=Role.USER)
|
||||
context = await aggregate.model_invoking(message)
|
||||
|
||||
assert isinstance(context, Context)
|
||||
assert not context.contents
|
||||
|
||||
async def test_model_invoking_with_mutable_sequence(self) -> None:
|
||||
"""Test model_invoking with MutableSequence of messages."""
|
||||
provider = MockContextProvider([TextContent("Test instructions")])
|
||||
aggregate = AggregateContextProvider([provider])
|
||||
|
||||
messages = [ChatMessage(text="Hello", role=Role.USER)]
|
||||
context = await aggregate.model_invoking(messages)
|
||||
|
||||
assert provider.model_invoking_called
|
||||
assert provider.model_invoking_messages == messages
|
||||
assert isinstance(context, Context)
|
||||
assert context.contents
|
||||
assert isinstance(context.contents[0], TextContent)
|
||||
assert context.contents[0].text == "Test instructions"
|
||||
|
||||
async def test_async_methods_concurrent_execution(self) -> None:
|
||||
"""Test that async methods execute providers concurrently."""
|
||||
# Use AsyncMock to verify concurrent execution
|
||||
provider1 = Mock(spec=ContextProvider)
|
||||
provider1.thread_created = AsyncMock()
|
||||
provider1.messages_adding = AsyncMock()
|
||||
provider1.model_invoking = AsyncMock(return_value=Context(contents=[TextContent("Test 1")]))
|
||||
|
||||
provider2 = Mock(spec=ContextProvider)
|
||||
provider2.thread_created = AsyncMock()
|
||||
provider2.messages_adding = AsyncMock()
|
||||
provider2.model_invoking = AsyncMock(return_value=Context(contents=[TextContent("Test 2")]))
|
||||
|
||||
aggregate = AggregateContextProvider([provider1, provider2])
|
||||
|
||||
# Test thread_created
|
||||
await aggregate.thread_created("thread-123")
|
||||
provider1.thread_created.assert_called_once_with("thread-123")
|
||||
provider2.thread_created.assert_called_once_with("thread-123")
|
||||
|
||||
# Test messages_adding
|
||||
message = ChatMessage(text="Hello", role=Role.USER)
|
||||
await aggregate.messages_adding("thread-123", message)
|
||||
provider1.messages_adding.assert_called_once_with("thread-123", message)
|
||||
provider2.messages_adding.assert_called_once_with("thread-123", message)
|
||||
|
||||
# Test model_invoking
|
||||
context = await aggregate.model_invoking(message)
|
||||
provider1.model_invoking.assert_called_once_with(message)
|
||||
provider2.model_invoking.assert_called_once_with(message)
|
||||
assert context.contents
|
||||
assert isinstance(context.contents[0], TextContent)
|
||||
assert isinstance(context.contents[1], TextContent)
|
||||
assert context.contents[0].text == "Test 1"
|
||||
assert context.contents[1].text == "Test 2"
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
||||
@@ -0,0 +1,21 @@
|
||||
# Get Started with Microsoft Agent Framework Mem0
|
||||
|
||||
Please install this package as the extra for `agent-framework`:
|
||||
|
||||
```bash
|
||||
pip install agent-framework[mem0]
|
||||
```
|
||||
|
||||
## Memory Context Provider
|
||||
|
||||
The Mem0 context provider enables persistent memory capabilities for your agents, allowing them to remember user preferences and conversation context across different sessions and threads.
|
||||
|
||||
### Basic Usage Example
|
||||
|
||||
See the [Mem0 basic example](https://github.com/microsoft/agent-framework/tree/main/python/samples/getting_started/context_providers/mem0/mem0_basic.py) which demonstrates:
|
||||
|
||||
- Setting up an agent with Mem0 context provider
|
||||
- Teaching the agent user preferences
|
||||
- Retrieving information using remembered context across new threads
|
||||
- Persistent memory
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
from ._provider import Mem0Provider
|
||||
|
||||
try:
|
||||
__version__ = importlib.metadata.version(__name__)
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
__version__ = "0.0.0" # Fallback for development mode
|
||||
|
||||
__all__ = [
|
||||
"Mem0Provider",
|
||||
"__version__",
|
||||
]
|
||||
@@ -0,0 +1,180 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import sys
|
||||
from collections.abc import MutableSequence, Sequence
|
||||
from typing import Any, Final
|
||||
|
||||
from agent_framework import ChatMessage, Context, ContextProvider, TextContent
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self # pragma: no cover
|
||||
|
||||
|
||||
DEFAULT_CONTEXT_PROMPT: Final[str] = "## Memories\nConsider the following memories when answering user questions:"
|
||||
|
||||
|
||||
class Mem0Provider(ContextProvider):
|
||||
api_key: str | None = None
|
||||
application_id: str | None = None
|
||||
agent_id: str | None = None
|
||||
thread_id: str | None = None
|
||||
user_id: str | None = None
|
||||
scope_to_per_operation_thread_id: bool = False
|
||||
context_prompt: str = DEFAULT_CONTEXT_PROMPT
|
||||
# Use Any to avoid forward reference issues with AsyncMemoryClient
|
||||
mem0_client: Any = None
|
||||
|
||||
_should_close_client: bool = PrivateAttr(default=False) # Track whether we should close client connection
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
application_id: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
scope_to_per_operation_thread_id: bool = False,
|
||||
context_prompt: str = DEFAULT_CONTEXT_PROMPT,
|
||||
mem0_client: Any = None,
|
||||
) -> None:
|
||||
"""Initializes a new instance of the Mem0Provider class.
|
||||
|
||||
Args:
|
||||
api_key: The API key for authenticating with the Mem0 API. If not
|
||||
provided, it will attempt to use the MEM0_API_KEY environment variable.
|
||||
application_id: The application ID for scoping memories or None.
|
||||
agent_id: The agent ID for scoping memories or None.
|
||||
thread_id: The thread ID for scoping memories or None.
|
||||
user_id: The user ID for scoping memories or None.
|
||||
scope_to_per_operation_thread_id: Whether to scope memories to per-operation thread ID.
|
||||
context_prompt: The prompt to prepend to retrieved memories.
|
||||
mem0_client: A pre-created Mem0 MemoryClient or None to create a default client.
|
||||
"""
|
||||
should_close_client = False
|
||||
if mem0_client is None:
|
||||
from mem0 import AsyncMemoryClient
|
||||
|
||||
mem0_client = AsyncMemoryClient(api_key=api_key)
|
||||
should_close_client = True
|
||||
|
||||
super().__init__(
|
||||
api_key=api_key, # type: ignore[reportCallIssue]
|
||||
application_id=application_id, # type: ignore[reportCallIssue]
|
||||
agent_id=agent_id, # type: ignore[reportCallIssue]
|
||||
thread_id=thread_id, # type: ignore[reportCallIssue]
|
||||
user_id=user_id, # type: ignore[reportCallIssue]
|
||||
scope_to_per_operation_thread_id=scope_to_per_operation_thread_id, # type: ignore[reportCallIssue]
|
||||
context_prompt=context_prompt, # type: ignore[reportCallIssue]
|
||||
mem0_client=mem0_client, # type: ignore[reportCallIssue]
|
||||
)
|
||||
|
||||
self._per_operation_thread_id: str | None = None
|
||||
self._should_close_client = should_close_client
|
||||
|
||||
async def __aenter__(self) -> "Self":
|
||||
"""Async context manager entry."""
|
||||
if self.mem0_client:
|
||||
await self.mem0_client.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None:
|
||||
"""Async context manager exit."""
|
||||
if self._should_close_client and self.mem0_client:
|
||||
await self.mem0_client.__aexit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
async def thread_created(self, thread_id: str | None = None) -> None:
|
||||
"""Called when a new thread is created.
|
||||
|
||||
Args:
|
||||
thread_id: The ID of the thread or None.
|
||||
"""
|
||||
self._validate_per_operation_thread_id(thread_id)
|
||||
self._per_operation_thread_id = self._per_operation_thread_id or thread_id
|
||||
|
||||
async def messages_adding(self, thread_id: str | None, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
|
||||
"""Called when a new message is being added to the thread.
|
||||
|
||||
Args:
|
||||
thread_id: The ID of the thread or None.
|
||||
new_messages: New messages to add.
|
||||
"""
|
||||
self._validate_filters()
|
||||
self._validate_per_operation_thread_id(thread_id)
|
||||
self._per_operation_thread_id = self._per_operation_thread_id or thread_id
|
||||
|
||||
messages_list = [new_messages] if isinstance(new_messages, ChatMessage) else list(new_messages)
|
||||
|
||||
messages: list[dict[str, str]] = [
|
||||
{"role": message.role.value, "content": message.text}
|
||||
for message in messages_list
|
||||
if message.role.value in {"user", "assistant", "system"} and message.text and message.text.strip()
|
||||
]
|
||||
|
||||
if messages:
|
||||
await self.mem0_client.add( # type: ignore[misc]
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
agent_id=self.agent_id,
|
||||
run_id=self._per_operation_thread_id if self.scope_to_per_operation_thread_id else self.thread_id,
|
||||
metadata={"application_id": self.application_id},
|
||||
)
|
||||
|
||||
async def model_invoking(self, messages: ChatMessage | MutableSequence[ChatMessage]) -> Context:
|
||||
"""Called before invoking the AI model to provide context.
|
||||
|
||||
Args:
|
||||
messages: List of new messages in the thread.
|
||||
|
||||
Returns:
|
||||
Context: Context object containing instructions with memories.
|
||||
"""
|
||||
self._validate_filters()
|
||||
messages_list = [messages] if isinstance(messages, ChatMessage) else list(messages)
|
||||
input_text = "\n".join(msg.text for msg in messages_list if msg and msg.text and msg.text.strip())
|
||||
|
||||
memories = await self.mem0_client.search( # type: ignore[misc]
|
||||
query=input_text,
|
||||
user_id=self.user_id,
|
||||
agent_id=self.agent_id,
|
||||
run_id=self._per_operation_thread_id if self.scope_to_per_operation_thread_id else self.thread_id,
|
||||
)
|
||||
|
||||
line_separated_memories = "\n".join(memory.get("memory", "") for memory in memories)
|
||||
|
||||
content = TextContent(f"{self.context_prompt}\n{line_separated_memories}") if line_separated_memories else None
|
||||
|
||||
return Context(contents=[content] if content else None)
|
||||
|
||||
def _validate_filters(self) -> None:
|
||||
"""Validates that at least one filter is provided.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If no filters are provided.
|
||||
"""
|
||||
if not self.agent_id and not self.user_id and not self.application_id and not self.thread_id:
|
||||
raise ServiceInitializationError(
|
||||
"At least one of the filters: agent_id, user_id, application_id, or thread_id is required."
|
||||
)
|
||||
|
||||
def _validate_per_operation_thread_id(self, thread_id: str | None) -> None:
|
||||
"""Validates that a new thread ID doesn't conflict with an existing one when scoped.
|
||||
|
||||
Args:
|
||||
thread_id: The new thread ID or None.
|
||||
|
||||
Raises:
|
||||
ValueError: If a new thread ID is provided when one already exists.
|
||||
"""
|
||||
if (
|
||||
self.scope_to_per_operation_thread_id
|
||||
and thread_id
|
||||
and self._per_operation_thread_id
|
||||
and thread_id != self._per_operation_thread_id
|
||||
):
|
||||
raise ValueError(
|
||||
"Mem0Provider can only be used with one thread at a time when scope_to_per_operation_thread_id is True."
|
||||
)
|
||||
@@ -0,0 +1,94 @@
|
||||
[project]
|
||||
name = "agent-framework-mem0"
|
||||
description = "Mem0 integration for Microsoft Agent Framework."
|
||||
authors = [{ name = "Microsoft", email = "SK-Support@microsoft.com"}]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
version = "0.1.0b1"
|
||||
license-files = ["LICENSE"]
|
||||
urls.homepage = "https://learn.microsoft.com/en-us/semantic-kernel/overview/"
|
||||
urls.source = "https://github.com/microsoft/agent-framework/tree/main/python"
|
||||
urls.release_notes = "https://github.com/microsoft/agent-framework/releases?q=tag%3Apython-1&expanded=true"
|
||||
urls.issues = "https://github.com/microsoft/agent-framework/issues"
|
||||
classifiers = [
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
"Intended Audience :: Developers",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Framework :: Pydantic :: 2",
|
||||
"Typing :: Typed",
|
||||
]
|
||||
dependencies = [
|
||||
"agent-framework",
|
||||
"mem0ai>=0.1.117",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
prerelease = "if-necessary-or-explicit"
|
||||
environments = [
|
||||
"sys_platform == 'darwin'",
|
||||
"sys_platform == 'linux'",
|
||||
"sys_platform == 'win32'"
|
||||
]
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
filterwarnings = [
|
||||
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
|
||||
]
|
||||
timeout = 120
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
"**/__init__.py"
|
||||
]
|
||||
|
||||
[tool.pyright]
|
||||
extend = "../../pyproject.toml"
|
||||
exclude = ['tests']
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ['pydantic.mypy']
|
||||
strict = true
|
||||
python_version = "3.10"
|
||||
ignore_missing_imports = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
check_untyped_defs = true
|
||||
warn_return_any = true
|
||||
show_error_codes = true
|
||||
warn_unused_ignores = false
|
||||
disallow_incomplete_defs = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_any_unimported = true
|
||||
|
||||
[tool.bandit]
|
||||
targets = ["agent_framework_mem0"]
|
||||
exclude_dirs = ["tests"]
|
||||
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_mem0"
|
||||
test = "pytest --cov=agent_framework_mem0 --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
[tool.uv.build-backend]
|
||||
module-name = "agent_framework_mem0"
|
||||
module-root = ""
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.2,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
|
||||
def test_self_through_main() -> None:
|
||||
try:
|
||||
from agent_framework.mem0 import __version__
|
||||
except ImportError:
|
||||
__version__ = None
|
||||
|
||||
assert __version__ is not None
|
||||
|
||||
|
||||
def test_self() -> None:
|
||||
try:
|
||||
from agent_framework_mem0 import __version__
|
||||
except ImportError:
|
||||
__version__ = None
|
||||
|
||||
assert __version__ is not None
|
||||
|
||||
|
||||
def test_agent_framework() -> None:
|
||||
try:
|
||||
from agent_framework import __version__
|
||||
except ImportError:
|
||||
__version__ = None
|
||||
|
||||
assert __version__ is not None
|
||||
@@ -0,0 +1,481 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# pyright: reportPrivateUsage=false
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework import ChatMessage, Context, Role, TextContent
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from agent_framework.mem0 import Mem0Provider
|
||||
|
||||
|
||||
def test_mem0_provider_import():
|
||||
"""Test that Mem0Provider can be imported."""
|
||||
assert Mem0Provider is not None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mem0_client() -> AsyncMock:
|
||||
"""Create a mock Mem0 AsyncMemoryClient."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.add = AsyncMock()
|
||||
mock_client.search = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock()
|
||||
mock_client.async_client = AsyncMock()
|
||||
mock_client.async_client.aclose = AsyncMock()
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_messages() -> list[ChatMessage]:
|
||||
"""Create sample chat messages for testing."""
|
||||
return [
|
||||
ChatMessage(role=Role.USER, text="Hello, how are you?"),
|
||||
ChatMessage(role=Role.ASSISTANT, text="I'm doing well, thank you!"),
|
||||
ChatMessage(role=Role.SYSTEM, text="You are a helpful assistant"),
|
||||
]
|
||||
|
||||
|
||||
class TestMem0ProviderInitialization:
|
||||
"""Test initialization and configuration of Mem0Provider."""
|
||||
|
||||
def test_init_with_all_ids(self, mock_mem0_client: AsyncMock):
|
||||
"""Test initialization with all IDs provided."""
|
||||
provider = Mem0Provider(
|
||||
user_id="user123",
|
||||
agent_id="agent123",
|
||||
application_id="app123",
|
||||
thread_id="thread123",
|
||||
mem0_client=mock_mem0_client,
|
||||
)
|
||||
assert provider.user_id == "user123"
|
||||
assert provider.agent_id == "agent123"
|
||||
assert provider.application_id == "app123"
|
||||
assert provider.thread_id == "thread123"
|
||||
|
||||
def test_init_without_filters_succeeds(self, mock_mem0_client: AsyncMock):
|
||||
"""Test that initialization succeeds even without filters (validation happens during invocation)."""
|
||||
provider = Mem0Provider(mem0_client=mock_mem0_client)
|
||||
assert provider.user_id is None
|
||||
assert provider.agent_id is None
|
||||
assert provider.application_id is None
|
||||
assert provider.thread_id is None
|
||||
|
||||
def test_init_with_custom_context_prompt(self, mock_mem0_client: AsyncMock):
|
||||
"""Test initialization with custom context prompt."""
|
||||
custom_prompt = "## Custom Memories\nConsider these memories:"
|
||||
provider = Mem0Provider(user_id="user123", context_prompt=custom_prompt, mem0_client=mock_mem0_client)
|
||||
assert provider.context_prompt == custom_prompt
|
||||
|
||||
def test_init_with_scope_to_per_operation_thread_id(self, mock_mem0_client: AsyncMock):
|
||||
"""Test initialization with scope_to_per_operation_thread_id enabled."""
|
||||
provider = Mem0Provider(
|
||||
user_id="user123",
|
||||
scope_to_per_operation_thread_id=True,
|
||||
mem0_client=mock_mem0_client,
|
||||
)
|
||||
assert provider.scope_to_per_operation_thread_id is True
|
||||
|
||||
@patch("mem0.AsyncMemoryClient")
|
||||
def test_init_creates_default_client_when_none_provided(self, mock_memory_client_class: AsyncMock):
|
||||
"""Test that a default client is created when none is provided."""
|
||||
mock_client = AsyncMock()
|
||||
mock_memory_client_class.return_value = mock_client
|
||||
|
||||
provider = Mem0Provider(user_id="user123", api_key="test_api_key")
|
||||
|
||||
mock_memory_client_class.assert_called_once_with(api_key="test_api_key")
|
||||
assert provider.mem0_client == mock_client
|
||||
assert provider._should_close_client is True
|
||||
|
||||
def test_init_with_provided_client_should_not_close(self, mock_mem0_client: AsyncMock):
|
||||
"""Test that provided client should not be closed by provider."""
|
||||
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
||||
assert provider._should_close_client is False
|
||||
|
||||
|
||||
class TestMem0ProviderAsyncContextManager:
|
||||
"""Test async context manager behavior."""
|
||||
|
||||
async def test_async_context_manager_entry(self, mock_mem0_client: AsyncMock):
|
||||
"""Test async context manager entry returns self."""
|
||||
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
||||
async with provider as ctx:
|
||||
assert ctx is provider
|
||||
|
||||
async def test_async_context_manager_exit_closes_client_when_should_close(self):
|
||||
"""Test that async context manager closes client when it should."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock()
|
||||
mock_client.async_client = AsyncMock()
|
||||
mock_client.async_client.aclose = AsyncMock()
|
||||
|
||||
with patch("mem0.AsyncMemoryClient", return_value=mock_client):
|
||||
provider = Mem0Provider(user_id="user123", api_key="test_key")
|
||||
assert provider._should_close_client is True
|
||||
|
||||
async with provider:
|
||||
pass
|
||||
|
||||
mock_client.__aexit__.assert_called_once()
|
||||
|
||||
async def test_async_context_manager_exit_does_not_close_provided_client(self, mock_mem0_client: AsyncMock):
|
||||
"""Test that async context manager does not close provided client."""
|
||||
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
||||
assert provider._should_close_client is False
|
||||
|
||||
async with provider:
|
||||
pass
|
||||
|
||||
mock_mem0_client.__aexit__.assert_not_called()
|
||||
|
||||
|
||||
class TestMem0ProviderThreadMethods:
|
||||
"""Test thread lifecycle methods."""
|
||||
|
||||
async def test_thread_created_sets_per_operation_thread_id(self, mock_mem0_client: AsyncMock):
|
||||
"""Test that thread_created sets per-operation thread ID."""
|
||||
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
||||
|
||||
await provider.thread_created("thread123")
|
||||
|
||||
assert provider._per_operation_thread_id == "thread123"
|
||||
|
||||
async def test_thread_created_with_existing_thread_id(self, mock_mem0_client: AsyncMock):
|
||||
"""Test thread_created when thread ID already exists."""
|
||||
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
||||
provider._per_operation_thread_id = "existing_thread"
|
||||
|
||||
await provider.thread_created("thread123")
|
||||
|
||||
# Should not overwrite existing thread ID
|
||||
assert provider._per_operation_thread_id == "existing_thread"
|
||||
|
||||
async def test_thread_created_validation_with_scope_enabled(self, mock_mem0_client: AsyncMock):
|
||||
"""Test thread_created validation when scope_to_per_operation_thread_id is enabled."""
|
||||
provider = Mem0Provider(
|
||||
user_id="user123",
|
||||
scope_to_per_operation_thread_id=True,
|
||||
mem0_client=mock_mem0_client,
|
||||
)
|
||||
provider._per_operation_thread_id = "existing_thread"
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await provider.thread_created("different_thread")
|
||||
|
||||
assert "can only be used with one thread at a time" in str(exc_info.value)
|
||||
|
||||
async def test_messages_adding_sets_per_operation_thread_id(
|
||||
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
|
||||
):
|
||||
"""Test that messages_adding sets per-operation thread ID."""
|
||||
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
||||
|
||||
await provider.messages_adding("thread123", sample_messages)
|
||||
|
||||
assert provider._per_operation_thread_id == "thread123"
|
||||
|
||||
|
||||
class TestMem0ProviderMessagesAdding:
|
||||
"""Test messages_adding method."""
|
||||
|
||||
async def test_messages_adding_fails_without_filters(self, mock_mem0_client: AsyncMock):
|
||||
"""Test that messages_adding fails when no filters are provided."""
|
||||
provider = Mem0Provider(mem0_client=mock_mem0_client)
|
||||
message = ChatMessage(role=Role.USER, text="Hello!")
|
||||
|
||||
with pytest.raises(ServiceInitializationError) as exc_info:
|
||||
await provider.messages_adding("thread123", message)
|
||||
|
||||
assert "At least one of the filters" in str(exc_info.value)
|
||||
|
||||
async def test_messages_adding_single_message(self, mock_mem0_client: AsyncMock):
|
||||
"""Test adding a single message."""
|
||||
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
||||
message = ChatMessage(role=Role.USER, text="Hello!")
|
||||
|
||||
await provider.messages_adding("thread123", message)
|
||||
|
||||
mock_mem0_client.add.assert_called_once()
|
||||
call_args = mock_mem0_client.add.call_args
|
||||
assert call_args.kwargs["messages"] == [{"role": "user", "content": "Hello!"}]
|
||||
assert call_args.kwargs["user_id"] == "user123"
|
||||
|
||||
async def test_messages_adding_multiple_messages(
|
||||
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
|
||||
):
|
||||
"""Test adding multiple messages."""
|
||||
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
||||
|
||||
await provider.messages_adding("thread123", sample_messages)
|
||||
|
||||
mock_mem0_client.add.assert_called_once()
|
||||
call_args = mock_mem0_client.add.call_args
|
||||
expected_messages = [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you!"},
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
]
|
||||
assert call_args.kwargs["messages"] == expected_messages
|
||||
|
||||
async def test_messages_adding_with_agent_id(self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]):
|
||||
"""Test adding messages with agent_id."""
|
||||
provider = Mem0Provider(agent_id="agent123", mem0_client=mock_mem0_client)
|
||||
|
||||
await provider.messages_adding("thread123", sample_messages)
|
||||
|
||||
call_args = mock_mem0_client.add.call_args
|
||||
assert call_args.kwargs["agent_id"] == "agent123"
|
||||
assert call_args.kwargs["user_id"] is None
|
||||
|
||||
async def test_messages_adding_with_application_id(
|
||||
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
|
||||
):
|
||||
"""Test adding messages with application_id in metadata."""
|
||||
provider = Mem0Provider(user_id="user123", application_id="app123", mem0_client=mock_mem0_client)
|
||||
|
||||
await provider.messages_adding("thread123", sample_messages)
|
||||
|
||||
call_args = mock_mem0_client.add.call_args
|
||||
assert call_args.kwargs["metadata"] == {"application_id": "app123"}
|
||||
|
||||
async def test_messages_adding_with_scope_to_per_operation_thread_id(
|
||||
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
|
||||
):
|
||||
"""Test adding messages with scope_to_per_operation_thread_id enabled."""
|
||||
provider = Mem0Provider(
|
||||
user_id="user123",
|
||||
thread_id="base_thread",
|
||||
scope_to_per_operation_thread_id=True,
|
||||
mem0_client=mock_mem0_client,
|
||||
)
|
||||
provider._per_operation_thread_id = "operation_thread"
|
||||
|
||||
await provider.messages_adding("operation_thread", sample_messages)
|
||||
|
||||
call_args = mock_mem0_client.add.call_args
|
||||
assert call_args.kwargs["run_id"] == "operation_thread"
|
||||
|
||||
async def test_messages_adding_without_scope_uses_base_thread_id(
|
||||
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
|
||||
):
|
||||
"""Test adding messages without scope uses base thread_id."""
|
||||
provider = Mem0Provider(
|
||||
user_id="user123",
|
||||
thread_id="base_thread",
|
||||
scope_to_per_operation_thread_id=False,
|
||||
mem0_client=mock_mem0_client,
|
||||
)
|
||||
|
||||
await provider.messages_adding("operation_thread", sample_messages)
|
||||
|
||||
call_args = mock_mem0_client.add.call_args
|
||||
assert call_args.kwargs["run_id"] == "base_thread"
|
||||
|
||||
async def test_messages_adding_filters_empty_messages(self, mock_mem0_client: AsyncMock):
|
||||
"""Test that empty or invalid messages are filtered out."""
|
||||
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
||||
messages = [
|
||||
ChatMessage(role=Role.USER, text=""), # Empty text
|
||||
ChatMessage(role=Role.USER, text=" "), # Whitespace only
|
||||
ChatMessage(role=Role.USER, text="Valid message"),
|
||||
]
|
||||
|
||||
await provider.messages_adding("thread123", messages)
|
||||
|
||||
call_args = mock_mem0_client.add.call_args
|
||||
# Should only include the valid message
|
||||
assert call_args.kwargs["messages"] == [{"role": "user", "content": "Valid message"}]
|
||||
|
||||
async def test_messages_adding_skips_when_no_valid_messages(self, mock_mem0_client: AsyncMock):
|
||||
"""Test that mem0 client is not called when no valid messages exist."""
|
||||
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
||||
messages = [
|
||||
ChatMessage(role=Role.USER, text=""),
|
||||
ChatMessage(role=Role.USER, text=" "),
|
||||
]
|
||||
|
||||
await provider.messages_adding("thread123", messages)
|
||||
|
||||
mock_mem0_client.add.assert_not_called()
|
||||
|
||||
|
||||
class TestMem0ProviderModelInvoking:
|
||||
"""Test model_invoking method."""
|
||||
|
||||
async def test_model_invoking_fails_without_filters(self, mock_mem0_client: AsyncMock):
|
||||
"""Test that model_invoking fails when no filters are provided."""
|
||||
provider = Mem0Provider(mem0_client=mock_mem0_client)
|
||||
message = ChatMessage(role=Role.USER, text="What's the weather?")
|
||||
|
||||
with pytest.raises(ServiceInitializationError) as exc_info:
|
||||
await provider.model_invoking(message)
|
||||
|
||||
assert "At least one of the filters" in str(exc_info.value)
|
||||
|
||||
async def test_model_invoking_single_message(self, mock_mem0_client: AsyncMock):
|
||||
"""Test model_invoking with a single message."""
|
||||
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
||||
message = ChatMessage(role=Role.USER, text="What's the weather?")
|
||||
|
||||
# Mock search results
|
||||
mock_mem0_client.search.return_value = [
|
||||
{"memory": "User likes outdoor activities"},
|
||||
{"memory": "User lives in Seattle"},
|
||||
]
|
||||
|
||||
context = await provider.model_invoking(message)
|
||||
|
||||
mock_mem0_client.search.assert_called_once()
|
||||
call_args = mock_mem0_client.search.call_args
|
||||
assert call_args.kwargs["query"] == "What's the weather?"
|
||||
assert call_args.kwargs["user_id"] == "user123"
|
||||
|
||||
assert isinstance(context, Context)
|
||||
expected_instructions = (
|
||||
"## Memories\nConsider the following memories when answering user questions:\n"
|
||||
"User likes outdoor activities\nUser lives in Seattle"
|
||||
)
|
||||
|
||||
assert context.contents
|
||||
assert isinstance(context.contents[0], TextContent)
|
||||
assert context.contents[0].text == expected_instructions
|
||||
|
||||
async def test_model_invoking_multiple_messages(
|
||||
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
|
||||
):
|
||||
"""Test model_invoking with multiple messages."""
|
||||
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
||||
|
||||
mock_mem0_client.search.return_value = [{"memory": "Previous conversation context"}]
|
||||
|
||||
await provider.model_invoking(sample_messages)
|
||||
|
||||
call_args = mock_mem0_client.search.call_args
|
||||
expected_query = "Hello, how are you?\nI'm doing well, thank you!\nYou are a helpful assistant"
|
||||
assert call_args.kwargs["query"] == expected_query
|
||||
|
||||
async def test_model_invoking_with_agent_id(self, mock_mem0_client: AsyncMock):
|
||||
"""Test model_invoking with agent_id."""
|
||||
provider = Mem0Provider(agent_id="agent123", mem0_client=mock_mem0_client)
|
||||
message = ChatMessage(role=Role.USER, text="Hello")
|
||||
|
||||
mock_mem0_client.search.return_value = []
|
||||
|
||||
await provider.model_invoking(message)
|
||||
|
||||
call_args = mock_mem0_client.search.call_args
|
||||
assert call_args.kwargs["agent_id"] == "agent123"
|
||||
assert call_args.kwargs["user_id"] is None
|
||||
|
||||
async def test_model_invoking_with_scope_to_per_operation_thread_id(self, mock_mem0_client: AsyncMock):
|
||||
"""Test model_invoking with scope_to_per_operation_thread_id enabled."""
|
||||
provider = Mem0Provider(
|
||||
user_id="user123",
|
||||
thread_id="base_thread",
|
||||
scope_to_per_operation_thread_id=True,
|
||||
mem0_client=mock_mem0_client,
|
||||
)
|
||||
provider._per_operation_thread_id = "operation_thread"
|
||||
message = ChatMessage(role=Role.USER, text="Hello")
|
||||
|
||||
mock_mem0_client.search.return_value = []
|
||||
|
||||
await provider.model_invoking(message)
|
||||
|
||||
call_args = mock_mem0_client.search.call_args
|
||||
assert call_args.kwargs["run_id"] == "operation_thread"
|
||||
|
||||
async def test_model_invoking_no_memories_returns_none_instructions(self, mock_mem0_client: AsyncMock):
|
||||
"""Test that no memories returns context with None instructions."""
|
||||
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
||||
message = ChatMessage(role=Role.USER, text="Hello")
|
||||
|
||||
mock_mem0_client.search.return_value = []
|
||||
|
||||
context = await provider.model_invoking(message)
|
||||
|
||||
assert isinstance(context, Context)
|
||||
assert not context.contents
|
||||
|
||||
async def test_model_invoking_filters_empty_message_text(self, mock_mem0_client: AsyncMock):
|
||||
"""Test that empty message text is filtered out from query."""
|
||||
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
||||
messages = [
|
||||
ChatMessage(role=Role.USER, text=""),
|
||||
ChatMessage(role=Role.USER, text="Valid message"),
|
||||
ChatMessage(role=Role.USER, text=" "),
|
||||
]
|
||||
|
||||
mock_mem0_client.search.return_value = []
|
||||
|
||||
await provider.model_invoking(messages)
|
||||
|
||||
call_args = mock_mem0_client.search.call_args
|
||||
assert call_args.kwargs["query"] == "Valid message"
|
||||
|
||||
async def test_model_invoking_custom_context_prompt(self, mock_mem0_client: AsyncMock):
|
||||
"""Test model_invoking with custom context prompt."""
|
||||
custom_prompt = "## Custom Context\nRemember these details:"
|
||||
provider = Mem0Provider(
|
||||
user_id="user123",
|
||||
context_prompt=custom_prompt,
|
||||
mem0_client=mock_mem0_client,
|
||||
)
|
||||
message = ChatMessage(role=Role.USER, text="Hello")
|
||||
|
||||
mock_mem0_client.search.return_value = [{"memory": "Test memory"}]
|
||||
|
||||
context = await provider.model_invoking(message)
|
||||
|
||||
expected_instructions = "## Custom Context\nRemember these details:\nTest memory"
|
||||
assert context.contents
|
||||
assert isinstance(context.contents[0], TextContent)
|
||||
assert context.contents[0].text == expected_instructions
|
||||
|
||||
|
||||
class TestMem0ProviderValidation:
|
||||
"""Test validation methods."""
|
||||
|
||||
def test_validate_per_operation_thread_id_success(self, mock_mem0_client: AsyncMock):
|
||||
"""Test successful validation of per-operation thread ID."""
|
||||
provider = Mem0Provider(
|
||||
user_id="user123",
|
||||
scope_to_per_operation_thread_id=True,
|
||||
mem0_client=mock_mem0_client,
|
||||
)
|
||||
provider._per_operation_thread_id = "thread123"
|
||||
|
||||
# Should not raise exception for same thread ID
|
||||
provider._validate_per_operation_thread_id("thread123")
|
||||
|
||||
# Should not raise exception for None
|
||||
provider._validate_per_operation_thread_id(None)
|
||||
|
||||
def test_validate_per_operation_thread_id_failure(self, mock_mem0_client: AsyncMock):
|
||||
"""Test validation failure for conflicting thread IDs."""
|
||||
provider = Mem0Provider(
|
||||
user_id="user123",
|
||||
scope_to_per_operation_thread_id=True,
|
||||
mem0_client=mock_mem0_client,
|
||||
)
|
||||
provider._per_operation_thread_id = "thread123"
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
provider._validate_per_operation_thread_id("different_thread")
|
||||
|
||||
assert "can only be used with one thread at a time" in str(exc_info.value)
|
||||
|
||||
def test_validate_per_operation_thread_id_disabled_scope(self, mock_mem0_client: AsyncMock):
|
||||
"""Test that validation is skipped when scope is disabled."""
|
||||
provider = Mem0Provider(
|
||||
user_id="user123",
|
||||
scope_to_per_operation_thread_id=False,
|
||||
mem0_client=mock_mem0_client,
|
||||
)
|
||||
provider._per_operation_thread_id = "thread123"
|
||||
|
||||
# Should not raise exception even with different thread ID
|
||||
provider._validate_per_operation_thread_id("different_thread")
|
||||
@@ -7,6 +7,7 @@ dependencies = [
|
||||
"agent-framework",
|
||||
"agent-framework-azure",
|
||||
"agent-framework-foundry",
|
||||
"agent-framework-mem0",
|
||||
"agent-framework-workflow",
|
||||
]
|
||||
|
||||
@@ -63,8 +64,9 @@ exclude = [ "packages/agent_framework_project.egg-info" ]
|
||||
agent-framework = { workspace = true }
|
||||
agent-framework-azure = { workspace = true }
|
||||
agent-framework-foundry = { workspace = true }
|
||||
agent-framework-workflow = { workspace = true }
|
||||
agent-framework-mem0 = { workspace = true }
|
||||
agent-framework-runtime = { workspace = true }
|
||||
agent-framework-workflow = { workspace = true }
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
@@ -191,7 +193,7 @@ build = "python run_tasks_in_packages_if_exists.py build"
|
||||
# combined checks
|
||||
check = ["fmt", "lint", "pyright", "mypy", "test", "markdown-code-lint", "samples-code-check"]
|
||||
pre-commit-check = ["fmt", "lint", "pyright", "markdown-code-lint", "samples-code-check"]
|
||||
all-tests = "pytest --cov=agent_framework --cov=agent_framework_azure --cov=agent_framework_foundry --cov=agent_framework_workflow --cov-report=term-missing:skip-covered packages/**/tests"
|
||||
all-tests = "pytest --import-mode=importlib --cov=agent_framework --cov=agent_framework_azure --cov=agent_framework_foundry --cov=agent_framework_mem0 --cov=agent_framework_workflow --cov-report=term-missing:skip-covered packages/azure/tests packages/foundry/tests packages/main/tests packages/mem0/tests packages/workflow/tests"
|
||||
|
||||
[tool.poe.tasks.venv]
|
||||
cmd = "uv venv --clear --python $python"
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
# Mem0 Context Provider Examples
|
||||
|
||||
[Mem0](https://mem0.ai/) is a self-improving memory layer for Large Language Models that enables applications to have long-term memory capabilities. The Agent Framework's Mem0 context provider integrates with Mem0's API to provide persistent memory across conversation sessions.
|
||||
|
||||
This folder contains examples demonstrating how to use the Mem0 context provider with the Agent Framework for persistent memory and context management across conversations.
|
||||
|
||||
## Examples
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| [`mem0_basic.py`](mem0_basic.py) | Basic example of using Mem0 context provider to store and retrieve user preferences across different conversation threads. |
|
||||
| [`mem0_threads.py`](mem0_threads.py) | Advanced example demonstrating different thread scoping strategies with Mem0. Covers global thread scope (memories shared across all operations), per-operation thread scope (memories isolated per thread), and multiple agents with different memory configurations for personal vs. work contexts. |
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Required Resources
|
||||
|
||||
1. [Mem0 API Key](https://app.mem0.ai/) - Sign up for a Mem0 account and get your API key
|
||||
2. Azure AI Foundry project endpoint (used in these examples)
|
||||
3. Azure CLI authentication (run `az login`)
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
Set the following environment variables:
|
||||
|
||||
**For Mem0:**
|
||||
- `MEM0_API_KEY`: Your Mem0 API key (alternatively, pass it as `api_key` parameter to `Mem0Provider`)
|
||||
|
||||
**For Azure AI Foundry:**
|
||||
- `FOUNDRY_PROJECT_ENDPOINT`: Your Azure AI Foundry project endpoint
|
||||
- `FOUNDRY_MODEL_DEPLOYMENT_NAME`: The name of your model deployment
|
||||
|
||||
## Key Concepts
|
||||
|
||||
### Memory Scoping
|
||||
|
||||
The Mem0 context provider supports different scoping strategies:
|
||||
|
||||
- **Global Scope** (`scope_to_per_operation_thread_id=False`): Memories are shared across all conversation threads
|
||||
- **Thread Scope** (`scope_to_per_operation_thread_id=True`): Memories are isolated per conversation thread
|
||||
|
||||
### Memory Association
|
||||
|
||||
Mem0 records can be associated with different identifiers:
|
||||
|
||||
- `user_id`: Associate memories with a specific user
|
||||
- `agent_id`: Associate memories with a specific agent
|
||||
- `thread_id`: Associate memories with a specific conversation thread
|
||||
- `application_id`: Associate memories with an application context
|
||||
@@ -0,0 +1,72 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from agent_framework.mem0 import Mem0Provider
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
|
||||
|
||||
def retrieve_company_report(company_code: str, detailed: bool) -> str:
|
||||
if company_code != "CNTS":
|
||||
raise ValueError("Company code not found")
|
||||
if not detailed:
|
||||
return "CNTS is a company that specializes in technology."
|
||||
return (
|
||||
"CNTS is a company that specializes in technology. "
|
||||
"It had a revenue of $10 million in 2022. It has 100 employees."
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Example of memory usage with Mem0 context provider."""
|
||||
print("=== Mem0 Context Provider Example ===")
|
||||
|
||||
# Each record in Mem0 should be associated with agent_id or user_id or application_id or thread_id.
|
||||
# In this example, we associate Mem0 records with user_id.
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
# For Azure authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
|
||||
# authentication option.
|
||||
# For Mem0 authentication, set Mem0 API key via "api_key" parameter or MEM0_API_KEY environment variable.
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
FoundryChatClient(async_credential=credential).create_agent(
|
||||
name="FriendlyAssistant",
|
||||
instructions="You are a friendly assistant.",
|
||||
tools=retrieve_company_report,
|
||||
context_providers=Mem0Provider(user_id=user_id),
|
||||
) as agent,
|
||||
):
|
||||
# First ask the agent to retrieve a company report with no previous context.
|
||||
# The agent will not be able to invoke the tool, since it doesn't know
|
||||
# the company code or the report format, so it should ask for clarification.
|
||||
query = "Please retrieve my company report"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result}\n")
|
||||
|
||||
# Now tell the agent the company code and the report format that you want to use
|
||||
# and it should be able to invoke the tool and return the report.
|
||||
query = "I always work with CNTS and I always want a detailed report format. Please remember and retrieve it."
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result}\n")
|
||||
|
||||
print("\nRequest within a new thread:")
|
||||
# Create a new thread for the agent.
|
||||
# The new thread has no context of the previous conversation.
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
# Since we have the mem0 component in the thread, the agent should be able to
|
||||
# retrieve the company report without asking for clarification, as it will
|
||||
# be able to remember the user preferences from Mem0 component.
|
||||
query = "Please retrieve my company report"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query, thread=thread)
|
||||
print(f"Agent: {result}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,164 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from agent_framework.mem0 import Mem0Provider
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
|
||||
|
||||
def get_user_preferences(user_id: str) -> str:
|
||||
"""Mock function to get user preferences."""
|
||||
preferences = {
|
||||
"user123": "Prefers concise responses and technical details",
|
||||
"user456": "Likes detailed explanations with examples",
|
||||
}
|
||||
return preferences.get(user_id, "No specific preferences found")
|
||||
|
||||
|
||||
async def example_global_thread_scope() -> None:
|
||||
"""Example 1: Global thread_id scope (memories shared across all operations)."""
|
||||
print("1. Global Thread Scope Example:")
|
||||
print("-" * 40)
|
||||
|
||||
global_thread_id = str(uuid.uuid4())
|
||||
user_id = "user123"
|
||||
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
FoundryChatClient(async_credential=credential).create_agent(
|
||||
name="GlobalMemoryAssistant",
|
||||
instructions="You are an assistant that remembers user preferences across conversations.",
|
||||
tools=get_user_preferences,
|
||||
context_providers=Mem0Provider(
|
||||
user_id=user_id,
|
||||
thread_id=global_thread_id,
|
||||
scope_to_per_operation_thread_id=False, # Share memories across all threads
|
||||
),
|
||||
) as global_agent,
|
||||
):
|
||||
# Store some preferences in the global scope
|
||||
query = "Remember that I prefer technical responses with code examples when discussing programming."
|
||||
print(f"User: {query}")
|
||||
result = await global_agent.run(query)
|
||||
print(f"Agent: {result}\n")
|
||||
|
||||
# Create a new thread - but memories should still be accessible due to global scope
|
||||
new_thread = global_agent.get_new_thread()
|
||||
query = "What do you know about my preferences?"
|
||||
print(f"User (new thread): {query}")
|
||||
result = await global_agent.run(query, thread=new_thread)
|
||||
print(f"Agent: {result}\n")
|
||||
|
||||
|
||||
async def example_per_operation_thread_scope() -> None:
|
||||
"""Example 2: Per-operation thread scope (memories isolated per thread).
|
||||
|
||||
Note: When scope_to_per_operation_thread_id=True, the provider is bound to a single thread
|
||||
throughout its lifetime. Use the same thread object for all operations with that provider.
|
||||
"""
|
||||
print("2. Per-Operation Thread Scope Example:")
|
||||
print("-" * 40)
|
||||
|
||||
user_id = "user123"
|
||||
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
FoundryChatClient(async_credential=credential).create_agent(
|
||||
name="ScopedMemoryAssistant",
|
||||
instructions="You are an assistant with thread-scoped memory.",
|
||||
tools=get_user_preferences,
|
||||
context_providers=Mem0Provider(
|
||||
user_id=user_id,
|
||||
scope_to_per_operation_thread_id=True, # Isolate memories per thread
|
||||
),
|
||||
) as scoped_agent,
|
||||
):
|
||||
# Create a specific thread for this scoped provider
|
||||
dedicated_thread = scoped_agent.get_new_thread()
|
||||
|
||||
# Store some information in the dedicated thread
|
||||
query = "Remember that for this conversation, I'm working on a Python project about data analysis."
|
||||
print(f"User (dedicated thread): {query}")
|
||||
result = await scoped_agent.run(query, thread=dedicated_thread)
|
||||
print(f"Agent: {result}\n")
|
||||
|
||||
# Test memory retrieval in the same dedicated thread
|
||||
query = "What project am I working on?"
|
||||
print(f"User (same dedicated thread): {query}")
|
||||
result = await scoped_agent.run(query, thread=dedicated_thread)
|
||||
print(f"Agent: {result}\n")
|
||||
|
||||
# Store more information in the same thread
|
||||
query = "Also remember that I prefer using pandas and matplotlib for this project."
|
||||
print(f"User (same dedicated thread): {query}")
|
||||
result = await scoped_agent.run(query, thread=dedicated_thread)
|
||||
print(f"Agent: {result}\n")
|
||||
|
||||
# Test comprehensive memory retrieval
|
||||
query = "What do you know about my current project and preferences?"
|
||||
print(f"User (same dedicated thread): {query}")
|
||||
result = await scoped_agent.run(query, thread=dedicated_thread)
|
||||
print(f"Agent: {result}\n")
|
||||
|
||||
|
||||
async def example_multiple_agents() -> None:
|
||||
"""Example 3: Multiple agents with different thread configurations."""
|
||||
print("3. Multiple Agents with Different Thread Configurations:")
|
||||
print("-" * 40)
|
||||
|
||||
agent_id_1 = "agent_personal"
|
||||
agent_id_2 = "agent_work"
|
||||
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
FoundryChatClient(async_credential=credential).create_agent(
|
||||
name="PersonalAssistant",
|
||||
instructions="You are a personal assistant that helps with personal tasks.",
|
||||
context_providers=Mem0Provider(
|
||||
agent_id=agent_id_1,
|
||||
),
|
||||
) as personal_agent,
|
||||
FoundryChatClient(async_credential=credential).create_agent(
|
||||
name="WorkAssistant",
|
||||
instructions="You are a work assistant that helps with professional tasks.",
|
||||
context_providers=Mem0Provider(
|
||||
agent_id=agent_id_2,
|
||||
),
|
||||
) as work_agent,
|
||||
):
|
||||
# Store personal information
|
||||
query = "Remember that I like to exercise at 6 AM and prefer outdoor activities."
|
||||
print(f"User to Personal Agent: {query}")
|
||||
result = await personal_agent.run(query)
|
||||
print(f"Personal Agent: {result}\n")
|
||||
|
||||
# Store work information
|
||||
query = "Remember that I have team meetings every Tuesday at 2 PM."
|
||||
print(f"User to Work Agent: {query}")
|
||||
result = await work_agent.run(query)
|
||||
print(f"Work Agent: {result}\n")
|
||||
|
||||
# Test memory isolation
|
||||
query = "What do you know about my schedule?"
|
||||
print(f"User to Personal Agent: {query}")
|
||||
result = await personal_agent.run(query)
|
||||
print(f"Personal Agent: {result}\n")
|
||||
|
||||
print(f"User to Work Agent: {query}")
|
||||
result = await work_agent.run(query)
|
||||
print(f"Work Agent: {result}\n")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Run all Mem0 thread management examples."""
|
||||
print("=== Mem0 Thread Management Example ===\n")
|
||||
|
||||
await example_global_thread_scope()
|
||||
await example_per_operation_thread_scope()
|
||||
await example_multiple_agents()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Generated
+2237
-1902
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user