Python: Update Mem0Provider to use v2 search API filters parameter (#2766)

* short fix to move id parameters to filters object

* added tests

* small fix

* mem0 dependency update
This commit is contained in:
Giles Odigwe
2025-12-16 10:23:37 -08:00
committed by GitHub
Unverified
parent 754dfb2c9d
commit 54f482df73
6 changed files with 129 additions and 15 deletions
@@ -63,7 +63,7 @@ def _check_openai_version_for_callable_api_key() -> None:
raise ServiceInitializationError(
f"Callable API keys require OpenAI SDK >= 1.106.0, but you have {openai.__version__}. "
f"Please upgrade with 'pip install openai>=1.106.0' or provide a string API key instead. "
f"Note: If you're using mem0ai, you may need to upgrade to mem0ai>=0.1.118 "
f"Note: If you're using mem0ai, you may need to upgrade to mem0ai>=1.0.0 "
f"to allow newer OpenAI versions."
)
except ServiceInitializationError:
@@ -150,11 +150,12 @@ class Mem0Provider(ContextProvider):
if not input_text.strip():
return Context(messages=None)
# Build filters from init parameters
filters = self._build_filters()
search_response: MemorySearchResponse_v1_1 | MemorySearchResponse_v2 = await self.mem0_client.search( # type: ignore[misc]
query=input_text,
user_id=self.user_id,
agent_id=self.agent_id,
run_id=self._per_operation_thread_id if self.scope_to_per_operation_thread_id else self.thread_id,
filters=filters,
)
# Depending on the API version, the response schema varies slightly
@@ -185,6 +186,29 @@ class Mem0Provider(ContextProvider):
"At least one of the filters: agent_id, user_id, application_id, or thread_id is required."
)
def _build_filters(self) -> dict[str, Any]:
"""Build search filters from initialization parameters.
Returns:
Filter dictionary for mem0 v2 search API containing initialization parameters.
In the v2 API, filters holds the user_id, agent_id, run_id (thread_id), and app_id
(application_id) which are required for scoping memory search operations.
"""
filters: dict[str, Any] = {}
if self.user_id:
filters["user_id"] = self.user_id
if self.agent_id:
filters["agent_id"] = self.agent_id
if self.scope_to_per_operation_thread_id and self._per_operation_thread_id:
filters["run_id"] = self._per_operation_thread_id
elif self.thread_id:
filters["run_id"] = self.thread_id
if self.application_id:
filters["app_id"] = self.application_id
return filters
def _validate_per_operation_thread_id(self, thread_id: str | None) -> None:
"""Validates that a new thread ID doesn't conflict with an existing one when scoped.
+1 -1
View File
@@ -24,7 +24,7 @@ classifiers = [
]
dependencies = [
"agent-framework-core",
"mem0ai>=0.1.117",
"mem0ai>=1.0.0",
]
[tool.uv]
@@ -338,7 +338,7 @@ class TestMem0ProviderModelInvoking:
mock_mem0_client.search.assert_called_once()
call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["query"] == "What's the weather?"
assert call_args.kwargs["user_id"] == "user123"
assert call_args.kwargs["filters"] == {"user_id": "user123"}
assert isinstance(context, Context)
expected_instructions = (
@@ -373,8 +373,7 @@ class TestMem0ProviderModelInvoking:
await provider.invoking(message)
call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["agent_id"] == "agent123"
assert call_args.kwargs["user_id"] is None
assert call_args.kwargs["filters"] == {"agent_id": "agent123"}
async def test_model_invoking_with_scope_to_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None:
"""Test invoking with scope_to_per_operation_thread_id enabled."""
@@ -392,7 +391,7 @@ class TestMem0ProviderModelInvoking:
await provider.invoking(message)
call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["run_id"] == "operation_thread"
assert call_args.kwargs["filters"] == {"user_id": "user123", "run_id": "operation_thread"}
async def test_model_invoking_no_memories_returns_none_instructions(self, mock_mem0_client: AsyncMock) -> None:
"""Test that no memories returns context with None instructions."""
@@ -510,3 +509,87 @@ class TestMem0ProviderValidation:
# Should not raise exception even with different thread ID
provider._validate_per_operation_thread_id("different_thread")
class TestMem0ProviderBuildFilters:
"""Test the _build_filters method."""
def test_build_filters_with_user_id_only(self, mock_mem0_client: AsyncMock) -> None:
"""Test building filters with only user_id."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
filters = provider._build_filters()
assert filters == {"user_id": "user123"}
def test_build_filters_with_all_parameters(self, mock_mem0_client: AsyncMock) -> None:
"""Test building filters with all initialization parameters."""
provider = Mem0Provider(
user_id="user123",
agent_id="agent456",
thread_id="thread789",
application_id="app999",
mem0_client=mock_mem0_client,
)
filters = provider._build_filters()
assert filters == {
"user_id": "user123",
"agent_id": "agent456",
"run_id": "thread789",
"app_id": "app999",
}
def test_build_filters_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None:
"""Test that None values are excluded from filters."""
provider = Mem0Provider(
user_id="user123",
agent_id=None,
thread_id=None,
application_id=None,
mem0_client=mock_mem0_client,
)
filters = provider._build_filters()
assert filters == {"user_id": "user123"}
assert "agent_id" not in filters
assert "run_id" not in filters
assert "app_id" not in filters
def test_build_filters_with_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None:
"""Test that per-operation thread ID takes precedence over base thread_id."""
provider = Mem0Provider(
user_id="user123",
thread_id="base_thread",
scope_to_per_operation_thread_id=True,
mem0_client=mock_mem0_client,
)
provider._per_operation_thread_id = "operation_thread"
filters = provider._build_filters()
assert filters == {
"user_id": "user123",
"run_id": "operation_thread", # Per-operation thread, not base_thread
}
def test_build_filters_uses_base_thread_when_no_per_operation(self, mock_mem0_client: AsyncMock) -> None:
"""Test that base thread_id is used when per-operation thread is not set."""
provider = Mem0Provider(
user_id="user123",
thread_id="base_thread",
scope_to_per_operation_thread_id=True,
mem0_client=mock_mem0_client,
)
# _per_operation_thread_id is None
filters = provider._build_filters()
assert filters == {
"user_id": "user123",
"run_id": "base_thread", # Falls back to base thread_id
}
def test_build_filters_returns_empty_dict_when_no_parameters(self, mock_mem0_client: AsyncMock) -> None:
"""Test that _build_filters returns an empty dict when no parameters are set."""
provider = Mem0Provider(mem0_client=mock_mem0_client)
filters = provider._build_filters()
assert filters == {}
@@ -54,6 +54,13 @@ async def main() -> None:
result = await agent.run(query)
print(f"Agent: {result}\n")
# Mem0 processes and indexes memories asynchronously.
# Wait for memories to be indexed before querying in a new thread.
# In production, consider implementing retry logic or using Mem0's
# eventual consistency handling instead of a fixed delay.
print("Waiting for memories to be processed...")
await asyncio.sleep(12) # Empirically determined delay for Mem0 indexing
print("\nRequest within a new thread:")
# Create a new thread for the agent.
# The new thread has no context of the previous conversation.
+6 -6
View File
@@ -530,7 +530,7 @@ dependencies = [
[package.metadata]
requires-dist = [
{ name = "agent-framework-core", editable = "packages/core" },
{ name = "mem0ai", specifier = ">=0.1.117" },
{ name = "mem0ai", specifier = ">=1.0.0" },
]
[[package]]
@@ -1341,7 +1341,7 @@ name = "clr-loader"
version = "0.2.9"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cffi", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "cffi", marker = "(python_full_version < '3.14' and sys_platform == 'darwin') or (python_full_version < '3.14' and sys_platform == 'linux') or (python_full_version < '3.14' and sys_platform == 'win32')" },
]
sdist = { url = "https://files.pythonhosted.org/packages/54/c2/da52aaf19424e3f0abec003d08dd1ccae52c88a3b41e31151a03bed18488/clr_loader-0.2.9.tar.gz", hash = "sha256:6af3d582c3de55ce9e9e676d2b3dbf6bc680c4ea8f76c58786739a5bdcf6b52d", size = 84829, upload-time = "2025-12-05T16:57:12.466Z" }
wheels = [
@@ -1820,7 +1820,7 @@ name = "exceptiongroup"
version = "1.3.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions", marker = "(python_full_version < '3.13' and sys_platform == 'darwin') or (python_full_version < '3.13' and sys_platform == 'linux') or (python_full_version < '3.13' and sys_platform == 'win32')" },
{ name = "typing-extensions", marker = "(python_full_version < '3.11' and sys_platform == 'darwin') or (python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'win32')" },
]
sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" }
wheels = [
@@ -4473,8 +4473,8 @@ name = "powerfx"
version = "0.0.33"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cffi", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "pythonnet", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "cffi", marker = "(python_full_version < '3.14' and sys_platform == 'darwin') or (python_full_version < '3.14' and sys_platform == 'linux') or (python_full_version < '3.14' and sys_platform == 'win32')" },
{ name = "pythonnet", marker = "(python_full_version < '3.14' and sys_platform == 'darwin') or (python_full_version < '3.14' and sys_platform == 'linux') or (python_full_version < '3.14' and sys_platform == 'win32')" },
]
sdist = { url = "https://files.pythonhosted.org/packages/5e/41/8f95f72f4f3b7ea54357c449bf5bd94813b6321dec31db9ffcbf578e2fa3/powerfx-0.0.33.tar.gz", hash = "sha256:85e8330bef8a7a207c3e010aa232df0ae38825e94d590c73daf3a3f44115cb09", size = 3236647, upload-time = "2025-11-20T19:31:09.414Z" }
wheels = [
@@ -5143,7 +5143,7 @@ name = "pythonnet"
version = "3.0.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "clr-loader", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "clr-loader", marker = "(python_full_version < '3.14' and sys_platform == 'darwin') or (python_full_version < '3.14' and sys_platform == 'linux') or (python_full_version < '3.14' and sys_platform == 'win32')" },
]
sdist = { url = "https://files.pythonhosted.org/packages/9a/d6/1afd75edd932306ae9bd2c2d961d603dc2b52fcec51b04afea464f1f6646/pythonnet-3.0.5.tar.gz", hash = "sha256:48e43ca463941b3608b32b4e236db92d8d40db4c58a75ace902985f76dac21cf", size = 239212, upload-time = "2024-12-13T08:30:44.393Z" }
wheels = [