mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Update Mem0Provider to use v2 search API filters parameter (#2766)
* short fix to move id parameters to filters object * added tests * small fix * mem0 dependency update
This commit is contained in:
committed by
GitHub
Unverified
parent
754dfb2c9d
commit
54f482df73
@@ -63,7 +63,7 @@ def _check_openai_version_for_callable_api_key() -> None:
|
||||
raise ServiceInitializationError(
|
||||
f"Callable API keys require OpenAI SDK >= 1.106.0, but you have {openai.__version__}. "
|
||||
f"Please upgrade with 'pip install openai>=1.106.0' or provide a string API key instead. "
|
||||
f"Note: If you're using mem0ai, you may need to upgrade to mem0ai>=0.1.118 "
|
||||
f"Note: If you're using mem0ai, you may need to upgrade to mem0ai>=1.0.0 "
|
||||
f"to allow newer OpenAI versions."
|
||||
)
|
||||
except ServiceInitializationError:
|
||||
|
||||
@@ -150,11 +150,12 @@ class Mem0Provider(ContextProvider):
|
||||
if not input_text.strip():
|
||||
return Context(messages=None)
|
||||
|
||||
# Build filters from init parameters
|
||||
filters = self._build_filters()
|
||||
|
||||
search_response: MemorySearchResponse_v1_1 | MemorySearchResponse_v2 = await self.mem0_client.search( # type: ignore[misc]
|
||||
query=input_text,
|
||||
user_id=self.user_id,
|
||||
agent_id=self.agent_id,
|
||||
run_id=self._per_operation_thread_id if self.scope_to_per_operation_thread_id else self.thread_id,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
# Depending on the API version, the response schema varies slightly
|
||||
@@ -185,6 +186,29 @@ class Mem0Provider(ContextProvider):
|
||||
"At least one of the filters: agent_id, user_id, application_id, or thread_id is required."
|
||||
)
|
||||
|
||||
def _build_filters(self) -> dict[str, Any]:
|
||||
"""Build search filters from initialization parameters.
|
||||
|
||||
Returns:
|
||||
Filter dictionary for mem0 v2 search API containing initialization parameters.
|
||||
In the v2 API, filters holds the user_id, agent_id, run_id (thread_id), and app_id
|
||||
(application_id) which are required for scoping memory search operations.
|
||||
"""
|
||||
filters: dict[str, Any] = {}
|
||||
|
||||
if self.user_id:
|
||||
filters["user_id"] = self.user_id
|
||||
if self.agent_id:
|
||||
filters["agent_id"] = self.agent_id
|
||||
if self.scope_to_per_operation_thread_id and self._per_operation_thread_id:
|
||||
filters["run_id"] = self._per_operation_thread_id
|
||||
elif self.thread_id:
|
||||
filters["run_id"] = self.thread_id
|
||||
if self.application_id:
|
||||
filters["app_id"] = self.application_id
|
||||
|
||||
return filters
|
||||
|
||||
def _validate_per_operation_thread_id(self, thread_id: str | None) -> None:
|
||||
"""Validates that a new thread ID doesn't conflict with an existing one when scoped.
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"agent-framework-core",
|
||||
"mem0ai>=0.1.117",
|
||||
"mem0ai>=1.0.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
|
||||
@@ -338,7 +338,7 @@ class TestMem0ProviderModelInvoking:
|
||||
mock_mem0_client.search.assert_called_once()
|
||||
call_args = mock_mem0_client.search.call_args
|
||||
assert call_args.kwargs["query"] == "What's the weather?"
|
||||
assert call_args.kwargs["user_id"] == "user123"
|
||||
assert call_args.kwargs["filters"] == {"user_id": "user123"}
|
||||
|
||||
assert isinstance(context, Context)
|
||||
expected_instructions = (
|
||||
@@ -373,8 +373,7 @@ class TestMem0ProviderModelInvoking:
|
||||
await provider.invoking(message)
|
||||
|
||||
call_args = mock_mem0_client.search.call_args
|
||||
assert call_args.kwargs["agent_id"] == "agent123"
|
||||
assert call_args.kwargs["user_id"] is None
|
||||
assert call_args.kwargs["filters"] == {"agent_id": "agent123"}
|
||||
|
||||
async def test_model_invoking_with_scope_to_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""Test invoking with scope_to_per_operation_thread_id enabled."""
|
||||
@@ -392,7 +391,7 @@ class TestMem0ProviderModelInvoking:
|
||||
await provider.invoking(message)
|
||||
|
||||
call_args = mock_mem0_client.search.call_args
|
||||
assert call_args.kwargs["run_id"] == "operation_thread"
|
||||
assert call_args.kwargs["filters"] == {"user_id": "user123", "run_id": "operation_thread"}
|
||||
|
||||
async def test_model_invoking_no_memories_returns_none_instructions(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""Test that no memories returns context with None instructions."""
|
||||
@@ -510,3 +509,87 @@ class TestMem0ProviderValidation:
|
||||
|
||||
# Should not raise exception even with different thread ID
|
||||
provider._validate_per_operation_thread_id("different_thread")
|
||||
|
||||
|
||||
class TestMem0ProviderBuildFilters:
|
||||
"""Test the _build_filters method."""
|
||||
|
||||
def test_build_filters_with_user_id_only(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""Test building filters with only user_id."""
|
||||
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
||||
|
||||
filters = provider._build_filters()
|
||||
assert filters == {"user_id": "user123"}
|
||||
|
||||
def test_build_filters_with_all_parameters(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""Test building filters with all initialization parameters."""
|
||||
provider = Mem0Provider(
|
||||
user_id="user123",
|
||||
agent_id="agent456",
|
||||
thread_id="thread789",
|
||||
application_id="app999",
|
||||
mem0_client=mock_mem0_client,
|
||||
)
|
||||
|
||||
filters = provider._build_filters()
|
||||
assert filters == {
|
||||
"user_id": "user123",
|
||||
"agent_id": "agent456",
|
||||
"run_id": "thread789",
|
||||
"app_id": "app999",
|
||||
}
|
||||
|
||||
def test_build_filters_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""Test that None values are excluded from filters."""
|
||||
provider = Mem0Provider(
|
||||
user_id="user123",
|
||||
agent_id=None,
|
||||
thread_id=None,
|
||||
application_id=None,
|
||||
mem0_client=mock_mem0_client,
|
||||
)
|
||||
|
||||
filters = provider._build_filters()
|
||||
assert filters == {"user_id": "user123"}
|
||||
assert "agent_id" not in filters
|
||||
assert "run_id" not in filters
|
||||
assert "app_id" not in filters
|
||||
|
||||
def test_build_filters_with_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""Test that per-operation thread ID takes precedence over base thread_id."""
|
||||
provider = Mem0Provider(
|
||||
user_id="user123",
|
||||
thread_id="base_thread",
|
||||
scope_to_per_operation_thread_id=True,
|
||||
mem0_client=mock_mem0_client,
|
||||
)
|
||||
provider._per_operation_thread_id = "operation_thread"
|
||||
|
||||
filters = provider._build_filters()
|
||||
assert filters == {
|
||||
"user_id": "user123",
|
||||
"run_id": "operation_thread", # Per-operation thread, not base_thread
|
||||
}
|
||||
|
||||
def test_build_filters_uses_base_thread_when_no_per_operation(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""Test that base thread_id is used when per-operation thread is not set."""
|
||||
provider = Mem0Provider(
|
||||
user_id="user123",
|
||||
thread_id="base_thread",
|
||||
scope_to_per_operation_thread_id=True,
|
||||
mem0_client=mock_mem0_client,
|
||||
)
|
||||
# _per_operation_thread_id is None
|
||||
|
||||
filters = provider._build_filters()
|
||||
assert filters == {
|
||||
"user_id": "user123",
|
||||
"run_id": "base_thread", # Falls back to base thread_id
|
||||
}
|
||||
|
||||
def test_build_filters_returns_empty_dict_when_no_parameters(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""Test that _build_filters returns an empty dict when no parameters are set."""
|
||||
provider = Mem0Provider(mem0_client=mock_mem0_client)
|
||||
|
||||
filters = provider._build_filters()
|
||||
assert filters == {}
|
||||
|
||||
@@ -54,6 +54,13 @@ async def main() -> None:
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result}\n")
|
||||
|
||||
# Mem0 processes and indexes memories asynchronously.
|
||||
# Wait for memories to be indexed before querying in a new thread.
|
||||
# In production, consider implementing retry logic or using Mem0's
|
||||
# eventual consistency handling instead of a fixed delay.
|
||||
print("Waiting for memories to be processed...")
|
||||
await asyncio.sleep(12) # Empirically determined delay for Mem0 indexing
|
||||
|
||||
print("\nRequest within a new thread:")
|
||||
# Create a new thread for the agent.
|
||||
# The new thread has no context of the previous conversation.
|
||||
|
||||
Generated
+6
-6
@@ -530,7 +530,7 @@ dependencies = [
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "agent-framework-core", editable = "packages/core" },
|
||||
{ name = "mem0ai", specifier = ">=0.1.117" },
|
||||
{ name = "mem0ai", specifier = ">=1.0.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1341,7 +1341,7 @@ name = "clr-loader"
|
||||
version = "0.2.9"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cffi", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "cffi", marker = "(python_full_version < '3.14' and sys_platform == 'darwin') or (python_full_version < '3.14' and sys_platform == 'linux') or (python_full_version < '3.14' and sys_platform == 'win32')" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/54/c2/da52aaf19424e3f0abec003d08dd1ccae52c88a3b41e31151a03bed18488/clr_loader-0.2.9.tar.gz", hash = "sha256:6af3d582c3de55ce9e9e676d2b3dbf6bc680c4ea8f76c58786739a5bdcf6b52d", size = 84829, upload-time = "2025-12-05T16:57:12.466Z" }
|
||||
wheels = [
|
||||
@@ -1820,7 +1820,7 @@ name = "exceptiongroup"
|
||||
version = "1.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions", marker = "(python_full_version < '3.13' and sys_platform == 'darwin') or (python_full_version < '3.13' and sys_platform == 'linux') or (python_full_version < '3.13' and sys_platform == 'win32')" },
|
||||
{ name = "typing-extensions", marker = "(python_full_version < '3.11' and sys_platform == 'darwin') or (python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'win32')" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" }
|
||||
wheels = [
|
||||
@@ -4473,8 +4473,8 @@ name = "powerfx"
|
||||
version = "0.0.33"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cffi", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "pythonnet", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "cffi", marker = "(python_full_version < '3.14' and sys_platform == 'darwin') or (python_full_version < '3.14' and sys_platform == 'linux') or (python_full_version < '3.14' and sys_platform == 'win32')" },
|
||||
{ name = "pythonnet", marker = "(python_full_version < '3.14' and sys_platform == 'darwin') or (python_full_version < '3.14' and sys_platform == 'linux') or (python_full_version < '3.14' and sys_platform == 'win32')" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5e/41/8f95f72f4f3b7ea54357c449bf5bd94813b6321dec31db9ffcbf578e2fa3/powerfx-0.0.33.tar.gz", hash = "sha256:85e8330bef8a7a207c3e010aa232df0ae38825e94d590c73daf3a3f44115cb09", size = 3236647, upload-time = "2025-11-20T19:31:09.414Z" }
|
||||
wheels = [
|
||||
@@ -5143,7 +5143,7 @@ name = "pythonnet"
|
||||
version = "3.0.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "clr-loader", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "clr-loader", marker = "(python_full_version < '3.14' and sys_platform == 'darwin') or (python_full_version < '3.14' and sys_platform == 'linux') or (python_full_version < '3.14' and sys_platform == 'win32')" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9a/d6/1afd75edd932306ae9bd2c2d961d603dc2b52fcec51b04afea464f1f6646/pythonnet-3.0.5.tar.gz", hash = "sha256:48e43ca463941b3608b32b4e236db92d8d40db4c58a75ace902985f76dac21cf", size = 239212, upload-time = "2024-12-13T08:30:44.393Z" }
|
||||
wheels = [
|
||||
|
||||
Reference in New Issue
Block a user