mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add support for Mem0 Open Memory (#876)
* Add support for Mem0 Open Memory * Linting fixes * Linting fixes * Add sample and documentation * Small fixes * Update sample code imports/class names for new package structure * Improved typing --------- Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Co-authored-by: Evan Mattson <35585003+moonbox3@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
0f2f2263c5
commit
ef9c072eab
@@ -2,21 +2,31 @@
|
||||
|
||||
import sys
|
||||
from collections.abc import MutableSequence, Sequence
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import ChatMessage, Context, ContextProvider, TextContent
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from mem0 import AsyncMemoryClient
|
||||
from mem0 import AsyncMemory, AsyncMemoryClient
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self # pragma: no cover
|
||||
from typing import NotRequired, Self, TypedDict # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self # pragma: no cover
|
||||
from typing_extensions import NotRequired, Self, TypedDict # pragma: no cover
|
||||
|
||||
|
||||
# Type aliases for Mem0 search response formats (v1.1 and v2; v1 is deprecated, but matches the type definition for v2)
|
||||
class MemorySearchResponse_v1_1(TypedDict):
|
||||
results: list[dict[str, Any]]
|
||||
relations: NotRequired[list[dict[str, Any]]]
|
||||
|
||||
|
||||
MemorySearchResponse_v2 = list[dict[str, Any]]
|
||||
|
||||
|
||||
class Mem0Provider(ContextProvider):
|
||||
mem0_client: AsyncMemoryClient
|
||||
mem0_client: AsyncMemory | AsyncMemoryClient
|
||||
api_key: str | None = None
|
||||
application_id: str | None = None
|
||||
agent_id: str | None = None
|
||||
@@ -36,7 +46,7 @@ class Mem0Provider(ContextProvider):
|
||||
user_id: str | None = None,
|
||||
scope_to_per_operation_thread_id: bool = False,
|
||||
context_prompt: str = ContextProvider.DEFAULT_CONTEXT_PROMPT,
|
||||
mem0_client: AsyncMemoryClient | None = None,
|
||||
mem0_client: AsyncMemory | AsyncMemoryClient | None = None,
|
||||
) -> None:
|
||||
"""Initializes a new instance of the Mem0Provider class.
|
||||
|
||||
@@ -72,14 +82,14 @@ class Mem0Provider(ContextProvider):
|
||||
|
||||
async def __aenter__(self) -> "Self":
|
||||
"""Async context manager entry."""
|
||||
if self.mem0_client:
|
||||
if self.mem0_client and isinstance(self.mem0_client, AbstractAsyncContextManager):
|
||||
await self.mem0_client.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None:
|
||||
"""Async context manager exit."""
|
||||
if self._should_close_client and self.mem0_client:
|
||||
await self.mem0_client.__aexit__(exc_type, exc_val, exc_tb) # type: ignore
|
||||
if self._should_close_client and self.mem0_client and isinstance(self.mem0_client, AbstractAsyncContextManager):
|
||||
await self.mem0_client.__aexit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
async def thread_created(self, thread_id: str | None = None) -> None:
|
||||
"""Called when a new thread is created.
|
||||
@@ -131,13 +141,22 @@ class Mem0Provider(ContextProvider):
|
||||
messages_list = [messages] if isinstance(messages, ChatMessage) else list(messages)
|
||||
input_text = "\n".join(msg.text for msg in messages_list if msg and msg.text and msg.text.strip())
|
||||
|
||||
memories = await self.mem0_client.search( # type: ignore[misc]
|
||||
search_response: MemorySearchResponse_v1_1 | MemorySearchResponse_v2 = await self.mem0_client.search( # type: ignore[misc]
|
||||
query=input_text,
|
||||
user_id=self.user_id,
|
||||
agent_id=self.agent_id,
|
||||
run_id=self._per_operation_thread_id if self.scope_to_per_operation_thread_id else self.thread_id,
|
||||
)
|
||||
|
||||
# Depending on the API version, the response schema varies slightly
|
||||
if isinstance(search_response, list):
|
||||
memories = search_response
|
||||
elif isinstance(search_response, dict) and "results" in search_response:
|
||||
memories = search_response["results"]
|
||||
else:
|
||||
# Fallback for unexpected schema - return response as text as-is
|
||||
memories = [search_response]
|
||||
|
||||
line_separated_memories = "\n".join(memory.get("memory", "") for memory in memories)
|
||||
|
||||
content = TextContent(f"{self.context_prompt}\n{line_separated_memories}") if line_separated_memories else None
|
||||
|
||||
@@ -10,12 +10,13 @@ This folder contains examples demonstrating how to use the Mem0 context provider
|
||||
|------|-------------|
|
||||
| [`mem0_basic.py`](mem0_basic.py) | Basic example of using Mem0 context provider to store and retrieve user preferences across different conversation threads. |
|
||||
| [`mem0_threads.py`](mem0_threads.py) | Advanced example demonstrating different thread scoping strategies with Mem0. Covers global thread scope (memories shared across all operations), per-operation thread scope (memories isolated per thread), and multiple agents with different memory configurations for personal vs. work contexts. |
|
||||
| [`mem0_oss.py`](mem0_oss.py) | Example of using the Mem0 Open Source self-hosted version as the context provider. Demonstrates setup and configuration for local deployment. |
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Required Resources
|
||||
|
||||
1. [Mem0 API Key](https://app.mem0.ai/) - Sign up for a Mem0 account and get your API key
|
||||
1. [Mem0 API Key](https://app.mem0.ai/) - Sign up for a Mem0 account and get your API key - _or_ self-host [Mem0 Open Source](https://docs.mem0.ai/open-source/overview)
|
||||
2. Azure AI project endpoint (used in these examples)
|
||||
3. Azure CLI authentication (run `az login`)
|
||||
|
||||
@@ -25,8 +26,11 @@ This folder contains examples demonstrating how to use the Mem0 context provider
|
||||
|
||||
Set the following environment variables:
|
||||
|
||||
**For Mem0:**
|
||||
- `MEM0_API_KEY`: Your Mem0 API key (alternatively, pass it as `api_key` parameter to `Mem0Provider`)
|
||||
**For Mem0 Platform:**
|
||||
- `MEM0_API_KEY`: Your Mem0 API key (alternatively, pass it as `api_key` parameter to `Mem0Provider`). Not required if you are self-hosting [Mem0 Open Source](https://docs.mem0.ai/open-source/overview)
|
||||
|
||||
**For Mem0 Open Source:**
|
||||
- `OPENAI_API_KEY`: Your OpenAI API key (used by Mem0 OSS for embedding generation and automatic memory extraction)
|
||||
|
||||
**For Azure AI:**
|
||||
- `AZURE_AI_PROJECT_ENDPOINT`: Your Azure AI project endpoint
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
|
||||
from agent_framework.azure import AzureAIAgentClient
|
||||
from agent_framework.mem0 import Mem0Provider
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from mem0 import AsyncMemory
|
||||
|
||||
|
||||
def retrieve_company_report(company_code: str, detailed: bool) -> str:
|
||||
if company_code != "CNTS":
|
||||
raise ValueError("Company code not found")
|
||||
if not detailed:
|
||||
return "CNTS is a company that specializes in technology."
|
||||
return (
|
||||
"CNTS is a company that specializes in technology. "
|
||||
"It had a revenue of $10 million in 2022. It has 100 employees."
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Example of memory usage with local Mem0 OSS context provider."""
|
||||
print("=== Mem0 Context Provider Example ===")
|
||||
|
||||
# Each record in Mem0 should be associated with agent_id or user_id or application_id or thread_id.
|
||||
# In this example, we associate Mem0 records with user_id.
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
# For Azure authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
|
||||
# authentication option.
|
||||
# By default, local Mem0 authenticates to your OpenAI using the OPENAI_API_KEY environment variable.
|
||||
# See the Mem0 documentation for other LLM providers and authentication options.
|
||||
local_mem0_client = AsyncMemory()
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
AzureAIAgentClient(async_credential=credential).create_agent(
|
||||
name="FriendlyAssistant",
|
||||
instructions="You are a friendly assistant.",
|
||||
tools=retrieve_company_report,
|
||||
context_providers=Mem0Provider(user_id=user_id, mem0_client=local_mem0_client),
|
||||
) as agent,
|
||||
):
|
||||
# First ask the agent to retrieve a company report with no previous context.
|
||||
# The agent will not be able to invoke the tool, since it doesn't know
|
||||
# the company code or the report format, so it should ask for clarification.
|
||||
query = "Please retrieve my company report"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result}\n")
|
||||
|
||||
# Now tell the agent the company code and the report format that you want to use
|
||||
# and it should be able to invoke the tool and return the report.
|
||||
query = "I always work with CNTS and I always want a detailed report format. Please remember and retrieve it."
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result}\n")
|
||||
|
||||
print("\nRequest within a new thread:")
|
||||
|
||||
# Create a new thread for the agent.
|
||||
# The new thread has no context of the previous conversation.
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
# Since we have the mem0 component in the thread, the agent should be able to
|
||||
# retrieve the company report without asking for clarification, as it will
|
||||
# be able to remember the user preferences from Mem0 component.
|
||||
query = "Please retrieve my company report"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query, thread=thread)
|
||||
print(f"Agent: {result}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user