mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
83e6229c11
* ported Content to a new model * fixed linting * fixes * fixed data format handling * fix for 3.10 mypy * fix * fix int test
595 lines
24 KiB
Python
595 lines
24 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
# pyright: reportPrivateUsage=false
|
|
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
from agent_framework import ChatMessage, Content, Context, Role
|
|
from agent_framework.exceptions import ServiceInitializationError
|
|
from agent_framework.mem0 import Mem0Provider
|
|
|
|
|
|
def test_mem0_provider_import() -> None:
|
|
"""Test that Mem0Provider can be imported."""
|
|
assert Mem0Provider is not None
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_mem0_client() -> AsyncMock:
|
|
"""Create a mock Mem0 AsyncMemoryClient."""
|
|
from mem0 import AsyncMemoryClient
|
|
|
|
mock_client = AsyncMock(spec=AsyncMemoryClient)
|
|
mock_client.add = AsyncMock()
|
|
mock_client.search = AsyncMock()
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock()
|
|
mock_client.async_client = AsyncMock()
|
|
mock_client.async_client.aclose = AsyncMock()
|
|
return mock_client
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_messages() -> list[ChatMessage]:
|
|
"""Create sample chat messages for testing."""
|
|
return [
|
|
ChatMessage(role=Role.USER, text="Hello, how are you?"),
|
|
ChatMessage(role=Role.ASSISTANT, text="I'm doing well, thank you!"),
|
|
ChatMessage(role=Role.SYSTEM, text="You are a helpful assistant"),
|
|
]
|
|
|
|
|
|
class TestMem0ProviderInitialization:
|
|
"""Test initialization and configuration of Mem0Provider."""
|
|
|
|
def test_init_with_all_ids(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test initialization with all IDs provided."""
|
|
provider = Mem0Provider(
|
|
user_id="user123",
|
|
agent_id="agent123",
|
|
application_id="app123",
|
|
thread_id="thread123",
|
|
mem0_client=mock_mem0_client,
|
|
)
|
|
assert provider.user_id == "user123"
|
|
assert provider.agent_id == "agent123"
|
|
assert provider.application_id == "app123"
|
|
assert provider.thread_id == "thread123"
|
|
|
|
def test_init_without_filters_succeeds(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test that initialization succeeds even without filters (validation happens during invocation)."""
|
|
provider = Mem0Provider(mem0_client=mock_mem0_client)
|
|
assert provider.user_id is None
|
|
assert provider.agent_id is None
|
|
assert provider.application_id is None
|
|
assert provider.thread_id is None
|
|
|
|
def test_init_with_custom_context_prompt(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test initialization with custom context prompt."""
|
|
custom_prompt = "## Custom Memories\nConsider these memories:"
|
|
provider = Mem0Provider(user_id="user123", context_prompt=custom_prompt, mem0_client=mock_mem0_client)
|
|
assert provider.context_prompt == custom_prompt
|
|
|
|
def test_init_with_scope_to_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test initialization with scope_to_per_operation_thread_id enabled."""
|
|
provider = Mem0Provider(
|
|
user_id="user123",
|
|
scope_to_per_operation_thread_id=True,
|
|
mem0_client=mock_mem0_client,
|
|
)
|
|
assert provider.scope_to_per_operation_thread_id is True
|
|
|
|
@patch("agent_framework_mem0._provider.AsyncMemoryClient")
|
|
def test_init_creates_default_client_when_none_provided(self, mock_memory_client_class: AsyncMock) -> None:
|
|
"""Test that a default client is created when none is provided."""
|
|
from mem0 import AsyncMemoryClient
|
|
|
|
mock_client = AsyncMock(spec=AsyncMemoryClient)
|
|
mock_memory_client_class.return_value = mock_client
|
|
|
|
provider = Mem0Provider(user_id="user123", api_key="test_api_key")
|
|
|
|
mock_memory_client_class.assert_called_once_with(api_key="test_api_key")
|
|
assert provider.mem0_client == mock_client
|
|
assert provider._should_close_client is True
|
|
|
|
def test_init_with_provided_client_should_not_close(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test that provided client should not be closed by provider."""
|
|
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
|
assert provider._should_close_client is False
|
|
|
|
|
|
class TestMem0ProviderAsyncContextManager:
|
|
"""Test async context manager behavior."""
|
|
|
|
async def test_async_context_manager_entry(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test async context manager entry returns self."""
|
|
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
|
async with provider as ctx:
|
|
assert ctx is provider
|
|
|
|
async def test_async_context_manager_exit_closes_client_when_should_close(self) -> None:
|
|
"""Test that async context manager closes client when it should."""
|
|
from mem0 import AsyncMemoryClient
|
|
|
|
mock_client = AsyncMock(spec=AsyncMemoryClient)
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock()
|
|
mock_client.async_client = AsyncMock()
|
|
mock_client.async_client.aclose = AsyncMock()
|
|
|
|
with patch("agent_framework_mem0._provider.AsyncMemoryClient", return_value=mock_client):
|
|
provider = Mem0Provider(user_id="user123", api_key="test_key")
|
|
assert provider._should_close_client is True
|
|
|
|
async with provider:
|
|
pass
|
|
|
|
mock_client.__aexit__.assert_called_once()
|
|
|
|
async def test_async_context_manager_exit_does_not_close_provided_client(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test that async context manager does not close provided client."""
|
|
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
|
assert provider._should_close_client is False
|
|
|
|
async with provider:
|
|
pass
|
|
|
|
mock_mem0_client.__aexit__.assert_not_called()
|
|
|
|
|
|
class TestMem0ProviderThreadMethods:
|
|
"""Test thread lifecycle methods."""
|
|
|
|
async def test_thread_created_sets_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test that thread_created sets per-operation thread ID."""
|
|
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
|
|
|
await provider.thread_created("thread123")
|
|
|
|
assert provider._per_operation_thread_id == "thread123"
|
|
|
|
async def test_thread_created_with_existing_thread_id(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test thread_created when thread ID already exists."""
|
|
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
|
provider._per_operation_thread_id = "existing_thread"
|
|
|
|
await provider.thread_created("thread123")
|
|
|
|
# Should not overwrite existing thread ID
|
|
assert provider._per_operation_thread_id == "existing_thread"
|
|
|
|
async def test_thread_created_validation_with_scope_enabled(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test thread_created validation when scope_to_per_operation_thread_id is enabled."""
|
|
provider = Mem0Provider(
|
|
user_id="user123",
|
|
scope_to_per_operation_thread_id=True,
|
|
mem0_client=mock_mem0_client,
|
|
)
|
|
provider._per_operation_thread_id = "existing_thread"
|
|
|
|
with pytest.raises(ValueError) as exc_info:
|
|
await provider.thread_created("different_thread")
|
|
|
|
assert "can only be used with one thread at a time" in str(exc_info.value)
|
|
|
|
async def test_messages_adding_sets_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test that invoked sets per-operation thread ID."""
|
|
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
|
|
|
await provider.thread_created("thread123")
|
|
|
|
assert provider._per_operation_thread_id == "thread123"
|
|
|
|
|
|
class TestMem0ProviderMessagesAdding:
|
|
"""Test invoked method."""
|
|
|
|
async def test_messages_adding_fails_without_filters(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test that invoked fails when no filters are provided."""
|
|
provider = Mem0Provider(mem0_client=mock_mem0_client)
|
|
message = ChatMessage(role=Role.USER, text="Hello!")
|
|
|
|
with pytest.raises(ServiceInitializationError) as exc_info:
|
|
await provider.invoked(message)
|
|
|
|
assert "At least one of the filters" in str(exc_info.value)
|
|
|
|
async def test_messages_adding_single_message(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test adding a single message."""
|
|
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
|
message = ChatMessage(role=Role.USER, text="Hello!")
|
|
|
|
await provider.invoked(message)
|
|
|
|
mock_mem0_client.add.assert_called_once()
|
|
call_args = mock_mem0_client.add.call_args
|
|
assert call_args.kwargs["messages"] == [{"role": "user", "content": "Hello!"}]
|
|
assert call_args.kwargs["user_id"] == "user123"
|
|
|
|
async def test_messages_adding_multiple_messages(
|
|
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
|
|
) -> None:
|
|
"""Test adding multiple messages."""
|
|
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
|
|
|
await provider.invoked(sample_messages)
|
|
|
|
mock_mem0_client.add.assert_called_once()
|
|
call_args = mock_mem0_client.add.call_args
|
|
expected_messages = [
|
|
{"role": "user", "content": "Hello, how are you?"},
|
|
{"role": "assistant", "content": "I'm doing well, thank you!"},
|
|
{"role": "system", "content": "You are a helpful assistant"},
|
|
]
|
|
assert call_args.kwargs["messages"] == expected_messages
|
|
|
|
async def test_messages_adding_with_agent_id(
|
|
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
|
|
) -> None:
|
|
"""Test adding messages with agent_id."""
|
|
provider = Mem0Provider(agent_id="agent123", mem0_client=mock_mem0_client)
|
|
|
|
await provider.invoked(sample_messages)
|
|
|
|
call_args = mock_mem0_client.add.call_args
|
|
assert call_args.kwargs["agent_id"] == "agent123"
|
|
assert call_args.kwargs["user_id"] is None
|
|
|
|
async def test_messages_adding_with_application_id(
|
|
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
|
|
) -> None:
|
|
"""Test adding messages with application_id in metadata."""
|
|
provider = Mem0Provider(user_id="user123", application_id="app123", mem0_client=mock_mem0_client)
|
|
|
|
await provider.invoked(sample_messages)
|
|
|
|
call_args = mock_mem0_client.add.call_args
|
|
assert call_args.kwargs["metadata"] == {"application_id": "app123"}
|
|
|
|
async def test_messages_adding_with_scope_to_per_operation_thread_id(
|
|
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
|
|
) -> None:
|
|
"""Test adding messages with scope_to_per_operation_thread_id enabled."""
|
|
provider = Mem0Provider(
|
|
user_id="user123",
|
|
thread_id="base_thread",
|
|
scope_to_per_operation_thread_id=True,
|
|
mem0_client=mock_mem0_client,
|
|
)
|
|
provider._per_operation_thread_id = "operation_thread"
|
|
|
|
await provider.thread_created(thread_id="operation_thread")
|
|
await provider.invoked(sample_messages)
|
|
|
|
call_args = mock_mem0_client.add.call_args
|
|
assert call_args.kwargs["run_id"] == "operation_thread"
|
|
|
|
async def test_messages_adding_without_scope_uses_base_thread_id(
|
|
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
|
|
) -> None:
|
|
"""Test adding messages without scope uses base thread_id."""
|
|
provider = Mem0Provider(
|
|
user_id="user123",
|
|
thread_id="base_thread",
|
|
scope_to_per_operation_thread_id=False,
|
|
mem0_client=mock_mem0_client,
|
|
)
|
|
|
|
await provider.invoked(sample_messages)
|
|
|
|
call_args = mock_mem0_client.add.call_args
|
|
assert call_args.kwargs["run_id"] == "base_thread"
|
|
|
|
async def test_messages_adding_filters_empty_messages(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test that empty or invalid messages are filtered out."""
|
|
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
|
messages = [
|
|
ChatMessage(role=Role.USER, text=""), # Empty text
|
|
ChatMessage(role=Role.USER, text=" "), # Whitespace only
|
|
ChatMessage(role=Role.USER, text="Valid message"),
|
|
]
|
|
|
|
await provider.invoked(messages)
|
|
|
|
call_args = mock_mem0_client.add.call_args
|
|
# Should only include the valid message
|
|
assert call_args.kwargs["messages"] == [{"role": "user", "content": "Valid message"}]
|
|
|
|
async def test_messages_adding_skips_when_no_valid_messages(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test that mem0 client is not called when no valid messages exist."""
|
|
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
|
messages = [
|
|
ChatMessage(role=Role.USER, text=""),
|
|
ChatMessage(role=Role.USER, text=" "),
|
|
]
|
|
|
|
await provider.invoked(messages)
|
|
|
|
mock_mem0_client.add.assert_not_called()
|
|
|
|
|
|
class TestMem0ProviderModelInvoking:
|
|
"""Test invoking method."""
|
|
|
|
async def test_model_invoking_fails_without_filters(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test that invoking fails when no filters are provided."""
|
|
provider = Mem0Provider(mem0_client=mock_mem0_client)
|
|
message = ChatMessage(role=Role.USER, text="What's the weather?")
|
|
|
|
with pytest.raises(ServiceInitializationError) as exc_info:
|
|
await provider.invoking(message)
|
|
|
|
assert "At least one of the filters" in str(exc_info.value)
|
|
|
|
async def test_model_invoking_single_message(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test invoking with a single message."""
|
|
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
|
message = ChatMessage(role=Role.USER, text="What's the weather?")
|
|
|
|
# Mock search results
|
|
mock_mem0_client.search.return_value = [
|
|
{"memory": "User likes outdoor activities"},
|
|
{"memory": "User lives in Seattle"},
|
|
]
|
|
|
|
context = await provider.invoking(message)
|
|
|
|
mock_mem0_client.search.assert_called_once()
|
|
call_args = mock_mem0_client.search.call_args
|
|
assert call_args.kwargs["query"] == "What's the weather?"
|
|
assert call_args.kwargs["filters"] == {"user_id": "user123"}
|
|
|
|
assert isinstance(context, Context)
|
|
expected_instructions = (
|
|
"## Memories\nConsider the following memories when answering user questions:\n"
|
|
"User likes outdoor activities\nUser lives in Seattle"
|
|
)
|
|
|
|
assert context.messages
|
|
assert context.messages[0].text == expected_instructions
|
|
|
|
async def test_model_invoking_multiple_messages(
|
|
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
|
|
) -> None:
|
|
"""Test invoking with multiple messages."""
|
|
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
|
|
|
mock_mem0_client.search.return_value = [{"memory": "Previous conversation context"}]
|
|
|
|
await provider.invoking(sample_messages)
|
|
|
|
call_args = mock_mem0_client.search.call_args
|
|
expected_query = "Hello, how are you?\nI'm doing well, thank you!\nYou are a helpful assistant"
|
|
assert call_args.kwargs["query"] == expected_query
|
|
|
|
async def test_model_invoking_with_agent_id(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test invoking with agent_id."""
|
|
provider = Mem0Provider(agent_id="agent123", mem0_client=mock_mem0_client)
|
|
message = ChatMessage(role=Role.USER, text="Hello")
|
|
|
|
mock_mem0_client.search.return_value = []
|
|
|
|
await provider.invoking(message)
|
|
|
|
call_args = mock_mem0_client.search.call_args
|
|
assert call_args.kwargs["filters"] == {"agent_id": "agent123"}
|
|
|
|
async def test_model_invoking_with_scope_to_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test invoking with scope_to_per_operation_thread_id enabled."""
|
|
provider = Mem0Provider(
|
|
user_id="user123",
|
|
thread_id="base_thread",
|
|
scope_to_per_operation_thread_id=True,
|
|
mem0_client=mock_mem0_client,
|
|
)
|
|
provider._per_operation_thread_id = "operation_thread"
|
|
message = ChatMessage(role=Role.USER, text="Hello")
|
|
|
|
mock_mem0_client.search.return_value = []
|
|
|
|
await provider.invoking(message)
|
|
|
|
call_args = mock_mem0_client.search.call_args
|
|
assert call_args.kwargs["filters"] == {"user_id": "user123", "run_id": "operation_thread"}
|
|
|
|
async def test_model_invoking_no_memories_returns_none_instructions(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test that no memories returns context with None instructions."""
|
|
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
|
message = ChatMessage(role=Role.USER, text="Hello")
|
|
|
|
mock_mem0_client.search.return_value = []
|
|
|
|
context = await provider.invoking(message)
|
|
|
|
assert isinstance(context, Context)
|
|
assert not context.messages
|
|
|
|
async def test_model_invoking_function_approval_response_returns_none_instructions(
|
|
self, mock_mem0_client: AsyncMock
|
|
) -> None:
|
|
"""Test invoking with function approval response content messages returns context with None instructions."""
|
|
|
|
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
|
function_call = Content.from_function_call(call_id="1", name="test_func", arguments='{"arg1": "value1"}')
|
|
message = ChatMessage(
|
|
role=Role.USER,
|
|
contents=[
|
|
Content.from_function_approval_response(
|
|
id="approval_1",
|
|
function_call=function_call,
|
|
approved=True,
|
|
)
|
|
],
|
|
)
|
|
|
|
mock_mem0_client.search.return_value = []
|
|
|
|
context = await provider.invoking(message)
|
|
|
|
assert isinstance(context, Context)
|
|
assert not context.messages
|
|
|
|
async def test_model_invoking_filters_empty_message_text(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test that empty message text is filtered out from query."""
|
|
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
|
messages = [
|
|
ChatMessage(role=Role.USER, text=""),
|
|
ChatMessage(role=Role.USER, text="Valid message"),
|
|
ChatMessage(role=Role.USER, text=" "),
|
|
]
|
|
|
|
mock_mem0_client.search.return_value = []
|
|
|
|
await provider.invoking(messages)
|
|
|
|
call_args = mock_mem0_client.search.call_args
|
|
assert call_args.kwargs["query"] == "Valid message"
|
|
|
|
async def test_model_invoking_custom_context_prompt(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test invoking with custom context prompt."""
|
|
custom_prompt = "## Custom Context\nRemember these details:"
|
|
provider = Mem0Provider(
|
|
user_id="user123",
|
|
context_prompt=custom_prompt,
|
|
mem0_client=mock_mem0_client,
|
|
)
|
|
message = ChatMessage(role=Role.USER, text="Hello")
|
|
|
|
mock_mem0_client.search.return_value = [{"memory": "Test memory"}]
|
|
|
|
context = await provider.invoking(message)
|
|
|
|
expected_instructions = "## Custom Context\nRemember these details:\nTest memory"
|
|
assert context.messages
|
|
assert context.messages[0].text == expected_instructions
|
|
|
|
|
|
class TestMem0ProviderValidation:
|
|
"""Test validation methods."""
|
|
|
|
def test_validate_per_operation_thread_id_success(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test successful validation of per-operation thread ID."""
|
|
provider = Mem0Provider(
|
|
user_id="user123",
|
|
scope_to_per_operation_thread_id=True,
|
|
mem0_client=mock_mem0_client,
|
|
)
|
|
provider._per_operation_thread_id = "thread123"
|
|
|
|
# Should not raise exception for same thread ID
|
|
provider._validate_per_operation_thread_id("thread123")
|
|
|
|
# Should not raise exception for None
|
|
provider._validate_per_operation_thread_id(None)
|
|
|
|
def test_validate_per_operation_thread_id_failure(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test validation failure for conflicting thread IDs."""
|
|
provider = Mem0Provider(
|
|
user_id="user123",
|
|
scope_to_per_operation_thread_id=True,
|
|
mem0_client=mock_mem0_client,
|
|
)
|
|
provider._per_operation_thread_id = "thread123"
|
|
|
|
with pytest.raises(ValueError) as exc_info:
|
|
provider._validate_per_operation_thread_id("different_thread")
|
|
|
|
assert "can only be used with one thread at a time" in str(exc_info.value)
|
|
|
|
def test_validate_per_operation_thread_id_disabled_scope(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test that validation is skipped when scope is disabled."""
|
|
provider = Mem0Provider(
|
|
user_id="user123",
|
|
scope_to_per_operation_thread_id=False,
|
|
mem0_client=mock_mem0_client,
|
|
)
|
|
provider._per_operation_thread_id = "thread123"
|
|
|
|
# Should not raise exception even with different thread ID
|
|
provider._validate_per_operation_thread_id("different_thread")
|
|
|
|
|
|
class TestMem0ProviderBuildFilters:
|
|
"""Test the _build_filters method."""
|
|
|
|
def test_build_filters_with_user_id_only(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test building filters with only user_id."""
|
|
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
|
|
|
|
filters = provider._build_filters()
|
|
assert filters == {"user_id": "user123"}
|
|
|
|
def test_build_filters_with_all_parameters(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test building filters with all initialization parameters."""
|
|
provider = Mem0Provider(
|
|
user_id="user123",
|
|
agent_id="agent456",
|
|
thread_id="thread789",
|
|
application_id="app999",
|
|
mem0_client=mock_mem0_client,
|
|
)
|
|
|
|
filters = provider._build_filters()
|
|
assert filters == {
|
|
"user_id": "user123",
|
|
"agent_id": "agent456",
|
|
"run_id": "thread789",
|
|
"app_id": "app999",
|
|
}
|
|
|
|
def test_build_filters_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test that None values are excluded from filters."""
|
|
provider = Mem0Provider(
|
|
user_id="user123",
|
|
agent_id=None,
|
|
thread_id=None,
|
|
application_id=None,
|
|
mem0_client=mock_mem0_client,
|
|
)
|
|
|
|
filters = provider._build_filters()
|
|
assert filters == {"user_id": "user123"}
|
|
assert "agent_id" not in filters
|
|
assert "run_id" not in filters
|
|
assert "app_id" not in filters
|
|
|
|
def test_build_filters_with_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test that per-operation thread ID takes precedence over base thread_id."""
|
|
provider = Mem0Provider(
|
|
user_id="user123",
|
|
thread_id="base_thread",
|
|
scope_to_per_operation_thread_id=True,
|
|
mem0_client=mock_mem0_client,
|
|
)
|
|
provider._per_operation_thread_id = "operation_thread"
|
|
|
|
filters = provider._build_filters()
|
|
assert filters == {
|
|
"user_id": "user123",
|
|
"run_id": "operation_thread", # Per-operation thread, not base_thread
|
|
}
|
|
|
|
def test_build_filters_uses_base_thread_when_no_per_operation(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test that base thread_id is used when per-operation thread is not set."""
|
|
provider = Mem0Provider(
|
|
user_id="user123",
|
|
thread_id="base_thread",
|
|
scope_to_per_operation_thread_id=True,
|
|
mem0_client=mock_mem0_client,
|
|
)
|
|
# _per_operation_thread_id is None
|
|
|
|
filters = provider._build_filters()
|
|
assert filters == {
|
|
"user_id": "user123",
|
|
"run_id": "base_thread", # Falls back to base thread_id
|
|
}
|
|
|
|
def test_build_filters_returns_empty_dict_when_no_parameters(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Test that _build_filters returns an empty dict when no parameters are set."""
|
|
provider = Mem0Provider(mem0_client=mock_mem0_client)
|
|
|
|
filters = provider._build_filters()
|
|
assert filters == {}
|