Python: fix(mem0): isolate entity retrieval and correct app_id payload (#6242)

* fix(mem0): parallel memory retrieval logic and strict type compliance

* fix(mem0): align parallel retrieval types for pyright and mypy

* fix(mem0): handle asyncio.CancelledError in search response and update test description

* fix(mem0): improve error handling for asyncio.CancelledError and update test names for clarity

* fix(mem0): improve retrieval response handling
This commit is contained in:
Vedant Sonani
2026-06-08 19:20:23 +05:30
committed by GitHub
Unverified
parent 331201294b
commit 6169df04cb
2 changed files with 222 additions and 82 deletions
@@ -8,29 +8,34 @@ This module provides ``Mem0ContextProvider``, built on the new
from __future__ import annotations
import asyncio
import logging
import sys
from collections.abc import Awaitable
from contextlib import AbstractAsyncContextManager
from typing import TYPE_CHECKING, Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, TypedDict
from agent_framework import Message
from agent_framework._sessions import AgentSession, ContextProvider, SessionContext
from mem0 import AsyncMemory, AsyncMemoryClient
if sys.version_info >= (3, 11):
from typing import NotRequired, Self, TypedDict # pragma: no cover
from typing import Self # pragma: no cover
else:
from typing_extensions import NotRequired, Self, TypedDict # pragma: no cover
from typing_extensions import Self # pragma: no cover
if TYPE_CHECKING:
from agent_framework._agents import SupportsAgentRun
class _MemorySearchResponse_v1_1(TypedDict):
results: list[dict[str, Any]]
relations: NotRequired[list[dict[str, Any]]]
logger = logging.getLogger(__name__)
MemoryRecord: TypeAlias = dict[str, object]
_MemorySearchResponse_v2 = list[dict[str, Any]]
class SearchResults(TypedDict):
results: list[MemoryRecord]
SearchResponse: TypeAlias = list[MemoryRecord] | SearchResults
class Mem0ContextProvider(ContextProvider):
@@ -106,28 +111,85 @@ class Mem0ContextProvider(ContextProvider):
if not input_text.strip():
return
filters = self._build_filters()
# Query entity partitions independently to bypass strict logical AND limitations
# Mem0 OSS and Platform SDKs expose inconsistent search typings.
search_tasks: list[Awaitable[Any]] = []
# AsyncMemory (OSS) expects user_id/agent_id/run_id as direct kwargs
# AsyncMemoryClient (Platform) expects them in a filters dict
search_kwargs: dict[str, Any] = {"query": input_text}
if isinstance(self.mem0_client, AsyncMemory):
search_kwargs.update(filters)
else:
search_kwargs["filters"] = filters
# 1. Query User partition independently
if self.user_id:
user_kwargs = self._build_search_kwargs(input_text, "user_id", self.user_id)
search_tasks.append(self.mem0_client.search(**user_kwargs)) # type: ignore[reportUnknownMemberType, reportUnknownArgumentType]
search_response: _MemorySearchResponse_v1_1 | _MemorySearchResponse_v2 = await self.mem0_client.search( # type: ignore[misc]
**search_kwargs,
)
# 2. Query Agent partition independently
if self.agent_id:
agent_kwargs = self._build_search_kwargs(input_text, "agent_id", self.agent_id)
search_tasks.append(self.mem0_client.search(**agent_kwargs)) # type: ignore[reportUnknownMemberType, reportUnknownArgumentType]
if isinstance(search_response, list):
memories = search_response
elif isinstance(search_response, dict) and "results" in search_response:
memories = search_response["results"]
else:
memories = [search_response]
# Fall back to an app-scoped search when only application_id is configured
if not search_tasks and self.application_id:
app_kwargs: dict[str, Any] = {"query": input_text}
if isinstance(self.mem0_client, AsyncMemory):
app_kwargs["app_id"] = self.application_id
else:
app_kwargs["filters"] = {"app_id": self.application_id}
search_tasks.append(self.mem0_client.search(**app_kwargs)) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
if not search_tasks:
return
line_separated_memories = "\n".join(memory.get("memory", "") for memory in memories)
results: list[SearchResponse | BaseException] = await asyncio.gather(*search_tasks, return_exceptions=True)
# Merge and deduplicate results
memories: list[MemoryRecord] = []
seen_memory_ids: set[str] = set()
failed_tasks_count: int = 0
for search_response in results:
if isinstance(search_response, asyncio.CancelledError):
raise search_response
if isinstance(search_response, BaseException):
failed_tasks_count += 1
logger.error(
"Mem0 partition search task failed: %s",
search_response,
exc_info=(type(search_response), search_response, search_response.__traceback__),
)
continue
current_memories: list[MemoryRecord] = []
if isinstance(search_response, list):
current_memories = [mem for mem in search_response if isinstance(mem, dict)]
elif isinstance(search_response, dict):
results_field = search_response.get("results")
if isinstance(results_field, list):
current_memories = [
item
for item in results_field
if isinstance(item, dict) # pyright: ignore[reportUnknownVariableType]
]
else:
logger.warning(
"Unexpected Mem0 search response format: %s",
type(results_field).__name__,
)
for mem in current_memories:
mem_id = mem.get("id")
if mem_id is not None and not isinstance(mem_id, str):
mem_id = str(mem_id)
if mem_id is not None and mem_id in seen_memory_ids:
continue
if mem_id is not None:
seen_memory_ids.add(mem_id)
memories.append(mem)
if failed_tasks_count == len(search_tasks):
logger.error("All Mem0 retrieval tasks failed. Context provider is unable to verify memory state.")
line_separated_memories = "\n".join(str(memory.get("memory", "")) for memory in memories)
if line_separated_memories:
context.extend_messages(
self.source_id,
@@ -159,12 +221,21 @@ class Mem0ContextProvider(ContextProvider):
]
if messages:
await self.mem0_client.add( # type: ignore[misc]
messages=messages,
user_id=self.user_id,
agent_id=self.agent_id,
metadata={"application_id": self.application_id},
)
add_kwargs: dict[str, Any] = {
"messages": messages,
"user_id": self.user_id,
"agent_id": self.agent_id,
}
# Inject the application scope using the matching signature format for each SDK variant
if isinstance(self.mem0_client, AsyncMemory):
if self.application_id:
add_kwargs["app_id"] = self.application_id
else:
if self.application_id:
add_kwargs["filters"] = {"app_id": self.application_id}
await self.mem0_client.add(**add_kwargs) # type: ignore[misc, call-arg]
# -- Internal methods ------------------------------------------------------
@@ -173,15 +244,21 @@ class Mem0ContextProvider(ContextProvider):
if not self.agent_id and not self.user_id and not self.application_id:
raise ValueError("At least one of the filters: agent_id, user_id, or application_id is required.")
def _build_filters(self) -> dict[str, Any]:
"""Build search filters from initialization parameters."""
filters: dict[str, Any] = {}
if self.user_id:
filters["user_id"] = self.user_id
if self.agent_id:
filters["agent_id"] = self.agent_id
if self.application_id:
filters["app_id"] = self.application_id
def _build_search_kwargs(self, input_text: str, entity_key: str, entity_value: str) -> dict[str, Any]:
"""Build search keyword arguments formatted for OSS vs Platform clients."""
filters: dict[str, Any] = {"query": input_text}
if isinstance(self.mem0_client, AsyncMemory):
# AsyncMemory (OSS) expects direct kwargs
filters[entity_key] = entity_value
if self.application_id:
filters["app_id"] = self.application_id
else:
# AsyncMemoryClient (Platform) expects a filters dict
filters["filters"] = {entity_key: entity_value}
if self.application_id:
filters["filters"]["app_id"] = self.application_id
return filters
@@ -3,7 +3,7 @@
from __future__ import annotations
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agent_framework import AgentResponse, Message
@@ -193,39 +193,59 @@ class TestBeforeRun:
assert call_kwargs["user_id"] == "u1"
assert "filters" not in call_kwargs
async def test_oss_client_all_scoping_params(self, mock_oss_mem0_client: AsyncMock) -> None:
"""OSS client with all scoping parameters passes them as direct kwargs."""
@pytest.mark.asyncio
async def test_oss_client_all_scoping_params_except_app_id(self, mock_oss_mem0_client: AsyncMock) -> None:
"""OSS client with all scoping parameters passes them as isolated concurrent kwargs."""
mock_oss_mem0_client.search.return_value = []
provider = Mem0ContextProvider(
source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1", agent_id="a1", application_id="app1"
source_id="mem0",
mem0_client=mock_oss_mem0_client,
user_id="u1",
agent_id="a1"
)
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1")
mock_context = MagicMock(spec=SessionContext)
mock_msg = MagicMock()
mock_msg.text = "hello"
mock_context.input_messages = [mock_msg]
mock_context.response = None
await provider.before_run(
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={}
)
call_kwargs = mock_oss_mem0_client.search.call_args.kwargs
assert call_kwargs["user_id"] == "u1"
assert call_kwargs["agent_id"] == "a1"
assert "filters" not in call_kwargs
# Re-aligned assertion: We expect 2 separate concurrent calls instead of 1 combined call
assert mock_oss_mem0_client.search.call_count == 2
mock_oss_mem0_client.search.assert_any_call(query="hello", user_id="u1")
mock_oss_mem0_client.search.assert_any_call(query="hello", agent_id="a1")
async def test_platform_client_passes_filters_dict(self, mock_mem0_client: AsyncMock) -> None:
"""Platform AsyncMemoryClient should receive scoping params in a filters dict."""
@pytest.mark.asyncio
async def test_platform_client_passes_filters_dict_except_app_id(self, mock_mem0_client: AsyncMock) -> None:
"""Platform client passes scoping parameters concurrently inside the nested filters dictionary."""
mock_mem0_client.search.return_value = []
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1")
provider = Mem0ContextProvider(
source_id="mem0",
mem0_client=mock_mem0_client,
user_id="u1",
agent_id="a1",
)
mock_context = MagicMock(spec=SessionContext)
mock_msg = MagicMock()
mock_msg.text = "hello"
mock_context.input_messages = [mock_msg]
mock_context.response = None
await provider.before_run(
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={}
)
call_kwargs = mock_mem0_client.search.call_args.kwargs
assert call_kwargs["query"] == "Hello"
assert "filters" in call_kwargs
assert call_kwargs["filters"]["user_id"] == "u1"
# Re-aligned assertion: Platform client isolates filters per call to bypass AND limitations
assert mock_mem0_client.search.call_count == 2
mock_mem0_client.search.assert_any_call(query="hello", filters={"user_id": "u1"})
mock_mem0_client.search.assert_any_call(query="hello", filters={"agent_id": "a1"})
# -- after_run tests -----------------------------------------------------------
@@ -318,8 +338,8 @@ class TestAfterRun:
with pytest.raises(ValueError, match="At least one of the filters"):
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
async def test_stores_with_application_id_metadata(self, mock_mem0_client: AsyncMock) -> None:
"""application_id is passed in metadata."""
async def test_stores_with_application_id_filters(self, mock_mem0_client: AsyncMock) -> None:
"""application_id is passed in filters."""
provider = Mem0ContextProvider(
source_id="mem0", mem0_client=mock_mem0_client, user_id="u1", application_id="app1"
)
@@ -331,7 +351,7 @@ class TestAfterRun:
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
assert mock_mem0_client.add.call_args.kwargs["metadata"] == {"application_id": "app1"}
assert mock_mem0_client.add.call_args.kwargs["filters"] == {"app_id": "app1"}
# -- _validate_filters tests --------------------------------------------------
@@ -358,15 +378,20 @@ class TestValidateFilters:
provider._validate_filters()
# -- _build_filters tests -----------------------------------------------------
# -- _build_search_kwargs tests -----------------------------------------------------
class TestBuildFilters:
"""Test _build_filters method."""
class TestBuildSearchKwargs:
"""Test _build_search_kwargs method."""
def test_user_id_only(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
assert provider._build_filters() == {"user_id": "u1"}
# Pass the 3 required arguments
result = provider._build_search_kwargs("test query", "user_id", "u1")
# AsyncMock triggers the Platform client nested 'filters' structure
assert result == {"query": "test query", "filters": {"user_id": "u1"}}
def test_all_params(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(
@@ -376,28 +401,66 @@ class TestBuildFilters:
agent_id="a1",
application_id="app1",
)
assert provider._build_filters() == {
"user_id": "u1",
"agent_id": "a1",
"app_id": "app1",
# Test that app_id correctly merges with the isolated target entity
result = provider._build_search_kwargs("test query", "agent_id", "a1")
assert result == {
"query": "test query",
"filters": {
"agent_id": "a1",
"app_id": "app1",
},
}
def test_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
filters = provider._build_filters()
assert "agent_id" not in filters
assert "run_id" not in filters
assert "app_id" not in filters
# application_id is None by default, it should not appear in the dictionary
result = provider._build_search_kwargs("test query", "user_id", "u1")
assert "app_id" not in result.get("filters", {})
def test_no_run_id_in_search_filters(self, mock_mem0_client: AsyncMock) -> None:
"""run_id is excluded from search filters so memories work across sessions."""
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
filters = provider._build_filters()
assert "run_id" not in filters
result = provider._build_search_kwargs("test query", "user_id", "u1")
assert "run_id" not in result.get("filters", {})
assert "run_id" not in result
def test_empty_when_no_params(self, mock_mem0_client: AsyncMock) -> None:
# Validates base query payload generation
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
assert provider._build_filters() == {}
result = provider._build_search_kwargs("test query", "custom_key", "custom_val")
assert result == {"query": "test query", "filters": {"custom_key": "custom_val"}}
@pytest.mark.asyncio
async def test_before_run_application_only_fallback(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(
source_id="mem0", mem0_client=mock_mem0_client, application_id="app_fallback_test"
)
# Mock a valid message list and session container setup
mock_context = MagicMock(spec=SessionContext)
mock_msg = MagicMock()
mock_msg.text = "Retrieve systemic fallback memory traces"
mock_context.input_messages = [mock_msg]
mock_context.response = None
mock_mem0_client.search = AsyncMock(return_value=[{"id": "m1", "memory": "System configuration template"}])
await provider.before_run(
agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={}
)
# Verify that an application-scoped search task executed successfully
assert mock_mem0_client.search.call_count == 1
mock_context.extend_messages.assert_called_once()
# -- Context manager tests -----------------------------------------------------