mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: fix(mem0): isolate entity retrieval and correct app_id payload (#6242)
* fix(mem0): parallel memory retrieval logic and strict type compliance * fix(mem0): align parallel retrieval types for pyright and mypy * fix(mem0): handle asyncio.CancelledError in search response and update test description * fix(mem0): improve error handling for asyncio.CancelledError and update test names for clarity * fix(mem0): improve retrieval response handling
This commit is contained in:
committed by
GitHub
Unverified
parent
331201294b
commit
6169df04cb
@@ -8,29 +8,34 @@ This module provides ``Mem0ContextProvider``, built on the new
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Awaitable
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, TypedDict
|
||||
|
||||
from agent_framework import Message
|
||||
from agent_framework._sessions import AgentSession, ContextProvider, SessionContext
|
||||
from mem0 import AsyncMemory, AsyncMemoryClient
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import NotRequired, Self, TypedDict # pragma: no cover
|
||||
from typing import Self # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import NotRequired, Self, TypedDict # pragma: no cover
|
||||
from typing_extensions import Self # pragma: no cover
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent_framework._agents import SupportsAgentRun
|
||||
|
||||
|
||||
class _MemorySearchResponse_v1_1(TypedDict):
|
||||
results: list[dict[str, Any]]
|
||||
relations: NotRequired[list[dict[str, Any]]]
|
||||
logger = logging.getLogger(__name__)
|
||||
MemoryRecord: TypeAlias = dict[str, object]
|
||||
|
||||
|
||||
_MemorySearchResponse_v2 = list[dict[str, Any]]
|
||||
class SearchResults(TypedDict):
|
||||
results: list[MemoryRecord]
|
||||
|
||||
|
||||
SearchResponse: TypeAlias = list[MemoryRecord] | SearchResults
|
||||
|
||||
|
||||
class Mem0ContextProvider(ContextProvider):
|
||||
@@ -106,28 +111,85 @@ class Mem0ContextProvider(ContextProvider):
|
||||
if not input_text.strip():
|
||||
return
|
||||
|
||||
filters = self._build_filters()
|
||||
# Query entity partitions independently to bypass strict logical AND limitations
|
||||
# Mem0 OSS and Platform SDKs expose inconsistent search typings.
|
||||
search_tasks: list[Awaitable[Any]] = []
|
||||
|
||||
# AsyncMemory (OSS) expects user_id/agent_id/run_id as direct kwargs
|
||||
# AsyncMemoryClient (Platform) expects them in a filters dict
|
||||
search_kwargs: dict[str, Any] = {"query": input_text}
|
||||
if isinstance(self.mem0_client, AsyncMemory):
|
||||
search_kwargs.update(filters)
|
||||
else:
|
||||
search_kwargs["filters"] = filters
|
||||
# 1. Query User partition independently
|
||||
if self.user_id:
|
||||
user_kwargs = self._build_search_kwargs(input_text, "user_id", self.user_id)
|
||||
search_tasks.append(self.mem0_client.search(**user_kwargs)) # type: ignore[reportUnknownMemberType, reportUnknownArgumentType]
|
||||
|
||||
search_response: _MemorySearchResponse_v1_1 | _MemorySearchResponse_v2 = await self.mem0_client.search( # type: ignore[misc]
|
||||
**search_kwargs,
|
||||
)
|
||||
# 2. Query Agent partition independently
|
||||
if self.agent_id:
|
||||
agent_kwargs = self._build_search_kwargs(input_text, "agent_id", self.agent_id)
|
||||
search_tasks.append(self.mem0_client.search(**agent_kwargs)) # type: ignore[reportUnknownMemberType, reportUnknownArgumentType]
|
||||
|
||||
if isinstance(search_response, list):
|
||||
memories = search_response
|
||||
elif isinstance(search_response, dict) and "results" in search_response:
|
||||
memories = search_response["results"]
|
||||
else:
|
||||
memories = [search_response]
|
||||
# Fall back to an app-scoped search when only application_id is configured
|
||||
if not search_tasks and self.application_id:
|
||||
app_kwargs: dict[str, Any] = {"query": input_text}
|
||||
if isinstance(self.mem0_client, AsyncMemory):
|
||||
app_kwargs["app_id"] = self.application_id
|
||||
else:
|
||||
app_kwargs["filters"] = {"app_id": self.application_id}
|
||||
search_tasks.append(self.mem0_client.search(**app_kwargs)) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
|
||||
if not search_tasks:
|
||||
return
|
||||
|
||||
line_separated_memories = "\n".join(memory.get("memory", "") for memory in memories)
|
||||
results: list[SearchResponse | BaseException] = await asyncio.gather(*search_tasks, return_exceptions=True)
|
||||
|
||||
# Merge and deduplicate results
|
||||
memories: list[MemoryRecord] = []
|
||||
seen_memory_ids: set[str] = set()
|
||||
failed_tasks_count: int = 0
|
||||
|
||||
for search_response in results:
|
||||
if isinstance(search_response, asyncio.CancelledError):
|
||||
raise search_response
|
||||
|
||||
if isinstance(search_response, BaseException):
|
||||
failed_tasks_count += 1
|
||||
logger.error(
|
||||
"Mem0 partition search task failed: %s",
|
||||
search_response,
|
||||
exc_info=(type(search_response), search_response, search_response.__traceback__),
|
||||
)
|
||||
continue
|
||||
|
||||
current_memories: list[MemoryRecord] = []
|
||||
if isinstance(search_response, list):
|
||||
current_memories = [mem for mem in search_response if isinstance(mem, dict)]
|
||||
elif isinstance(search_response, dict):
|
||||
results_field = search_response.get("results")
|
||||
if isinstance(results_field, list):
|
||||
current_memories = [
|
||||
item
|
||||
for item in results_field
|
||||
if isinstance(item, dict) # pyright: ignore[reportUnknownVariableType]
|
||||
]
|
||||
else:
|
||||
logger.warning(
|
||||
"Unexpected Mem0 search response format: %s",
|
||||
type(results_field).__name__,
|
||||
)
|
||||
|
||||
for mem in current_memories:
|
||||
mem_id = mem.get("id")
|
||||
if mem_id is not None and not isinstance(mem_id, str):
|
||||
mem_id = str(mem_id)
|
||||
|
||||
if mem_id is not None and mem_id in seen_memory_ids:
|
||||
continue
|
||||
|
||||
if mem_id is not None:
|
||||
seen_memory_ids.add(mem_id)
|
||||
|
||||
memories.append(mem)
|
||||
|
||||
if failed_tasks_count == len(search_tasks):
|
||||
logger.error("All Mem0 retrieval tasks failed. Context provider is unable to verify memory state.")
|
||||
|
||||
line_separated_memories = "\n".join(str(memory.get("memory", "")) for memory in memories)
|
||||
if line_separated_memories:
|
||||
context.extend_messages(
|
||||
self.source_id,
|
||||
@@ -159,12 +221,21 @@ class Mem0ContextProvider(ContextProvider):
|
||||
]
|
||||
|
||||
if messages:
|
||||
await self.mem0_client.add( # type: ignore[misc]
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
agent_id=self.agent_id,
|
||||
metadata={"application_id": self.application_id},
|
||||
)
|
||||
add_kwargs: dict[str, Any] = {
|
||||
"messages": messages,
|
||||
"user_id": self.user_id,
|
||||
"agent_id": self.agent_id,
|
||||
}
|
||||
|
||||
# Inject the application scope using the matching signature format for each SDK variant
|
||||
if isinstance(self.mem0_client, AsyncMemory):
|
||||
if self.application_id:
|
||||
add_kwargs["app_id"] = self.application_id
|
||||
else:
|
||||
if self.application_id:
|
||||
add_kwargs["filters"] = {"app_id": self.application_id}
|
||||
|
||||
await self.mem0_client.add(**add_kwargs) # type: ignore[misc, call-arg]
|
||||
|
||||
# -- Internal methods ------------------------------------------------------
|
||||
|
||||
@@ -173,15 +244,21 @@ class Mem0ContextProvider(ContextProvider):
|
||||
if not self.agent_id and not self.user_id and not self.application_id:
|
||||
raise ValueError("At least one of the filters: agent_id, user_id, or application_id is required.")
|
||||
|
||||
def _build_filters(self) -> dict[str, Any]:
|
||||
"""Build search filters from initialization parameters."""
|
||||
filters: dict[str, Any] = {}
|
||||
if self.user_id:
|
||||
filters["user_id"] = self.user_id
|
||||
if self.agent_id:
|
||||
filters["agent_id"] = self.agent_id
|
||||
if self.application_id:
|
||||
filters["app_id"] = self.application_id
|
||||
def _build_search_kwargs(self, input_text: str, entity_key: str, entity_value: str) -> dict[str, Any]:
|
||||
"""Build search keyword arguments formatted for OSS vs Platform clients."""
|
||||
filters: dict[str, Any] = {"query": input_text}
|
||||
|
||||
if isinstance(self.mem0_client, AsyncMemory):
|
||||
# AsyncMemory (OSS) expects direct kwargs
|
||||
filters[entity_key] = entity_value
|
||||
if self.application_id:
|
||||
filters["app_id"] = self.application_id
|
||||
else:
|
||||
# AsyncMemoryClient (Platform) expects a filters dict
|
||||
filters["filters"] = {entity_key: entity_value}
|
||||
if self.application_id:
|
||||
filters["filters"]["app_id"] = self.application_id
|
||||
|
||||
return filters
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework import AgentResponse, Message
|
||||
@@ -193,39 +193,59 @@ class TestBeforeRun:
|
||||
assert call_kwargs["user_id"] == "u1"
|
||||
assert "filters" not in call_kwargs
|
||||
|
||||
async def test_oss_client_all_scoping_params(self, mock_oss_mem0_client: AsyncMock) -> None:
|
||||
"""OSS client with all scoping parameters passes them as direct kwargs."""
|
||||
@pytest.mark.asyncio
|
||||
async def test_oss_client_all_scoping_params_except_app_id(self, mock_oss_mem0_client: AsyncMock) -> None:
|
||||
"""OSS client with all scoping parameters passes them as isolated concurrent kwargs."""
|
||||
mock_oss_mem0_client.search.return_value = []
|
||||
|
||||
provider = Mem0ContextProvider(
|
||||
source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1", agent_id="a1", application_id="app1"
|
||||
source_id="mem0",
|
||||
mem0_client=mock_oss_mem0_client,
|
||||
user_id="u1",
|
||||
agent_id="a1"
|
||||
)
|
||||
session = AgentSession(session_id="test-session")
|
||||
ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1")
|
||||
|
||||
mock_context = MagicMock(spec=SessionContext)
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.text = "hello"
|
||||
mock_context.input_messages = [mock_msg]
|
||||
mock_context.response = None
|
||||
|
||||
await provider.before_run(
|
||||
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
||||
) # type: ignore[arg-type]
|
||||
agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={}
|
||||
)
|
||||
|
||||
call_kwargs = mock_oss_mem0_client.search.call_args.kwargs
|
||||
assert call_kwargs["user_id"] == "u1"
|
||||
assert call_kwargs["agent_id"] == "a1"
|
||||
assert "filters" not in call_kwargs
|
||||
# Re-aligned assertion: We expect 2 separate concurrent calls instead of 1 combined call
|
||||
assert mock_oss_mem0_client.search.call_count == 2
|
||||
mock_oss_mem0_client.search.assert_any_call(query="hello", user_id="u1")
|
||||
mock_oss_mem0_client.search.assert_any_call(query="hello", agent_id="a1")
|
||||
|
||||
async def test_platform_client_passes_filters_dict(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""Platform AsyncMemoryClient should receive scoping params in a filters dict."""
|
||||
@pytest.mark.asyncio
|
||||
async def test_platform_client_passes_filters_dict_except_app_id(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""Platform client passes scoping parameters concurrently inside the nested filters dictionary."""
|
||||
mock_mem0_client.search.return_value = []
|
||||
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
||||
session = AgentSession(session_id="test-session")
|
||||
ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1")
|
||||
|
||||
provider = Mem0ContextProvider(
|
||||
source_id="mem0",
|
||||
mem0_client=mock_mem0_client,
|
||||
user_id="u1",
|
||||
agent_id="a1",
|
||||
)
|
||||
|
||||
mock_context = MagicMock(spec=SessionContext)
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.text = "hello"
|
||||
mock_context.input_messages = [mock_msg]
|
||||
mock_context.response = None
|
||||
|
||||
await provider.before_run(
|
||||
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
||||
) # type: ignore[arg-type]
|
||||
agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={}
|
||||
)
|
||||
|
||||
call_kwargs = mock_mem0_client.search.call_args.kwargs
|
||||
assert call_kwargs["query"] == "Hello"
|
||||
assert "filters" in call_kwargs
|
||||
assert call_kwargs["filters"]["user_id"] == "u1"
|
||||
# Re-aligned assertion: Platform client isolates filters per call to bypass AND limitations
|
||||
assert mock_mem0_client.search.call_count == 2
|
||||
mock_mem0_client.search.assert_any_call(query="hello", filters={"user_id": "u1"})
|
||||
mock_mem0_client.search.assert_any_call(query="hello", filters={"agent_id": "a1"})
|
||||
|
||||
|
||||
# -- after_run tests -----------------------------------------------------------
|
||||
@@ -318,8 +338,8 @@ class TestAfterRun:
|
||||
with pytest.raises(ValueError, match="At least one of the filters"):
|
||||
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
||||
|
||||
async def test_stores_with_application_id_metadata(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""application_id is passed in metadata."""
|
||||
async def test_stores_with_application_id_filters(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""application_id is passed in filters."""
|
||||
provider = Mem0ContextProvider(
|
||||
source_id="mem0", mem0_client=mock_mem0_client, user_id="u1", application_id="app1"
|
||||
)
|
||||
@@ -331,7 +351,7 @@ class TestAfterRun:
|
||||
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert mock_mem0_client.add.call_args.kwargs["metadata"] == {"application_id": "app1"}
|
||||
assert mock_mem0_client.add.call_args.kwargs["filters"] == {"app_id": "app1"}
|
||||
|
||||
|
||||
# -- _validate_filters tests --------------------------------------------------
|
||||
@@ -358,15 +378,20 @@ class TestValidateFilters:
|
||||
provider._validate_filters()
|
||||
|
||||
|
||||
# -- _build_filters tests -----------------------------------------------------
|
||||
# -- _build_search_kwargs tests -----------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildFilters:
|
||||
"""Test _build_filters method."""
|
||||
class TestBuildSearchKwargs:
|
||||
"""Test _build_search_kwargs method."""
|
||||
|
||||
def test_user_id_only(self, mock_mem0_client: AsyncMock) -> None:
|
||||
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
||||
assert provider._build_filters() == {"user_id": "u1"}
|
||||
|
||||
# Pass the 3 required arguments
|
||||
result = provider._build_search_kwargs("test query", "user_id", "u1")
|
||||
|
||||
# AsyncMock triggers the Platform client nested 'filters' structure
|
||||
assert result == {"query": "test query", "filters": {"user_id": "u1"}}
|
||||
|
||||
def test_all_params(self, mock_mem0_client: AsyncMock) -> None:
|
||||
provider = Mem0ContextProvider(
|
||||
@@ -376,28 +401,66 @@ class TestBuildFilters:
|
||||
agent_id="a1",
|
||||
application_id="app1",
|
||||
)
|
||||
assert provider._build_filters() == {
|
||||
"user_id": "u1",
|
||||
"agent_id": "a1",
|
||||
"app_id": "app1",
|
||||
|
||||
# Test that app_id correctly merges with the isolated target entity
|
||||
result = provider._build_search_kwargs("test query", "agent_id", "a1")
|
||||
|
||||
assert result == {
|
||||
"query": "test query",
|
||||
"filters": {
|
||||
"agent_id": "a1",
|
||||
"app_id": "app1",
|
||||
},
|
||||
}
|
||||
|
||||
def test_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None:
|
||||
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
||||
filters = provider._build_filters()
|
||||
assert "agent_id" not in filters
|
||||
assert "run_id" not in filters
|
||||
assert "app_id" not in filters
|
||||
|
||||
# application_id is None by default, it should not appear in the dictionary
|
||||
result = provider._build_search_kwargs("test query", "user_id", "u1")
|
||||
|
||||
assert "app_id" not in result.get("filters", {})
|
||||
|
||||
def test_no_run_id_in_search_filters(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""run_id is excluded from search filters so memories work across sessions."""
|
||||
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
||||
filters = provider._build_filters()
|
||||
assert "run_id" not in filters
|
||||
|
||||
result = provider._build_search_kwargs("test query", "user_id", "u1")
|
||||
|
||||
assert "run_id" not in result.get("filters", {})
|
||||
assert "run_id" not in result
|
||||
|
||||
def test_empty_when_no_params(self, mock_mem0_client: AsyncMock) -> None:
|
||||
# Validates base query payload generation
|
||||
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
|
||||
assert provider._build_filters() == {}
|
||||
|
||||
result = provider._build_search_kwargs("test query", "custom_key", "custom_val")
|
||||
|
||||
assert result == {"query": "test query", "filters": {"custom_key": "custom_val"}}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_before_run_application_only_fallback(self, mock_mem0_client: AsyncMock) -> None:
|
||||
|
||||
provider = Mem0ContextProvider(
|
||||
source_id="mem0", mem0_client=mock_mem0_client, application_id="app_fallback_test"
|
||||
)
|
||||
|
||||
# Mock a valid message list and session container setup
|
||||
mock_context = MagicMock(spec=SessionContext)
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.text = "Retrieve systemic fallback memory traces"
|
||||
mock_context.input_messages = [mock_msg]
|
||||
mock_context.response = None
|
||||
|
||||
mock_mem0_client.search = AsyncMock(return_value=[{"id": "m1", "memory": "System configuration template"}])
|
||||
|
||||
await provider.before_run(
|
||||
agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={}
|
||||
)
|
||||
|
||||
# Verify that an application-scoped search task executed successfully
|
||||
assert mock_mem0_client.search.call_count == 1
|
||||
mock_context.extend_messages.assert_called_once()
|
||||
|
||||
|
||||
# -- Context manager tests -----------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user