Python: Context providers abstraction and Mem0 implementation (#631)

* Added context provider abstractions

* Added mem0 implementation

* Example and small fixes

* Added unit tests for agent

* Added unit tests for mem0 provider

* Updated README

* Small doc updates

* Update python/packages/mem0/agent_framework_mem0/_provider.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Small fixes in tests

* Renaming based on PR feedback

* Small fixes

* Added tests for AggregateContextProvider

* Small improvements

* More improvements based on PR feedback

* Small constant update

* Added more examples

* Added README for Mem0 examples

* Small updates to API

* Updated initialization logic

* Updates for context manager

* Updated Context class

* Dependency update

* Revert changes

* Fixed tests

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Chris <66376200+crickman@users.noreply.github.com>
This commit is contained in:
Dmytro Struk
2025-09-10 14:11:42 -07:00
committed by GitHub
Unverified
parent 89c8418705
commit 57d09afe04
25 changed files with 4166 additions and 1915 deletions
+2
View File
@@ -15,3 +15,5 @@ AGENT_FRAMEWORK_OTLP_ENDPOINT="http://localhost:4317/"
AGENT_FRAMEWORK_ENABLE_OTEL=true
AGENT_FRAMEWORK_ENABLE_SENSITIVE_DATA=true
AGENT_FRAMEWORK_WORKFLOW_ENABLE_OTEL=true
# Mem0
MEM0_API_KEY=""
@@ -12,6 +12,7 @@ from ._agents import * # noqa: F403
from ._clients import * # noqa: F403
from ._logging import * # noqa: F403
from ._mcp import * # noqa: F403
from ._memory import * # noqa: F403
from ._threads import * # noqa: F403
from ._tools import * # noqa: F403
from ._types import * # noqa: F403
@@ -3,7 +3,6 @@
import sys
from collections.abc import AsyncIterable, Callable, MutableMapping, Sequence
from contextlib import AbstractAsyncContextManager, AsyncExitStack
from itertools import chain
from typing import Any, ClassVar, Literal, Protocol, TypeVar, runtime_checkable
from uuid import uuid4
@@ -12,6 +11,7 @@ from pydantic import BaseModel, Field, PrivateAttr
from ._clients import BaseChatClient, ChatClientProtocol
from ._logging import get_logger
from ._mcp import MCPTool
from ._memory import AggregateContextProvider, Context, ContextProvider
from ._pydantic import AFBaseModel
from ._threads import AgentThread, ChatMessageStore, deserialize_thread_state, thread_on_new_messages
from ._tools import FUNCTION_INVOKING_CHAT_CLIENT_MARKER, ToolProtocol
@@ -137,12 +137,13 @@ class BaseAgent(AFBaseModel):
name: The name of the agent, can be None.
description: The description of the agent.
display_name: The display name of the agent, which is either the name or id.
context_providers: The collection of multiple context providers to include during agent invocation.
"""
id: str = Field(default_factory=lambda: str(uuid4()))
name: str | None = None
description: str | None = None
context_providers: AggregateContextProvider | None = None
async def _notify_thread_of_new_messages(
self, thread: AgentThread, new_messages: ChatMessage | Sequence[ChatMessage]
@@ -214,6 +215,7 @@ class ChatAgent(BaseAgent):
user: str | None = None,
additional_properties: dict[str, Any] | None = None,
chat_message_store_factory: Callable[[], ChatMessageStore] | None = None,
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
**kwargs: Any,
) -> None:
"""Create a ChatAgent.
@@ -248,6 +250,7 @@ class ChatAgent(BaseAgent):
additional_properties: additional properties to include in the request.
chat_message_store_factory: factory function to create an instance of ChatMessageStore. If not provided,
the default in-memory store will be used.
context_providers: The collection of multiple context providers to include during agent invocation.
kwargs: any additional keyword arguments.
Unused, can be used by subclasses of this Agent.
"""
@@ -258,6 +261,8 @@ class ChatAgent(BaseAgent):
kwargs.update(additional_properties or {})
aggregate_context_providers = self._prepare_context_providers(context_providers)
# We ignore the MCP Servers here and store them separately,
# we add their functions to the tools list at runtime
normalized_tools = [] if tools is None else tools if isinstance(tools, list) else [tools]
@@ -266,6 +271,7 @@ class ChatAgent(BaseAgent):
args: dict[str, Any] = {
"chat_client": chat_client,
"chat_message_store_factory": chat_message_store_factory,
"context_providers": aggregate_context_providers,
"chat_options": ChatOptions(
ai_model_id=model,
frequency_penalty=frequency_penalty,
@@ -301,12 +307,16 @@ class ChatAgent(BaseAgent):
async def __aenter__(self) -> "Self":
"""Async context manager entry.
If either the chat_client or the local_mcp_tools are context managers,
If any of the chat_client, local_mcp_tools, or context_providers are context managers,
they will be entered into the async exit stack to ensure proper cleanup.
This list might be extended in the future.
"""
for context_manager in chain([self.chat_client], self._local_mcp_tools):
context_managers = [self.chat_client, *self._local_mcp_tools]
if self.context_providers:
context_managers.append(self.context_providers)
for context_manager in context_managers:
if isinstance(context_manager, AbstractAsyncContextManager):
await self._async_exit_stack.enter_async_context(context_manager)
return self
@@ -388,7 +398,10 @@ class ChatAgent(BaseAgent):
will only be passed to functions that are called.
"""
input_messages = self._normalize_messages(messages)
thread, thread_messages = await self._prepare_thread_and_messages(thread=thread, input_messages=input_messages)
context = await self.context_providers.model_invoking(input_messages) if self.context_providers else None
thread, thread_messages = await self._prepare_thread_and_messages(
thread=thread, context=context, input_messages=input_messages
)
agent_name = self._get_agent_name()
# Resolve final tool list (runtime provided tools + local MCP server tools)
@@ -441,6 +454,10 @@ class ChatAgent(BaseAgent):
await self._notify_thread_of_new_messages(thread, input_messages)
await self._notify_thread_of_new_messages(thread, response.messages)
if self.context_providers:
await self.context_providers.thread_created(response.conversation_id)
await self.context_providers.messages_adding(thread.service_thread_id, input_messages + response.messages)
return AgentRunResponse(
messages=response.messages,
response_id=response.response_id,
@@ -509,7 +526,10 @@ class ChatAgent(BaseAgent):
"""
input_messages = self._normalize_messages(messages)
thread, thread_messages = await self._prepare_thread_and_messages(thread=thread, input_messages=input_messages)
context = await self.context_providers.model_invoking(input_messages) if self.context_providers else None
thread, thread_messages = await self._prepare_thread_and_messages(
thread=thread, context=context, input_messages=input_messages
)
agent_name = self._get_agent_name()
response_updates: list[ChatResponseUpdate] = []
@@ -575,6 +595,10 @@ class ChatAgent(BaseAgent):
await self._notify_thread_of_new_messages(thread, input_messages)
await self._notify_thread_of_new_messages(thread, response.messages)
if self.context_providers:
await self.context_providers.thread_created(response.conversation_id)
await self.context_providers.messages_adding(thread.service_thread_id, input_messages + response.messages)
def get_new_thread(self) -> AgentThread:
message_store: ChatMessageStore | None = None
@@ -617,12 +641,14 @@ class ChatAgent(BaseAgent):
self,
*,
thread: AgentThread | None,
context: Context | None,
input_messages: list[ChatMessage] | None = None,
) -> tuple[AgentThread, list[ChatMessage]]:
"""Prepare the messages for agent execution.
Args:
thread: The conversation thread.
context: Context to include in messages.
input_messages: Messages to process.
Returns:
@@ -636,6 +662,8 @@ class ChatAgent(BaseAgent):
messages: list[ChatMessage] = []
if self.instructions:
messages.append(ChatMessage(role=Role.SYSTEM, text=self.instructions))
if context and context.contents:
messages.append(ChatMessage(role=Role.SYSTEM, contents=context.contents))
if thread.message_store:
messages.extend(await thread.message_store.list_messages() or [])
messages.extend(input_messages or [])
@@ -658,3 +686,18 @@ class ChatAgent(BaseAgent):
def _get_agent_name(self) -> str:
return self.name or "UnnamedAgent"
def _prepare_context_providers(
self,
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
) -> AggregateContextProvider | None:
if not context_providers:
return None
if isinstance(context_providers, AggregateContextProvider):
return context_providers
if isinstance(context_providers, ContextProvider):
return AggregateContextProvider([context_providers])
return AggregateContextProvider(context_providers)
@@ -9,6 +9,7 @@ from pydantic import BaseModel, Field
from ._logging import get_logger
from ._mcp import MCPTool
from ._memory import AggregateContextProvider, ContextProvider
from ._pydantic import AFBaseModel
from ._threads import ChatMessageStore
from ._tools import ToolProtocol
@@ -463,6 +464,7 @@ class BaseChatClient(AFBaseModel, ABC):
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
| None = None,
chat_message_store_factory: Callable[[], ChatMessageStore] | None = None,
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
**kwargs: Any,
) -> "ChatAgent":
"""Create an agent with the given name and instructions.
@@ -473,6 +475,7 @@ class BaseChatClient(AFBaseModel, ABC):
tools: Optional list of tools to associate with the agent.
chat_message_store_factory: Factory function to create an instance of ChatMessageStore. If not provided,
the default in-memory store will be used.
context_providers: Context providers to include during agent invocation.
**kwargs: Additional keyword arguments to pass to the agent.
See ChatAgent for all the available options.
@@ -487,6 +490,7 @@ class BaseChatClient(AFBaseModel, ABC):
instructions=instructions,
tools=tools,
chat_message_store_factory=chat_message_store_factory,
context_providers=context_providers,
**kwargs,
)
@@ -0,0 +1,188 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import sys
from abc import ABC, abstractmethod
from collections.abc import MutableSequence, Sequence
from contextlib import AsyncExitStack
from types import TracebackType
from ._pydantic import AFBaseModel
from ._types import ChatMessage, Contents
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
else:
from typing_extensions import Self # pragma: no cover
# region Context
class Context(AFBaseModel):
"""A class containing any context that should be provided to the AI model as supplied by an ContextProvider.
Each ContextProvider has the ability to provide its own context for each invocation.
The Context class contains the additional context supplied by the ContextProvider.
This context will be combined with context supplied by other providers before being passed to the AI model.
This context is per invocation, and will not be stored as part of the chat history.
"""
contents: list[Contents] | None = None
"""
Any content to pass to the AI model in addition to any other prompts
that it may already have (in the case of an agent), or chat history that may already exist.
"""
# region ContextProvider
class ContextProvider(AFBaseModel, ABC):
"""Base class for all context providers.
A context provider is a component that can be used to enhance the AI's context management.
It can listen to changes in the conversation and provide additional context to the AI model
just before invocation.
"""
async def thread_created(self, thread_id: str | None) -> None:
"""Called just after a new thread is created.
Implementers can use this method to do any operations required at the creation of a new thread.
For example, checking long term storage for any data that is relevant
to the current session based on the input text.
Args:
thread_id: The ID of the new thread.
"""
pass
async def messages_adding(self, thread_id: str | None, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
"""Called just before messages are added to the chat by any participant.
Inheritors can use this method to update their context based on new messages.
Args:
thread_id: The ID of the thread for the new message.
new_messages: New messages to add.
"""
pass
@abstractmethod
async def model_invoking(self, messages: ChatMessage | MutableSequence[ChatMessage]) -> Context:
"""Called just before the Model/Agent/etc. is invoked.
Implementers can load any additional context required at this time,
and they should return any context that should be passed to the agent.
Args:
messages: The most recent messages that the agent is being invoked with.
"""
pass
async def __aenter__(self) -> "Self":
"""Async context manager entry.
Override this method to perform any setup operations when the context provider is entered.
Returns:
Self for chaining.
"""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Async context manager exit.
Override this method to perform any cleanup operations when the context provider is exited.
Args:
exc_type: Exception type if an exception occurred, None otherwise.
exc_val: Exception value if an exception occurred, None otherwise.
exc_tb: Exception traceback if an exception occurred, None otherwise.
"""
pass
# region AggregateContextProvider
class AggregateContextProvider(ContextProvider):
"""A ContextProvider that contains multiple context providers.
It delegates events to multiple context providers and aggregates responses from those events before returning.
"""
providers: list[ContextProvider]
"""List of registered context providers."""
def __init__(self, context_providers: Sequence[ContextProvider] | None = None) -> None:
"""Initialize AggregateContextProvider with context providers.
Args:
context_providers: Context providers to add.
"""
super().__init__(providers=list(context_providers or [])) # type: ignore
self._exit_stack: AsyncExitStack | None = None
def add(self, context_provider: ContextProvider) -> None:
"""Adds new context provider.
Args:
context_provider: Context provider to add.
"""
self.providers.append(context_provider)
async def thread_created(self, thread_id: str | None = None) -> None:
await asyncio.gather(*[x.thread_created(thread_id) for x in self.providers])
async def messages_adding(self, thread_id: str | None, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
await asyncio.gather(*[x.messages_adding(thread_id, new_messages) for x in self.providers])
async def model_invoking(self, messages: ChatMessage | MutableSequence[ChatMessage]) -> Context:
sub_contexts = await asyncio.gather(*[x.model_invoking(messages) for x in self.providers])
combined_context = Context()
# Flatten the list of lists and filter out None values
all_contents = []
for ctx in sub_contexts:
if ctx.contents:
all_contents.extend(ctx.contents)
combined_context.contents = all_contents if all_contents else None
return combined_context
async def __aenter__(self) -> "Self":
"""Enter async context manager and set up all providers.
Returns:
Self for chaining.
"""
self._exit_stack = AsyncExitStack()
await self._exit_stack.__aenter__()
# Enter all context providers
for provider in self.providers:
await self._exit_stack.enter_async_context(provider)
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit async context manager and clean up all providers.
Args:
exc_type: Exception type if an exception occurred, None otherwise.
exc_val: Exception value if an exception occurred, None otherwise.
exc_tb: Exception traceback if an exception occurred, None otherwise.
"""
if self._exit_stack is not None:
await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
self._exit_stack = None
@@ -0,0 +1,24 @@
# Copyright (c) Microsoft. All rights reserved.
import importlib
from typing import Any
PACKAGE_NAME = "agent_framework_mem0"
PACKAGE_EXTRA = "mem0"
_IMPORTS = ["__version__", "Mem0Provider"]
def __getattr__(name: str) -> Any:
if name in _IMPORTS:
try:
return getattr(importlib.import_module(PACKAGE_NAME), name)
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
f"The '{PACKAGE_EXTRA}' extra is not installed, "
f"please do `pip install agent-framework[{PACKAGE_EXTRA}]`"
) from exc
raise AttributeError(f"Module {PACKAGE_NAME} has no attribute {name}.")
def __dir__() -> list[str]:
return _IMPORTS
@@ -0,0 +1,5 @@
# Copyright (c) Microsoft. All rights reserved.
from agent_framework_mem0 import Mem0Provider, __version__
__all__ = ["Mem0Provider", "__version__"]
@@ -455,7 +455,7 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient):
)
response_tools.append(
WebSearchToolParam(
type="web_search",
type="web_search_preview",
user_location=WebSearchUserLocation(
type="approximate",
city=location.get("city", None),
+1 -1
View File
@@ -23,7 +23,7 @@ classifiers = [
"Typing :: Typed",
]
dependencies = [
"openai>=1.103.0",
"openai>=1.99.0",
"pydantic>=2.11.7",
"pydantic-settings>=2.10.1",
"typing-extensions>=4.14.0",
+10 -2
View File
@@ -151,11 +151,19 @@ class MockBaseChatClient(BaseChatClient):
logger.debug(f"Running base chat client inner, with: {messages=}, {chat_options=}, {kwargs=}")
if not self.run_responses:
return ChatResponse(messages=ChatMessage(role="assistant", text=f"test response - {messages[0].text}"))
response = self.run_responses.pop(0)
if chat_options.tool_choice == "none":
return ChatResponse(
messages=ChatMessage(role="assistant", text="I broke out of the function invocation loop...")
messages=ChatMessage(
role="assistant",
text="I broke out of the function invocation loop...",
),
conversation_id=response.conversation_id,
)
return self.run_responses.pop(0)
return response
@override
async def _inner_get_streaming_response(
+211 -1
View File
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import AsyncIterable
from collections.abc import AsyncIterable, MutableSequence, Sequence
from uuid import uuid4
from pytest import raises
@@ -15,10 +15,12 @@ from agent_framework import (
ChatMessage,
ChatMessageList,
ChatResponse,
Contents,
HostedCodeInterpreterTool,
Role,
TextContent,
)
from agent_framework._memory import AggregateContextProvider, Context, ContextProvider
from agent_framework.exceptions import AgentExecutionException
@@ -100,6 +102,7 @@ async def test_chat_client_agent_prepare_thread_and_messages(chat_client: ChatCl
_, result_messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
thread=thread,
context=Context(),
input_messages=[ChatMessage(role=Role.USER, text="Test")],
)
@@ -184,3 +187,210 @@ async def test_chat_client_agent_author_name_is_used_from_response(chat_client_b
result = await agent.run("Hello")
assert result.text == "test response"
assert result.messages[0].author_name == "TestAuthor"
# Mock context provider for testing
class MockContextProvider(ContextProvider):
context_contents: list[Contents] | None = None
thread_created_called: bool = False
messages_adding_called: bool = False
model_invoking_called: bool = False
thread_created_thread_id: str | None = None
messages_adding_thread_id: str | None = None
new_messages: list[ChatMessage] = []
def __init__(self, contents: list[Contents] | None = None) -> None:
super().__init__()
self.context_contents = contents
self.thread_created_called = False
self.messages_adding_called = False
self.model_invoking_called = False
self.thread_created_thread_id = None
self.messages_adding_thread_id = None
self.new_messages = []
async def thread_created(self, thread_id: str | None) -> None:
self.thread_created_called = True
self.thread_created_thread_id = thread_id
async def messages_adding(self, thread_id: str | None, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
self.messages_adding_called = True
self.messages_adding_thread_id = thread_id
if isinstance(new_messages, ChatMessage):
self.new_messages.append(new_messages)
else:
self.new_messages.extend(new_messages)
async def model_invoking(self, messages: ChatMessage | MutableSequence[ChatMessage]) -> Context:
self.model_invoking_called = True
return Context(contents=self.context_contents)
async def test_chat_agent_context_providers_model_invoking(chat_client: ChatClientProtocol) -> None:
"""Test that context providers' model_invoking is called during agent run."""
mock_provider = MockContextProvider(contents=[TextContent("Test context instructions")])
agent = ChatAgent(chat_client=chat_client, context_providers=mock_provider)
await agent.run("Hello")
assert mock_provider.model_invoking_called
async def test_chat_agent_context_providers_thread_created(chat_client_base: ChatClientProtocol) -> None:
"""Test that context providers' thread_created is called during agent run."""
mock_provider = MockContextProvider()
chat_client_base.run_responses = [
ChatResponse(
messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent("test response")])],
conversation_id="test-thread-id",
)
]
agent = ChatAgent(chat_client=chat_client_base, context_providers=mock_provider)
await agent.run("Hello")
assert mock_provider.thread_created_called
assert mock_provider.thread_created_thread_id == "test-thread-id"
async def test_chat_agent_context_providers_messages_adding(chat_client: ChatClientProtocol) -> None:
"""Test that context providers' messages_adding is called during agent run."""
mock_provider = MockContextProvider()
agent = ChatAgent(chat_client=chat_client, context_providers=mock_provider)
await agent.run("Hello")
assert mock_provider.messages_adding_called
# Should be called with both input and response messages
assert len(mock_provider.new_messages) >= 2
async def test_chat_agent_context_instructions_in_messages(chat_client: ChatClientProtocol) -> None:
"""Test that AI context instructions are included in messages."""
mock_provider = MockContextProvider(contents=[TextContent("Context-specific instructions")])
agent = ChatAgent(chat_client=chat_client, instructions="Agent instructions", context_providers=mock_provider)
# We need to test the _prepare_thread_and_messages method directly
context = Context(contents=[TextContent("Context-specific instructions")])
_, messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
thread=None, context=context, input_messages=[ChatMessage(role=Role.USER, text="Hello")]
)
# Should have agent instructions, context instructions, and user message
assert len(messages) == 3
assert messages[0].role == Role.SYSTEM
assert messages[0].text == "Agent instructions"
assert messages[1].role == Role.SYSTEM
assert messages[1].text == "Context-specific instructions"
assert messages[2].role == Role.USER
assert messages[2].text == "Hello"
async def test_chat_agent_context_instructions_without_agent_instructions(chat_client: ChatClientProtocol) -> None:
"""Test that AI context instructions work when agent has no instructions."""
agent = ChatAgent(chat_client=chat_client) # No instructions
context = Context(contents=[TextContent("Context-only instructions")])
_, messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
thread=None, context=context, input_messages=[ChatMessage(role=Role.USER, text="Hello")]
)
# Should have context instructions and user message only
assert len(messages) == 2
assert messages[0].role == Role.SYSTEM
assert messages[0].text == "Context-only instructions"
assert messages[1].role == Role.USER
assert messages[1].text == "Hello"
async def test_chat_agent_no_context_instructions(chat_client: ChatClientProtocol) -> None:
"""Test behavior when AI context has no instructions."""
agent = ChatAgent(chat_client=chat_client, instructions="Agent instructions")
context = Context() # No instructions
_, messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
thread=None, context=context, input_messages=[ChatMessage(role=Role.USER, text="Hello")]
)
# Should have agent instructions and user message only
assert len(messages) == 2
assert messages[0].role == Role.SYSTEM
assert messages[0].text == "Agent instructions"
assert messages[1].role == Role.USER
assert messages[1].text == "Hello"
async def test_chat_agent_run_stream_context_providers(chat_client: ChatClientProtocol) -> None:
"""Test that context providers work with run_stream method."""
mock_provider = MockContextProvider(contents=[TextContent("Stream context instructions")])
agent = ChatAgent(chat_client=chat_client, context_providers=mock_provider)
# Collect all stream updates
updates: list[AgentRunResponseUpdate] = []
async for update in agent.run_stream("Hello"):
updates.append(update)
# Verify context provider was called
assert mock_provider.model_invoking_called
assert mock_provider.thread_created_called
assert mock_provider.messages_adding_called
async def test_chat_agent_multiple_context_providers(chat_client: ChatClientProtocol) -> None:
"""Test that multiple context providers work together."""
provider1 = MockContextProvider(contents=[TextContent("First provider instructions")])
provider2 = MockContextProvider(contents=[TextContent("Second provider instructions")])
agent = ChatAgent(chat_client=chat_client, context_providers=[provider1, provider2])
await agent.run("Hello")
# Both providers should be called
assert provider1.model_invoking_called
assert provider1.thread_created_called
assert provider1.messages_adding_called
assert provider2.model_invoking_called
assert provider2.thread_created_called
assert provider2.messages_adding_called
async def test_chat_agent_aggregate_context_provider_combines_instructions() -> None:
"""Test that AggregateContextProvider combines instructions from multiple providers."""
provider1 = MockContextProvider(contents=[TextContent("First instruction")])
provider2 = MockContextProvider(contents=[TextContent("Second instruction")])
aggregate = AggregateContextProvider()
aggregate.providers.append(provider1)
aggregate.providers.append(provider2)
# Test model_invoking combines instructions
result = await aggregate.model_invoking([ChatMessage(role=Role.USER, text="Test")])
assert result.contents
assert isinstance(result.contents[0], TextContent)
assert isinstance(result.contents[1], TextContent)
assert result.contents[0].text == "First instruction"
assert result.contents[1].text == "Second instruction"
async def test_chat_agent_context_providers_with_thread_service_id(chat_client_base: ChatClientProtocol) -> None:
"""Test context providers with service-managed thread."""
mock_provider = MockContextProvider()
chat_client_base.run_responses = [
ChatResponse(
messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent("test response")])],
conversation_id="service-thread-123",
)
]
agent = ChatAgent(chat_client=chat_client_base, context_providers=mock_provider)
# Use existing service-managed thread
thread = AgentThread(service_thread_id="existing-thread-id")
await agent.run("Hello", thread=thread)
# messages_adding should be called with the service thread ID from response
assert mock_provider.messages_adding_called
assert mock_provider.messages_adding_thread_id == "service-thread-123" # Updated thread ID from response
@@ -0,0 +1,302 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import MutableSequence, Sequence
from unittest.mock import AsyncMock, Mock
from agent_framework import ChatMessage, Contents, Role, TextContent
from agent_framework._memory import AggregateContextProvider, Context, ContextProvider
class MockContextProvider(ContextProvider):
"""Mock ContextProvider for testing."""
context_contents: list[Contents] | None = None
thread_created_called: bool = False
messages_adding_called: bool = False
model_invoking_called: bool = False
thread_created_thread_id: str | None = None
messages_adding_thread_id: str | None = None
messages_adding_new_messages: ChatMessage | Sequence[ChatMessage] | None = None
model_invoking_messages: ChatMessage | MutableSequence[ChatMessage] | None = None
def __init__(self, context_contents: list[Contents] | None = None) -> None:
super().__init__()
self.context_contents = context_contents
self.thread_created_called = False
self.messages_adding_called = False
self.model_invoking_called = False
self.thread_created_thread_id = None
self.messages_adding_thread_id = None
self.messages_adding_new_messages = None
self.model_invoking_messages = None
async def thread_created(self, thread_id: str | None) -> None:
"""Track thread_created calls."""
self.thread_created_called = True
self.thread_created_thread_id = thread_id
async def messages_adding(self, thread_id: str | None, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
"""Track messages_adding calls."""
self.messages_adding_called = True
self.messages_adding_thread_id = thread_id
self.messages_adding_new_messages = new_messages
async def model_invoking(self, messages: ChatMessage | MutableSequence[ChatMessage]) -> Context:
"""Track model_invoking calls and return context."""
self.model_invoking_called = True
self.model_invoking_messages = messages
context = Context()
context.contents = self.context_contents
return context
class TestAggregateContextProvider:
"""Tests for AggregateContextProvider class."""
def test_init_with_no_providers(self) -> None:
"""Test initialization with no providers."""
aggregate = AggregateContextProvider()
assert aggregate.providers == []
def test_init_with_none_providers(self) -> None:
"""Test initialization with None providers."""
aggregate = AggregateContextProvider(None)
assert aggregate.providers == []
def test_init_with_providers(self) -> None:
"""Test initialization with providers."""
provider1 = MockContextProvider([TextContent("instructions1")])
provider2 = MockContextProvider([TextContent("instructions1")])
providers = [provider1, provider2]
aggregate = AggregateContextProvider(providers)
assert len(aggregate.providers) == 2
assert aggregate.providers[0] is provider1
assert aggregate.providers[1] is provider2
def test_add_provider(self) -> None:
"""Test adding a provider."""
aggregate = AggregateContextProvider()
provider = MockContextProvider([TextContent("instructions")])
aggregate.add(provider)
assert len(aggregate.providers) == 1
assert aggregate.providers[0] is provider
def test_add_multiple_providers(self) -> None:
"""Test adding multiple providers."""
aggregate = AggregateContextProvider()
provider1 = MockContextProvider([TextContent("instructions1")])
provider2 = MockContextProvider([TextContent("instructions2")])
aggregate.add(provider1)
aggregate.add(provider2)
assert len(aggregate.providers) == 2
assert aggregate.providers[0] is provider1
assert aggregate.providers[1] is provider2
async def test_thread_created_with_no_providers(self) -> None:
"""Test thread_created with no providers."""
aggregate = AggregateContextProvider()
# Should not raise an exception
await aggregate.thread_created("thread-123")
async def test_thread_created_with_providers(self) -> None:
"""Test thread_created calls all providers."""
provider1 = MockContextProvider([TextContent("instructions1")])
provider2 = MockContextProvider([TextContent("instructions2")])
aggregate = AggregateContextProvider([provider1, provider2])
thread_id = "thread-123"
await aggregate.thread_created(thread_id)
assert provider1.thread_created_called
assert provider1.thread_created_thread_id == thread_id
assert provider2.thread_created_called
assert provider2.thread_created_thread_id == thread_id
async def test_thread_created_with_none_thread_id(self) -> None:
"""Test thread_created with None thread_id."""
provider = MockContextProvider([TextContent("instructions")])
aggregate = AggregateContextProvider([provider])
await aggregate.thread_created(None)
assert provider.thread_created_called
assert provider.thread_created_thread_id is None
async def test_messages_adding_with_no_providers(self) -> None:
"""Test messages_adding with no providers."""
aggregate = AggregateContextProvider()
message = ChatMessage(text="Hello", role=Role.USER)
# Should not raise an exception
await aggregate.messages_adding("thread-123", message)
async def test_messages_adding_with_single_message(self) -> None:
"""Test messages_adding with a single message."""
provider1 = MockContextProvider([TextContent("instructions1")])
provider2 = MockContextProvider([TextContent("instructions2")])
aggregate = AggregateContextProvider([provider1, provider2])
thread_id = "thread-123"
message = ChatMessage(text="Hello", role=Role.USER)
await aggregate.messages_adding(thread_id, message)
assert provider1.messages_adding_called
assert provider1.messages_adding_thread_id == thread_id
assert provider1.messages_adding_new_messages == message
assert provider2.messages_adding_called
assert provider2.messages_adding_thread_id == thread_id
assert provider2.messages_adding_new_messages == message
async def test_messages_adding_with_message_sequence(self) -> None:
"""Test messages_adding with a sequence of messages."""
provider = MockContextProvider([TextContent("instructions")])
aggregate = AggregateContextProvider([provider])
thread_id = "thread-123"
messages = [
ChatMessage(text="Hello", role=Role.USER),
ChatMessage(text="Hi there", role=Role.ASSISTANT),
]
await aggregate.messages_adding(thread_id, messages)
assert provider.messages_adding_called
assert provider.messages_adding_thread_id == thread_id
assert provider.messages_adding_new_messages == messages
async def test_model_invoking_with_no_providers(self) -> None:
"""Test model_invoking with no providers."""
aggregate = AggregateContextProvider()
message = ChatMessage(text="Hello", role=Role.USER)
context = await aggregate.model_invoking(message)
assert isinstance(context, Context)
assert not context.contents
async def test_model_invoking_with_single_provider(self) -> None:
"""Test model_invoking with a single provider."""
provider = MockContextProvider([TextContent("Test instructions")])
aggregate = AggregateContextProvider([provider])
message = ChatMessage(text="Hello", role=Role.USER)
context = await aggregate.model_invoking(message)
assert provider.model_invoking_called
assert provider.model_invoking_messages == message
assert isinstance(context, Context)
assert context.contents
assert isinstance(context.contents[0], TextContent)
assert context.contents[0].text == "Test instructions"
async def test_model_invoking_with_multiple_providers(self) -> None:
"""Test model_invoking combines contexts from multiple providers."""
provider1 = MockContextProvider([TextContent("Instructions 1")])
provider2 = MockContextProvider([TextContent("Instructions 2")])
provider3 = MockContextProvider([TextContent("Instructions 3")])
aggregate = AggregateContextProvider([provider1, provider2, provider3])
messages = [ChatMessage(text="Hello", role=Role.USER)]
context = await aggregate.model_invoking(messages)
assert provider1.model_invoking_called
assert provider1.model_invoking_messages == messages
assert provider2.model_invoking_called
assert provider2.model_invoking_messages == messages
assert provider3.model_invoking_called
assert provider3.model_invoking_messages == messages
assert isinstance(context, Context)
assert context.contents
assert isinstance(context.contents[0], TextContent)
assert isinstance(context.contents[1], TextContent)
assert isinstance(context.contents[2], TextContent)
assert context.contents[0].text == "Instructions 1"
assert context.contents[1].text == "Instructions 2"
assert context.contents[2].text == "Instructions 3"
async def test_model_invoking_with_none_instructions(self) -> None:
"""Test model_invoking filters out None instructions."""
provider1 = MockContextProvider([TextContent("Instructions 1")])
provider2 = MockContextProvider(None) # None instructions
provider3 = MockContextProvider([TextContent("Instructions 3")])
aggregate = AggregateContextProvider([provider1, provider2, provider3])
message = ChatMessage(text="Hello", role=Role.USER)
context = await aggregate.model_invoking(message)
assert isinstance(context, Context)
assert context.contents
assert isinstance(context.contents[0], TextContent)
assert isinstance(context.contents[1], TextContent)
assert context.contents[0].text == "Instructions 1"
assert context.contents[1].text == "Instructions 3"
async def test_model_invoking_with_all_none_instructions(self) -> None:
"""Test model_invoking when all providers return None instructions."""
provider1 = MockContextProvider(None)
provider2 = MockContextProvider(None)
aggregate = AggregateContextProvider([provider1, provider2])
message = ChatMessage(text="Hello", role=Role.USER)
context = await aggregate.model_invoking(message)
assert isinstance(context, Context)
assert not context.contents
async def test_model_invoking_with_mutable_sequence(self) -> None:
"""Test model_invoking with MutableSequence of messages."""
provider = MockContextProvider([TextContent("Test instructions")])
aggregate = AggregateContextProvider([provider])
messages = [ChatMessage(text="Hello", role=Role.USER)]
context = await aggregate.model_invoking(messages)
assert provider.model_invoking_called
assert provider.model_invoking_messages == messages
assert isinstance(context, Context)
assert context.contents
assert isinstance(context.contents[0], TextContent)
assert context.contents[0].text == "Test instructions"
async def test_async_methods_concurrent_execution(self) -> None:
"""Test that async methods execute providers concurrently."""
# Use AsyncMock to verify concurrent execution
provider1 = Mock(spec=ContextProvider)
provider1.thread_created = AsyncMock()
provider1.messages_adding = AsyncMock()
provider1.model_invoking = AsyncMock(return_value=Context(contents=[TextContent("Test 1")]))
provider2 = Mock(spec=ContextProvider)
provider2.thread_created = AsyncMock()
provider2.messages_adding = AsyncMock()
provider2.model_invoking = AsyncMock(return_value=Context(contents=[TextContent("Test 2")]))
aggregate = AggregateContextProvider([provider1, provider2])
# Test thread_created
await aggregate.thread_created("thread-123")
provider1.thread_created.assert_called_once_with("thread-123")
provider2.thread_created.assert_called_once_with("thread-123")
# Test messages_adding
message = ChatMessage(text="Hello", role=Role.USER)
await aggregate.messages_adding("thread-123", message)
provider1.messages_adding.assert_called_once_with("thread-123", message)
provider2.messages_adding.assert_called_once_with("thread-123", message)
# Test model_invoking
context = await aggregate.model_invoking(message)
provider1.model_invoking.assert_called_once_with(message)
provider2.model_invoking.assert_called_once_with(message)
assert context.contents
assert isinstance(context.contents[0], TextContent)
assert isinstance(context.contents[1], TextContent)
assert context.contents[0].text == "Test 1"
assert context.contents[1].text == "Test 2"
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
+21
View File
@@ -0,0 +1,21 @@
# Get Started with Microsoft Agent Framework Mem0
Please install this package as the extra for `agent-framework`:
```bash
pip install agent-framework[mem0]
```
## Memory Context Provider
The Mem0 context provider enables persistent memory capabilities for your agents, allowing them to remember user preferences and conversation context across different sessions and threads.
### Basic Usage Example
See the [Mem0 basic example](https://github.com/microsoft/agent-framework/tree/main/python/samples/getting_started/context_providers/mem0/mem0_basic.py) which demonstrates:
- Setting up an agent with Mem0 context provider
- Teaching the agent user preferences
- Retrieving information using remembered context across new threads
- Persistent memory
@@ -0,0 +1,15 @@
# Copyright (c) Microsoft. All rights reserved.
import importlib.metadata
from ._provider import Mem0Provider
try:
__version__ = importlib.metadata.version(__name__)
except importlib.metadata.PackageNotFoundError:
__version__ = "0.0.0" # Fallback for development mode
__all__ = [
"Mem0Provider",
"__version__",
]
@@ -0,0 +1,180 @@
# Copyright (c) Microsoft. All rights reserved.
import sys
from collections.abc import MutableSequence, Sequence
from typing import Any, Final
from agent_framework import ChatMessage, Context, ContextProvider, TextContent
from agent_framework.exceptions import ServiceInitializationError
from pydantic import PrivateAttr
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
else:
from typing_extensions import Self # pragma: no cover
DEFAULT_CONTEXT_PROMPT: Final[str] = "## Memories\nConsider the following memories when answering user questions:"
class Mem0Provider(ContextProvider):
api_key: str | None = None
application_id: str | None = None
agent_id: str | None = None
thread_id: str | None = None
user_id: str | None = None
scope_to_per_operation_thread_id: bool = False
context_prompt: str = DEFAULT_CONTEXT_PROMPT
# Use Any to avoid forward reference issues with AsyncMemoryClient
mem0_client: Any = None
_should_close_client: bool = PrivateAttr(default=False) # Track whether we should close client connection
def __init__(
self,
api_key: str | None = None,
application_id: str | None = None,
agent_id: str | None = None,
thread_id: str | None = None,
user_id: str | None = None,
scope_to_per_operation_thread_id: bool = False,
context_prompt: str = DEFAULT_CONTEXT_PROMPT,
mem0_client: Any = None,
) -> None:
"""Initializes a new instance of the Mem0Provider class.
Args:
api_key: The API key for authenticating with the Mem0 API. If not
provided, it will attempt to use the MEM0_API_KEY environment variable.
application_id: The application ID for scoping memories or None.
agent_id: The agent ID for scoping memories or None.
thread_id: The thread ID for scoping memories or None.
user_id: The user ID for scoping memories or None.
scope_to_per_operation_thread_id: Whether to scope memories to per-operation thread ID.
context_prompt: The prompt to prepend to retrieved memories.
mem0_client: A pre-created Mem0 MemoryClient or None to create a default client.
"""
should_close_client = False
if mem0_client is None:
from mem0 import AsyncMemoryClient
mem0_client = AsyncMemoryClient(api_key=api_key)
should_close_client = True
super().__init__(
api_key=api_key, # type: ignore[reportCallIssue]
application_id=application_id, # type: ignore[reportCallIssue]
agent_id=agent_id, # type: ignore[reportCallIssue]
thread_id=thread_id, # type: ignore[reportCallIssue]
user_id=user_id, # type: ignore[reportCallIssue]
scope_to_per_operation_thread_id=scope_to_per_operation_thread_id, # type: ignore[reportCallIssue]
context_prompt=context_prompt, # type: ignore[reportCallIssue]
mem0_client=mem0_client, # type: ignore[reportCallIssue]
)
self._per_operation_thread_id: str | None = None
self._should_close_client = should_close_client
async def __aenter__(self) -> "Self":
"""Async context manager entry."""
if self.mem0_client:
await self.mem0_client.__aenter__()
return self
async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None:
"""Async context manager exit."""
if self._should_close_client and self.mem0_client:
await self.mem0_client.__aexit__(exc_type, exc_val, exc_tb)
async def thread_created(self, thread_id: str | None = None) -> None:
"""Called when a new thread is created.
Args:
thread_id: The ID of the thread or None.
"""
self._validate_per_operation_thread_id(thread_id)
self._per_operation_thread_id = self._per_operation_thread_id or thread_id
async def messages_adding(self, thread_id: str | None, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
"""Called when a new message is being added to the thread.
Args:
thread_id: The ID of the thread or None.
new_messages: New messages to add.
"""
self._validate_filters()
self._validate_per_operation_thread_id(thread_id)
self._per_operation_thread_id = self._per_operation_thread_id or thread_id
messages_list = [new_messages] if isinstance(new_messages, ChatMessage) else list(new_messages)
messages: list[dict[str, str]] = [
{"role": message.role.value, "content": message.text}
for message in messages_list
if message.role.value in {"user", "assistant", "system"} and message.text and message.text.strip()
]
if messages:
await self.mem0_client.add( # type: ignore[misc]
messages=messages,
user_id=self.user_id,
agent_id=self.agent_id,
run_id=self._per_operation_thread_id if self.scope_to_per_operation_thread_id else self.thread_id,
metadata={"application_id": self.application_id},
)
async def model_invoking(self, messages: ChatMessage | MutableSequence[ChatMessage]) -> Context:
"""Called before invoking the AI model to provide context.
Args:
messages: List of new messages in the thread.
Returns:
Context: Context object containing instructions with memories.
"""
self._validate_filters()
messages_list = [messages] if isinstance(messages, ChatMessage) else list(messages)
input_text = "\n".join(msg.text for msg in messages_list if msg and msg.text and msg.text.strip())
memories = await self.mem0_client.search( # type: ignore[misc]
query=input_text,
user_id=self.user_id,
agent_id=self.agent_id,
run_id=self._per_operation_thread_id if self.scope_to_per_operation_thread_id else self.thread_id,
)
line_separated_memories = "\n".join(memory.get("memory", "") for memory in memories)
content = TextContent(f"{self.context_prompt}\n{line_separated_memories}") if line_separated_memories else None
return Context(contents=[content] if content else None)
def _validate_filters(self) -> None:
"""Validates that at least one filter is provided.
Raises:
ServiceInitializationError: If no filters are provided.
"""
if not self.agent_id and not self.user_id and not self.application_id and not self.thread_id:
raise ServiceInitializationError(
"At least one of the filters: agent_id, user_id, application_id, or thread_id is required."
)
def _validate_per_operation_thread_id(self, thread_id: str | None) -> None:
"""Validates that a new thread ID doesn't conflict with an existing one when scoped.
Args:
thread_id: The new thread ID or None.
Raises:
ValueError: If a new thread ID is provided when one already exists.
"""
if (
self.scope_to_per_operation_thread_id
and thread_id
and self._per_operation_thread_id
and thread_id != self._per_operation_thread_id
):
raise ValueError(
"Mem0Provider can only be used with one thread at a time when scope_to_per_operation_thread_id is True."
)
+94
View File
@@ -0,0 +1,94 @@
[project]
name = "agent-framework-mem0"
description = "Mem0 integration for Microsoft Agent Framework."
authors = [{ name = "Microsoft", email = "SK-Support@microsoft.com"}]
readme = "README.md"
requires-python = ">=3.10"
version = "0.1.0b1"
license-files = ["LICENSE"]
urls.homepage = "https://learn.microsoft.com/en-us/semantic-kernel/overview/"
urls.source = "https://github.com/microsoft/agent-framework/tree/main/python"
urls.release_notes = "https://github.com/microsoft/agent-framework/releases?q=tag%3Apython-1&expanded=true"
urls.issues = "https://github.com/microsoft/agent-framework/issues"
classifiers = [
"License :: OSI Approved :: MIT License",
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Framework :: Pydantic :: 2",
"Typing :: Typed",
]
dependencies = [
"agent-framework",
"mem0ai>=0.1.117",
]
[tool.uv]
prerelease = "if-necessary-or-explicit"
environments = [
"sys_platform == 'darwin'",
"sys_platform == 'linux'",
"sys_platform == 'win32'"
]
[tool.uv-dynamic-versioning]
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
addopts = "-ra -q -r fEX"
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
filterwarnings = [
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
]
timeout = 120
[tool.ruff]
extend = "../../pyproject.toml"
[tool.coverage.run]
omit = [
"**/__init__.py"
]
[tool.pyright]
extend = "../../pyproject.toml"
exclude = ['tests']
[tool.mypy]
plugins = ['pydantic.mypy']
strict = true
python_version = "3.10"
ignore_missing_imports = true
disallow_untyped_defs = true
no_implicit_optional = true
check_untyped_defs = true
warn_return_any = true
show_error_codes = true
warn_unused_ignores = false
disallow_incomplete_defs = true
disallow_untyped_decorators = true
disallow_any_unimported = true
[tool.bandit]
targets = ["agent_framework_mem0"]
exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_mem0"
test = "pytest --cov=agent_framework_mem0 --cov-report=term-missing:skip-covered tests"
[tool.uv.build-backend]
module-name = "agent_framework_mem0"
module-root = ""
[build-system]
requires = ["uv_build>=0.8.2,<0.9.0"]
build-backend = "uv_build"
@@ -0,0 +1,28 @@
# Copyright (c) Microsoft. All rights reserved.
def test_self_through_main() -> None:
try:
from agent_framework.mem0 import __version__
except ImportError:
__version__ = None
assert __version__ is not None
def test_self() -> None:
try:
from agent_framework_mem0 import __version__
except ImportError:
__version__ = None
assert __version__ is not None
def test_agent_framework() -> None:
try:
from agent_framework import __version__
except ImportError:
__version__ = None
assert __version__ is not None
+481
View File
@@ -0,0 +1,481 @@
# Copyright (c) Microsoft. All rights reserved.
# pyright: reportPrivateUsage=false
from unittest.mock import AsyncMock, patch
import pytest
from agent_framework import ChatMessage, Context, Role, TextContent
from agent_framework.exceptions import ServiceInitializationError
from agent_framework.mem0 import Mem0Provider
def test_mem0_provider_import():
"""Test that Mem0Provider can be imported."""
assert Mem0Provider is not None
@pytest.fixture
def mock_mem0_client() -> AsyncMock:
"""Create a mock Mem0 AsyncMemoryClient."""
mock_client = AsyncMock()
mock_client.add = AsyncMock()
mock_client.search = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock()
mock_client.async_client = AsyncMock()
mock_client.async_client.aclose = AsyncMock()
return mock_client
@pytest.fixture
def sample_messages() -> list[ChatMessage]:
"""Create sample chat messages for testing."""
return [
ChatMessage(role=Role.USER, text="Hello, how are you?"),
ChatMessage(role=Role.ASSISTANT, text="I'm doing well, thank you!"),
ChatMessage(role=Role.SYSTEM, text="You are a helpful assistant"),
]
class TestMem0ProviderInitialization:
"""Test initialization and configuration of Mem0Provider."""
def test_init_with_all_ids(self, mock_mem0_client: AsyncMock):
"""Test initialization with all IDs provided."""
provider = Mem0Provider(
user_id="user123",
agent_id="agent123",
application_id="app123",
thread_id="thread123",
mem0_client=mock_mem0_client,
)
assert provider.user_id == "user123"
assert provider.agent_id == "agent123"
assert provider.application_id == "app123"
assert provider.thread_id == "thread123"
def test_init_without_filters_succeeds(self, mock_mem0_client: AsyncMock):
"""Test that initialization succeeds even without filters (validation happens during invocation)."""
provider = Mem0Provider(mem0_client=mock_mem0_client)
assert provider.user_id is None
assert provider.agent_id is None
assert provider.application_id is None
assert provider.thread_id is None
def test_init_with_custom_context_prompt(self, mock_mem0_client: AsyncMock):
"""Test initialization with custom context prompt."""
custom_prompt = "## Custom Memories\nConsider these memories:"
provider = Mem0Provider(user_id="user123", context_prompt=custom_prompt, mem0_client=mock_mem0_client)
assert provider.context_prompt == custom_prompt
def test_init_with_scope_to_per_operation_thread_id(self, mock_mem0_client: AsyncMock):
"""Test initialization with scope_to_per_operation_thread_id enabled."""
provider = Mem0Provider(
user_id="user123",
scope_to_per_operation_thread_id=True,
mem0_client=mock_mem0_client,
)
assert provider.scope_to_per_operation_thread_id is True
@patch("mem0.AsyncMemoryClient")
def test_init_creates_default_client_when_none_provided(self, mock_memory_client_class: AsyncMock):
"""Test that a default client is created when none is provided."""
mock_client = AsyncMock()
mock_memory_client_class.return_value = mock_client
provider = Mem0Provider(user_id="user123", api_key="test_api_key")
mock_memory_client_class.assert_called_once_with(api_key="test_api_key")
assert provider.mem0_client == mock_client
assert provider._should_close_client is True
def test_init_with_provided_client_should_not_close(self, mock_mem0_client: AsyncMock):
"""Test that provided client should not be closed by provider."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
assert provider._should_close_client is False
class TestMem0ProviderAsyncContextManager:
"""Test async context manager behavior."""
async def test_async_context_manager_entry(self, mock_mem0_client: AsyncMock):
"""Test async context manager entry returns self."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
async with provider as ctx:
assert ctx is provider
async def test_async_context_manager_exit_closes_client_when_should_close(self):
"""Test that async context manager closes client when it should."""
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock()
mock_client.async_client = AsyncMock()
mock_client.async_client.aclose = AsyncMock()
with patch("mem0.AsyncMemoryClient", return_value=mock_client):
provider = Mem0Provider(user_id="user123", api_key="test_key")
assert provider._should_close_client is True
async with provider:
pass
mock_client.__aexit__.assert_called_once()
async def test_async_context_manager_exit_does_not_close_provided_client(self, mock_mem0_client: AsyncMock):
"""Test that async context manager does not close provided client."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
assert provider._should_close_client is False
async with provider:
pass
mock_mem0_client.__aexit__.assert_not_called()
class TestMem0ProviderThreadMethods:
"""Test thread lifecycle methods."""
async def test_thread_created_sets_per_operation_thread_id(self, mock_mem0_client: AsyncMock):
"""Test that thread_created sets per-operation thread ID."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
await provider.thread_created("thread123")
assert provider._per_operation_thread_id == "thread123"
async def test_thread_created_with_existing_thread_id(self, mock_mem0_client: AsyncMock):
"""Test thread_created when thread ID already exists."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
provider._per_operation_thread_id = "existing_thread"
await provider.thread_created("thread123")
# Should not overwrite existing thread ID
assert provider._per_operation_thread_id == "existing_thread"
async def test_thread_created_validation_with_scope_enabled(self, mock_mem0_client: AsyncMock):
"""Test thread_created validation when scope_to_per_operation_thread_id is enabled."""
provider = Mem0Provider(
user_id="user123",
scope_to_per_operation_thread_id=True,
mem0_client=mock_mem0_client,
)
provider._per_operation_thread_id = "existing_thread"
with pytest.raises(ValueError) as exc_info:
await provider.thread_created("different_thread")
assert "can only be used with one thread at a time" in str(exc_info.value)
async def test_messages_adding_sets_per_operation_thread_id(
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
):
"""Test that messages_adding sets per-operation thread ID."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
await provider.messages_adding("thread123", sample_messages)
assert provider._per_operation_thread_id == "thread123"
class TestMem0ProviderMessagesAdding:
"""Test messages_adding method."""
async def test_messages_adding_fails_without_filters(self, mock_mem0_client: AsyncMock):
"""Test that messages_adding fails when no filters are provided."""
provider = Mem0Provider(mem0_client=mock_mem0_client)
message = ChatMessage(role=Role.USER, text="Hello!")
with pytest.raises(ServiceInitializationError) as exc_info:
await provider.messages_adding("thread123", message)
assert "At least one of the filters" in str(exc_info.value)
async def test_messages_adding_single_message(self, mock_mem0_client: AsyncMock):
"""Test adding a single message."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
message = ChatMessage(role=Role.USER, text="Hello!")
await provider.messages_adding("thread123", message)
mock_mem0_client.add.assert_called_once()
call_args = mock_mem0_client.add.call_args
assert call_args.kwargs["messages"] == [{"role": "user", "content": "Hello!"}]
assert call_args.kwargs["user_id"] == "user123"
async def test_messages_adding_multiple_messages(
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
):
"""Test adding multiple messages."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
await provider.messages_adding("thread123", sample_messages)
mock_mem0_client.add.assert_called_once()
call_args = mock_mem0_client.add.call_args
expected_messages = [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you!"},
{"role": "system", "content": "You are a helpful assistant"},
]
assert call_args.kwargs["messages"] == expected_messages
async def test_messages_adding_with_agent_id(self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]):
"""Test adding messages with agent_id."""
provider = Mem0Provider(agent_id="agent123", mem0_client=mock_mem0_client)
await provider.messages_adding("thread123", sample_messages)
call_args = mock_mem0_client.add.call_args
assert call_args.kwargs["agent_id"] == "agent123"
assert call_args.kwargs["user_id"] is None
async def test_messages_adding_with_application_id(
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
):
"""Test adding messages with application_id in metadata."""
provider = Mem0Provider(user_id="user123", application_id="app123", mem0_client=mock_mem0_client)
await provider.messages_adding("thread123", sample_messages)
call_args = mock_mem0_client.add.call_args
assert call_args.kwargs["metadata"] == {"application_id": "app123"}
async def test_messages_adding_with_scope_to_per_operation_thread_id(
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
):
"""Test adding messages with scope_to_per_operation_thread_id enabled."""
provider = Mem0Provider(
user_id="user123",
thread_id="base_thread",
scope_to_per_operation_thread_id=True,
mem0_client=mock_mem0_client,
)
provider._per_operation_thread_id = "operation_thread"
await provider.messages_adding("operation_thread", sample_messages)
call_args = mock_mem0_client.add.call_args
assert call_args.kwargs["run_id"] == "operation_thread"
async def test_messages_adding_without_scope_uses_base_thread_id(
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
):
"""Test adding messages without scope uses base thread_id."""
provider = Mem0Provider(
user_id="user123",
thread_id="base_thread",
scope_to_per_operation_thread_id=False,
mem0_client=mock_mem0_client,
)
await provider.messages_adding("operation_thread", sample_messages)
call_args = mock_mem0_client.add.call_args
assert call_args.kwargs["run_id"] == "base_thread"
async def test_messages_adding_filters_empty_messages(self, mock_mem0_client: AsyncMock):
"""Test that empty or invalid messages are filtered out."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
messages = [
ChatMessage(role=Role.USER, text=""), # Empty text
ChatMessage(role=Role.USER, text=" "), # Whitespace only
ChatMessage(role=Role.USER, text="Valid message"),
]
await provider.messages_adding("thread123", messages)
call_args = mock_mem0_client.add.call_args
# Should only include the valid message
assert call_args.kwargs["messages"] == [{"role": "user", "content": "Valid message"}]
async def test_messages_adding_skips_when_no_valid_messages(self, mock_mem0_client: AsyncMock):
"""Test that mem0 client is not called when no valid messages exist."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
messages = [
ChatMessage(role=Role.USER, text=""),
ChatMessage(role=Role.USER, text=" "),
]
await provider.messages_adding("thread123", messages)
mock_mem0_client.add.assert_not_called()
class TestMem0ProviderModelInvoking:
"""Test model_invoking method."""
async def test_model_invoking_fails_without_filters(self, mock_mem0_client: AsyncMock):
"""Test that model_invoking fails when no filters are provided."""
provider = Mem0Provider(mem0_client=mock_mem0_client)
message = ChatMessage(role=Role.USER, text="What's the weather?")
with pytest.raises(ServiceInitializationError) as exc_info:
await provider.model_invoking(message)
assert "At least one of the filters" in str(exc_info.value)
async def test_model_invoking_single_message(self, mock_mem0_client: AsyncMock):
"""Test model_invoking with a single message."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
message = ChatMessage(role=Role.USER, text="What's the weather?")
# Mock search results
mock_mem0_client.search.return_value = [
{"memory": "User likes outdoor activities"},
{"memory": "User lives in Seattle"},
]
context = await provider.model_invoking(message)
mock_mem0_client.search.assert_called_once()
call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["query"] == "What's the weather?"
assert call_args.kwargs["user_id"] == "user123"
assert isinstance(context, Context)
expected_instructions = (
"## Memories\nConsider the following memories when answering user questions:\n"
"User likes outdoor activities\nUser lives in Seattle"
)
assert context.contents
assert isinstance(context.contents[0], TextContent)
assert context.contents[0].text == expected_instructions
async def test_model_invoking_multiple_messages(
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
):
"""Test model_invoking with multiple messages."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
mock_mem0_client.search.return_value = [{"memory": "Previous conversation context"}]
await provider.model_invoking(sample_messages)
call_args = mock_mem0_client.search.call_args
expected_query = "Hello, how are you?\nI'm doing well, thank you!\nYou are a helpful assistant"
assert call_args.kwargs["query"] == expected_query
async def test_model_invoking_with_agent_id(self, mock_mem0_client: AsyncMock):
"""Test model_invoking with agent_id."""
provider = Mem0Provider(agent_id="agent123", mem0_client=mock_mem0_client)
message = ChatMessage(role=Role.USER, text="Hello")
mock_mem0_client.search.return_value = []
await provider.model_invoking(message)
call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["agent_id"] == "agent123"
assert call_args.kwargs["user_id"] is None
async def test_model_invoking_with_scope_to_per_operation_thread_id(self, mock_mem0_client: AsyncMock):
"""Test model_invoking with scope_to_per_operation_thread_id enabled."""
provider = Mem0Provider(
user_id="user123",
thread_id="base_thread",
scope_to_per_operation_thread_id=True,
mem0_client=mock_mem0_client,
)
provider._per_operation_thread_id = "operation_thread"
message = ChatMessage(role=Role.USER, text="Hello")
mock_mem0_client.search.return_value = []
await provider.model_invoking(message)
call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["run_id"] == "operation_thread"
async def test_model_invoking_no_memories_returns_none_instructions(self, mock_mem0_client: AsyncMock):
"""Test that no memories returns context with None instructions."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
message = ChatMessage(role=Role.USER, text="Hello")
mock_mem0_client.search.return_value = []
context = await provider.model_invoking(message)
assert isinstance(context, Context)
assert not context.contents
async def test_model_invoking_filters_empty_message_text(self, mock_mem0_client: AsyncMock):
"""Test that empty message text is filtered out from query."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
messages = [
ChatMessage(role=Role.USER, text=""),
ChatMessage(role=Role.USER, text="Valid message"),
ChatMessage(role=Role.USER, text=" "),
]
mock_mem0_client.search.return_value = []
await provider.model_invoking(messages)
call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["query"] == "Valid message"
async def test_model_invoking_custom_context_prompt(self, mock_mem0_client: AsyncMock):
"""Test model_invoking with custom context prompt."""
custom_prompt = "## Custom Context\nRemember these details:"
provider = Mem0Provider(
user_id="user123",
context_prompt=custom_prompt,
mem0_client=mock_mem0_client,
)
message = ChatMessage(role=Role.USER, text="Hello")
mock_mem0_client.search.return_value = [{"memory": "Test memory"}]
context = await provider.model_invoking(message)
expected_instructions = "## Custom Context\nRemember these details:\nTest memory"
assert context.contents
assert isinstance(context.contents[0], TextContent)
assert context.contents[0].text == expected_instructions
class TestMem0ProviderValidation:
"""Test validation methods."""
def test_validate_per_operation_thread_id_success(self, mock_mem0_client: AsyncMock):
"""Test successful validation of per-operation thread ID."""
provider = Mem0Provider(
user_id="user123",
scope_to_per_operation_thread_id=True,
mem0_client=mock_mem0_client,
)
provider._per_operation_thread_id = "thread123"
# Should not raise exception for same thread ID
provider._validate_per_operation_thread_id("thread123")
# Should not raise exception for None
provider._validate_per_operation_thread_id(None)
def test_validate_per_operation_thread_id_failure(self, mock_mem0_client: AsyncMock):
"""Test validation failure for conflicting thread IDs."""
provider = Mem0Provider(
user_id="user123",
scope_to_per_operation_thread_id=True,
mem0_client=mock_mem0_client,
)
provider._per_operation_thread_id = "thread123"
with pytest.raises(ValueError) as exc_info:
provider._validate_per_operation_thread_id("different_thread")
assert "can only be used with one thread at a time" in str(exc_info.value)
def test_validate_per_operation_thread_id_disabled_scope(self, mock_mem0_client: AsyncMock):
"""Test that validation is skipped when scope is disabled."""
provider = Mem0Provider(
user_id="user123",
scope_to_per_operation_thread_id=False,
mem0_client=mock_mem0_client,
)
provider._per_operation_thread_id = "thread123"
# Should not raise exception even with different thread ID
provider._validate_per_operation_thread_id("different_thread")
+4 -2
View File
@@ -7,6 +7,7 @@ dependencies = [
"agent-framework",
"agent-framework-azure",
"agent-framework-foundry",
"agent-framework-mem0",
"agent-framework-workflow",
]
@@ -63,8 +64,9 @@ exclude = [ "packages/agent_framework_project.egg-info" ]
agent-framework = { workspace = true }
agent-framework-azure = { workspace = true }
agent-framework-foundry = { workspace = true }
agent-framework-workflow = { workspace = true }
agent-framework-mem0 = { workspace = true }
agent-framework-runtime = { workspace = true }
agent-framework-workflow = { workspace = true }
[tool.ruff]
line-length = 120
@@ -191,7 +193,7 @@ build = "python run_tasks_in_packages_if_exists.py build"
# combined checks
check = ["fmt", "lint", "pyright", "mypy", "test", "markdown-code-lint", "samples-code-check"]
pre-commit-check = ["fmt", "lint", "pyright", "markdown-code-lint", "samples-code-check"]
all-tests = "pytest --cov=agent_framework --cov=agent_framework_azure --cov=agent_framework_foundry --cov=agent_framework_workflow --cov-report=term-missing:skip-covered packages/**/tests"
all-tests = "pytest --import-mode=importlib --cov=agent_framework --cov=agent_framework_azure --cov=agent_framework_foundry --cov=agent_framework_mem0 --cov=agent_framework_workflow --cov-report=term-missing:skip-covered packages/azure/tests packages/foundry/tests packages/main/tests packages/mem0/tests packages/workflow/tests"
[tool.poe.tasks.venv]
cmd = "uv venv --clear --python $python"
@@ -0,0 +1,51 @@
# Mem0 Context Provider Examples
[Mem0](https://mem0.ai/) is a self-improving memory layer for Large Language Models that enables applications to have long-term memory capabilities. The Agent Framework's Mem0 context provider integrates with Mem0's API to provide persistent memory across conversation sessions.
This folder contains examples demonstrating how to use the Mem0 context provider with the Agent Framework for persistent memory and context management across conversations.
## Examples
| File | Description |
|------|-------------|
| [`mem0_basic.py`](mem0_basic.py) | Basic example of using Mem0 context provider to store and retrieve user preferences across different conversation threads. |
| [`mem0_threads.py`](mem0_threads.py) | Advanced example demonstrating different thread scoping strategies with Mem0. Covers global thread scope (memories shared across all operations), per-operation thread scope (memories isolated per thread), and multiple agents with different memory configurations for personal vs. work contexts. |
## Prerequisites
### Required Resources
1. [Mem0 API Key](https://app.mem0.ai/) - Sign up for a Mem0 account and get your API key
2. Azure AI Foundry project endpoint (used in these examples)
3. Azure CLI authentication (run `az login`)
## Configuration
### Environment Variables
Set the following environment variables:
**For Mem0:**
- `MEM0_API_KEY`: Your Mem0 API key (alternatively, pass it as `api_key` parameter to `Mem0Provider`)
**For Azure AI Foundry:**
- `FOUNDRY_PROJECT_ENDPOINT`: Your Azure AI Foundry project endpoint
- `FOUNDRY_MODEL_DEPLOYMENT_NAME`: The name of your model deployment
## Key Concepts
### Memory Scoping
The Mem0 context provider supports different scoping strategies:
- **Global Scope** (`scope_to_per_operation_thread_id=False`): Memories are shared across all conversation threads
- **Thread Scope** (`scope_to_per_operation_thread_id=True`): Memories are isolated per conversation thread
### Memory Association
Mem0 records can be associated with different identifiers:
- `user_id`: Associate memories with a specific user
- `agent_id`: Associate memories with a specific agent
- `thread_id`: Associate memories with a specific conversation thread
- `application_id`: Associate memories with an application context
@@ -0,0 +1,72 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import uuid
from agent_framework.foundry import FoundryChatClient
from agent_framework.mem0 import Mem0Provider
from azure.identity.aio import AzureCliCredential
def retrieve_company_report(company_code: str, detailed: bool) -> str:
if company_code != "CNTS":
raise ValueError("Company code not found")
if not detailed:
return "CNTS is a company that specializes in technology."
return (
"CNTS is a company that specializes in technology. "
"It had a revenue of $10 million in 2022. It has 100 employees."
)
async def main() -> None:
"""Example of memory usage with Mem0 context provider."""
print("=== Mem0 Context Provider Example ===")
# Each record in Mem0 should be associated with agent_id or user_id or application_id or thread_id.
# In this example, we associate Mem0 records with user_id.
user_id = str(uuid.uuid4())
# For Azure authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
# For Mem0 authentication, set Mem0 API key via "api_key" parameter or MEM0_API_KEY environment variable.
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="FriendlyAssistant",
instructions="You are a friendly assistant.",
tools=retrieve_company_report,
context_providers=Mem0Provider(user_id=user_id),
) as agent,
):
# First ask the agent to retrieve a company report with no previous context.
# The agent will not be able to invoke the tool, since it doesn't know
# the company code or the report format, so it should ask for clarification.
query = "Please retrieve my company report"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result}\n")
# Now tell the agent the company code and the report format that you want to use
# and it should be able to invoke the tool and return the report.
query = "I always work with CNTS and I always want a detailed report format. Please remember and retrieve it."
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result}\n")
print("\nRequest within a new thread:")
# Create a new thread for the agent.
# The new thread has no context of the previous conversation.
thread = agent.get_new_thread()
# Since we have the mem0 component in the thread, the agent should be able to
# retrieve the company report without asking for clarification, as it will
# be able to remember the user preferences from Mem0 component.
query = "Please retrieve my company report"
print(f"User: {query}")
result = await agent.run(query, thread=thread)
print(f"Agent: {result}\n")
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,164 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import uuid
from agent_framework.foundry import FoundryChatClient
from agent_framework.mem0 import Mem0Provider
from azure.identity.aio import AzureCliCredential
def get_user_preferences(user_id: str) -> str:
"""Mock function to get user preferences."""
preferences = {
"user123": "Prefers concise responses and technical details",
"user456": "Likes detailed explanations with examples",
}
return preferences.get(user_id, "No specific preferences found")
async def example_global_thread_scope() -> None:
"""Example 1: Global thread_id scope (memories shared across all operations)."""
print("1. Global Thread Scope Example:")
print("-" * 40)
global_thread_id = str(uuid.uuid4())
user_id = "user123"
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="GlobalMemoryAssistant",
instructions="You are an assistant that remembers user preferences across conversations.",
tools=get_user_preferences,
context_providers=Mem0Provider(
user_id=user_id,
thread_id=global_thread_id,
scope_to_per_operation_thread_id=False, # Share memories across all threads
),
) as global_agent,
):
# Store some preferences in the global scope
query = "Remember that I prefer technical responses with code examples when discussing programming."
print(f"User: {query}")
result = await global_agent.run(query)
print(f"Agent: {result}\n")
# Create a new thread - but memories should still be accessible due to global scope
new_thread = global_agent.get_new_thread()
query = "What do you know about my preferences?"
print(f"User (new thread): {query}")
result = await global_agent.run(query, thread=new_thread)
print(f"Agent: {result}\n")
async def example_per_operation_thread_scope() -> None:
"""Example 2: Per-operation thread scope (memories isolated per thread).
Note: When scope_to_per_operation_thread_id=True, the provider is bound to a single thread
throughout its lifetime. Use the same thread object for all operations with that provider.
"""
print("2. Per-Operation Thread Scope Example:")
print("-" * 40)
user_id = "user123"
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="ScopedMemoryAssistant",
instructions="You are an assistant with thread-scoped memory.",
tools=get_user_preferences,
context_providers=Mem0Provider(
user_id=user_id,
scope_to_per_operation_thread_id=True, # Isolate memories per thread
),
) as scoped_agent,
):
# Create a specific thread for this scoped provider
dedicated_thread = scoped_agent.get_new_thread()
# Store some information in the dedicated thread
query = "Remember that for this conversation, I'm working on a Python project about data analysis."
print(f"User (dedicated thread): {query}")
result = await scoped_agent.run(query, thread=dedicated_thread)
print(f"Agent: {result}\n")
# Test memory retrieval in the same dedicated thread
query = "What project am I working on?"
print(f"User (same dedicated thread): {query}")
result = await scoped_agent.run(query, thread=dedicated_thread)
print(f"Agent: {result}\n")
# Store more information in the same thread
query = "Also remember that I prefer using pandas and matplotlib for this project."
print(f"User (same dedicated thread): {query}")
result = await scoped_agent.run(query, thread=dedicated_thread)
print(f"Agent: {result}\n")
# Test comprehensive memory retrieval
query = "What do you know about my current project and preferences?"
print(f"User (same dedicated thread): {query}")
result = await scoped_agent.run(query, thread=dedicated_thread)
print(f"Agent: {result}\n")
async def example_multiple_agents() -> None:
"""Example 3: Multiple agents with different thread configurations."""
print("3. Multiple Agents with Different Thread Configurations:")
print("-" * 40)
agent_id_1 = "agent_personal"
agent_id_2 = "agent_work"
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="PersonalAssistant",
instructions="You are a personal assistant that helps with personal tasks.",
context_providers=Mem0Provider(
agent_id=agent_id_1,
),
) as personal_agent,
FoundryChatClient(async_credential=credential).create_agent(
name="WorkAssistant",
instructions="You are a work assistant that helps with professional tasks.",
context_providers=Mem0Provider(
agent_id=agent_id_2,
),
) as work_agent,
):
# Store personal information
query = "Remember that I like to exercise at 6 AM and prefer outdoor activities."
print(f"User to Personal Agent: {query}")
result = await personal_agent.run(query)
print(f"Personal Agent: {result}\n")
# Store work information
query = "Remember that I have team meetings every Tuesday at 2 PM."
print(f"User to Work Agent: {query}")
result = await work_agent.run(query)
print(f"Work Agent: {result}\n")
# Test memory isolation
query = "What do you know about my schedule?"
print(f"User to Personal Agent: {query}")
result = await personal_agent.run(query)
print(f"Personal Agent: {result}\n")
print(f"User to Work Agent: {query}")
result = await work_agent.run(query)
print(f"Work Agent: {result}\n")
async def main() -> None:
"""Run all Mem0 thread management examples."""
print("=== Mem0 Thread Management Example ===\n")
await example_global_thread_scope()
await example_per_operation_thread_scope()
await example_multiple_agents()
if __name__ == "__main__":
asyncio.run(main())
+2237 -1902
View File
File diff suppressed because it is too large Load Diff