Files
agent-framework/python/packages/mem0/tests/test_mem0_context_provider.py
Eduard van Valkenburg 0521f5bed8 Python: [BREAKING] Simplify API: ChatAgent -> Agent, ChatMessage -> Message (#3747)
* [BREAKING] Rename ChatAgent -> Agent, ChatMessage -> Message, ChatClientProtocol -> SupportsChatGetResponse

Simplify the public API by removing redundant 'Chat' prefix from core types:
- ChatAgent -> Agent
- RawChatAgent -> RawAgent
- ChatMessage -> Message
- ChatClientProtocol -> SupportsChatGetResponse

Also renamed internal WorkflowMessage (was Message in _runner_context) to avoid collision.

No backward compatibility aliases - this is a clean breaking change.

* [BREAKING] Rename Agent chat_client parameter to client

* Fix rebase issues: WorkflowMessage references and broken markdown links

* Fix formatting and lint issues from code quality checks

* Fix import ordering in workflow sample files

* fixed rebase

* Fix test failures: use WorkflowMessage and A2AMessage after ChatMessage→Message rename

- Replace Message(data=..., source_id=...) with WorkflowMessage(...) in workflow tests
- Fix isinstance check in A2A agent to use A2AMessage instead of Message
- Fix import in test_workflow_observability.py (Message→WorkflowMessage)

* Fix lint, fmt, and sample errors after ChatMessage→Message rename

- Auto-fix 70+ ruff lint issues across samples (ChatMessage→Message refs)
- Fix HostedVectorStoreContent→Content.from_hosted_vector_store in file search sample
- Fix _normalize_messages→normalize_messages in custom agent sample
- Fix context.terminate→raise MiddlewareTermination in middleware samples
- Fix with_update_hook→with_transform_hook in override middleware sample
- Add TOptions_co import back to custom_chat_client sample
- Add noqa for FastAPI File() default in chatkit sample
- Fix B023 loop variable capture in weather agent sample

* fix: update Agent constructor calls from chat_client to client in declaration-only tool tests

* fix: add register_cleanup to devui lazy-loading proxy and type stub

* fixed tests and updated new pieces

* fix agui typevar

* fix merge errors

* fix merge conflicts

* fiux merge

* Remove unused links

---------

Co-authored-by: Evan Mattson <evan.mattson@microsoft.com>
2026-02-10 23:04:32 +00:00

604 lines
24 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
# pyright: reportPrivateUsage=false
import importlib
import os
import sys
from unittest.mock import AsyncMock
import pytest
from agent_framework import Content, Context, Message
from agent_framework.exceptions import ServiceInitializationError
from agent_framework.mem0 import Mem0Provider
def test_mem0_provider_import() -> None:
"""Test that Mem0Provider can be imported."""
assert Mem0Provider is not None
@pytest.fixture
def mock_mem0_client() -> AsyncMock:
"""Create a mock Mem0 AsyncMemoryClient."""
from mem0 import AsyncMemoryClient
mock_client = AsyncMock(spec=AsyncMemoryClient)
mock_client.add = AsyncMock()
mock_client.search = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock()
mock_client.async_client = AsyncMock()
mock_client.async_client.aclose = AsyncMock()
return mock_client
@pytest.fixture
def sample_messages() -> list[Message]:
"""Create sample chat messages for testing."""
return [
Message(role="user", text="Hello, how are you?"),
Message(role="assistant", text="I'm doing well, thank you!"),
Message(role="system", text="You are a helpful assistant"),
]
def test_init_with_all_ids(mock_mem0_client: AsyncMock) -> None:
"""Test initialization with all IDs provided."""
provider = Mem0Provider(
user_id="user123",
agent_id="agent123",
application_id="app123",
thread_id="thread123",
mem0_client=mock_mem0_client,
)
assert provider.user_id == "user123"
assert provider.agent_id == "agent123"
assert provider.application_id == "app123"
assert provider.thread_id == "thread123"
def test_init_without_filters_succeeds(mock_mem0_client: AsyncMock) -> None:
"""Test that initialization succeeds even without filters (validation happens during invocation)."""
provider = Mem0Provider(mem0_client=mock_mem0_client)
assert provider.user_id is None
assert provider.agent_id is None
assert provider.application_id is None
assert provider.thread_id is None
def test_init_with_custom_context_prompt(mock_mem0_client: AsyncMock) -> None:
"""Test initialization with custom context prompt."""
custom_prompt = "## Custom Memories\nConsider these memories:"
provider = Mem0Provider(user_id="user123", context_prompt=custom_prompt, mem0_client=mock_mem0_client)
assert provider.context_prompt == custom_prompt
def test_init_with_scope_to_per_operation_thread_id(mock_mem0_client: AsyncMock) -> None:
"""Test initialization with scope_to_per_operation_thread_id enabled."""
provider = Mem0Provider(
user_id="user123",
scope_to_per_operation_thread_id=True,
mem0_client=mock_mem0_client,
)
assert provider.scope_to_per_operation_thread_id is True
def test_init_with_provided_client_should_not_close(mock_mem0_client: AsyncMock) -> None:
"""Test that provided client should not be closed by provider."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
assert provider._should_close_client is False
async def test_async_context_manager_entry(mock_mem0_client: AsyncMock) -> None:
"""Test async context manager entry returns self."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
async with provider as ctx:
assert ctx is provider
async def test_async_context_manager_exit_does_not_close_provided_client(mock_mem0_client: AsyncMock) -> None:
"""Test that async context manager does not close provided client."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
assert provider._should_close_client is False
async with provider:
pass
mock_mem0_client.__aexit__.assert_not_called()
class TestMem0ProviderThreadMethods:
"""Test thread lifecycle methods."""
async def test_thread_created_sets_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None:
"""Test that thread_created sets per-operation thread ID."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
await provider.thread_created("thread123")
assert provider._per_operation_thread_id == "thread123"
async def test_thread_created_with_existing_thread_id(self, mock_mem0_client: AsyncMock) -> None:
"""Test thread_created when thread ID already exists."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
provider._per_operation_thread_id = "existing_thread"
await provider.thread_created("thread123")
# Should not overwrite existing thread ID
assert provider._per_operation_thread_id == "existing_thread"
async def test_thread_created_validation_with_scope_enabled(self, mock_mem0_client: AsyncMock) -> None:
"""Test thread_created validation when scope_to_per_operation_thread_id is enabled."""
provider = Mem0Provider(
user_id="user123",
scope_to_per_operation_thread_id=True,
mem0_client=mock_mem0_client,
)
provider._per_operation_thread_id = "existing_thread"
with pytest.raises(ValueError) as exc_info:
await provider.thread_created("different_thread")
assert "can only be used with one thread at a time" in str(exc_info.value)
async def test_messages_adding_sets_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None:
"""Test that invoked sets per-operation thread ID."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
await provider.thread_created("thread123")
assert provider._per_operation_thread_id == "thread123"
class TestMem0ProviderMessagesAdding:
"""Test invoked method."""
async def test_messages_adding_fails_without_filters(self, mock_mem0_client: AsyncMock) -> None:
"""Test that invoked fails when no filters are provided."""
provider = Mem0Provider(mem0_client=mock_mem0_client)
message = Message(role="user", text="Hello!")
with pytest.raises(ServiceInitializationError) as exc_info:
await provider.invoked(message)
assert "At least one of the filters" in str(exc_info.value)
async def test_messages_adding_single_message(self, mock_mem0_client: AsyncMock) -> None:
"""Test adding a single message."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
message = Message(role="user", text="Hello!")
await provider.invoked(message)
mock_mem0_client.add.assert_called_once()
call_args = mock_mem0_client.add.call_args
assert call_args.kwargs["messages"] == [{"role": "user", "content": "Hello!"}]
assert call_args.kwargs["user_id"] == "user123"
async def test_messages_adding_multiple_messages(
self, mock_mem0_client: AsyncMock, sample_messages: list[Message]
) -> None:
"""Test adding multiple messages."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
await provider.invoked(sample_messages)
mock_mem0_client.add.assert_called_once()
call_args = mock_mem0_client.add.call_args
expected_messages = [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you!"},
{"role": "system", "content": "You are a helpful assistant"},
]
assert call_args.kwargs["messages"] == expected_messages
async def test_messages_adding_with_agent_id(
self, mock_mem0_client: AsyncMock, sample_messages: list[Message]
) -> None:
"""Test adding messages with agent_id."""
provider = Mem0Provider(agent_id="agent123", mem0_client=mock_mem0_client)
await provider.invoked(sample_messages)
call_args = mock_mem0_client.add.call_args
assert call_args.kwargs["agent_id"] == "agent123"
assert call_args.kwargs["user_id"] is None
async def test_messages_adding_with_application_id(
self, mock_mem0_client: AsyncMock, sample_messages: list[Message]
) -> None:
"""Test adding messages with application_id in metadata."""
provider = Mem0Provider(user_id="user123", application_id="app123", mem0_client=mock_mem0_client)
await provider.invoked(sample_messages)
call_args = mock_mem0_client.add.call_args
assert call_args.kwargs["metadata"] == {"application_id": "app123"}
async def test_messages_adding_with_scope_to_per_operation_thread_id(
self, mock_mem0_client: AsyncMock, sample_messages: list[Message]
) -> None:
"""Test adding messages with scope_to_per_operation_thread_id enabled."""
provider = Mem0Provider(
user_id="user123",
thread_id="base_thread",
scope_to_per_operation_thread_id=True,
mem0_client=mock_mem0_client,
)
provider._per_operation_thread_id = "operation_thread"
await provider.thread_created(thread_id="operation_thread")
await provider.invoked(sample_messages)
call_args = mock_mem0_client.add.call_args
assert call_args.kwargs["run_id"] == "operation_thread"
async def test_messages_adding_without_scope_uses_base_thread_id(
self, mock_mem0_client: AsyncMock, sample_messages: list[Message]
) -> None:
"""Test adding messages without scope uses base thread_id."""
provider = Mem0Provider(
user_id="user123",
thread_id="base_thread",
scope_to_per_operation_thread_id=False,
mem0_client=mock_mem0_client,
)
await provider.invoked(sample_messages)
call_args = mock_mem0_client.add.call_args
assert call_args.kwargs["run_id"] == "base_thread"
async def test_messages_adding_filters_empty_messages(self, mock_mem0_client: AsyncMock) -> None:
"""Test that empty or invalid messages are filtered out."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
messages = [
Message(role="user", text=""), # Empty text
Message(role="user", text=" "), # Whitespace only
Message(role="user", text="Valid message"),
]
await provider.invoked(messages)
call_args = mock_mem0_client.add.call_args
# Should only include the valid message
assert call_args.kwargs["messages"] == [{"role": "user", "content": "Valid message"}]
async def test_messages_adding_skips_when_no_valid_messages(self, mock_mem0_client: AsyncMock) -> None:
"""Test that mem0 client is not called when no valid messages exist."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
messages = [
Message(role="user", text=""),
Message(role="user", text=" "),
]
await provider.invoked(messages)
mock_mem0_client.add.assert_not_called()
class TestMem0ProviderModelInvoking:
"""Test invoking method."""
async def test_model_invoking_fails_without_filters(self, mock_mem0_client: AsyncMock) -> None:
"""Test that invoking fails when no filters are provided."""
provider = Mem0Provider(mem0_client=mock_mem0_client)
message = Message(role="user", text="What's the weather?")
with pytest.raises(ServiceInitializationError) as exc_info:
await provider.invoking(message)
assert "At least one of the filters" in str(exc_info.value)
async def test_model_invoking_single_message(self, mock_mem0_client: AsyncMock) -> None:
"""Test invoking with a single message."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
message = Message(role="user", text="What's the weather?")
# Mock search results
mock_mem0_client.search.return_value = [
{"memory": "User likes outdoor activities"},
{"memory": "User lives in Seattle"},
]
context = await provider.invoking(message)
mock_mem0_client.search.assert_called_once()
call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["query"] == "What's the weather?"
assert call_args.kwargs["filters"] == {"user_id": "user123"}
assert isinstance(context, Context)
expected_instructions = (
"## Memories\nConsider the following memories when answering user questions:\n"
"User likes outdoor activities\nUser lives in Seattle"
)
assert context.messages
assert context.messages[0].text == expected_instructions
async def test_model_invoking_multiple_messages(
self, mock_mem0_client: AsyncMock, sample_messages: list[Message]
) -> None:
"""Test invoking with multiple messages."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
mock_mem0_client.search.return_value = [{"memory": "Previous conversation context"}]
await provider.invoking(sample_messages)
call_args = mock_mem0_client.search.call_args
expected_query = "Hello, how are you?\nI'm doing well, thank you!\nYou are a helpful assistant"
assert call_args.kwargs["query"] == expected_query
async def test_model_invoking_with_agent_id(self, mock_mem0_client: AsyncMock) -> None:
"""Test invoking with agent_id."""
provider = Mem0Provider(agent_id="agent123", mem0_client=mock_mem0_client)
message = Message(role="user", text="Hello")
mock_mem0_client.search.return_value = []
await provider.invoking(message)
call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["filters"] == {"agent_id": "agent123"}
async def test_model_invoking_with_scope_to_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None:
"""Test invoking with scope_to_per_operation_thread_id enabled."""
provider = Mem0Provider(
user_id="user123",
thread_id="base_thread",
scope_to_per_operation_thread_id=True,
mem0_client=mock_mem0_client,
)
provider._per_operation_thread_id = "operation_thread"
message = Message(role="user", text="Hello")
mock_mem0_client.search.return_value = []
await provider.invoking(message)
call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["filters"] == {"user_id": "user123", "run_id": "operation_thread"}
async def test_model_invoking_no_memories_returns_none_instructions(self, mock_mem0_client: AsyncMock) -> None:
"""Test that no memories returns context with None instructions."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
message = Message(role="user", text="Hello")
mock_mem0_client.search.return_value = []
context = await provider.invoking(message)
assert isinstance(context, Context)
assert not context.messages
async def test_model_invoking_function_approval_response_returns_none_instructions(
self, mock_mem0_client: AsyncMock
) -> None:
"""Test invoking with function approval response content messages returns context with None instructions."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
function_call = Content.from_function_call(call_id="1", name="test_func", arguments='{"arg1": "value1"}')
message = Message(
role="user",
contents=[
Content.from_function_approval_response(
id="approval_1",
function_call=function_call,
approved=True,
)
],
)
mock_mem0_client.search.return_value = []
context = await provider.invoking(message)
assert isinstance(context, Context)
assert not context.messages
async def test_model_invoking_filters_empty_message_text(self, mock_mem0_client: AsyncMock) -> None:
"""Test that empty message text is filtered out from query."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
messages = [
Message(role="user", text=""),
Message(role="user", text="Valid message"),
Message(role="user", text=" "),
]
mock_mem0_client.search.return_value = []
await provider.invoking(messages)
call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["query"] == "Valid message"
async def test_model_invoking_custom_context_prompt(self, mock_mem0_client: AsyncMock) -> None:
"""Test invoking with custom context prompt."""
custom_prompt = "## Custom Context\nRemember these details:"
provider = Mem0Provider(
user_id="user123",
context_prompt=custom_prompt,
mem0_client=mock_mem0_client,
)
message = Message(role="user", text="Hello")
mock_mem0_client.search.return_value = [{"memory": "Test memory"}]
context = await provider.invoking(message)
expected_instructions = "## Custom Context\nRemember these details:\nTest memory"
assert context.messages
assert context.messages[0].text == expected_instructions
class TestMem0ProviderValidation:
"""Test validation methods."""
def test_validate_per_operation_thread_id_success(self, mock_mem0_client: AsyncMock) -> None:
"""Test successful validation of per-operation thread ID."""
provider = Mem0Provider(
user_id="user123",
scope_to_per_operation_thread_id=True,
mem0_client=mock_mem0_client,
)
provider._per_operation_thread_id = "thread123"
# Should not raise exception for same thread ID
provider._validate_per_operation_thread_id("thread123")
# Should not raise exception for None
provider._validate_per_operation_thread_id(None)
def test_validate_per_operation_thread_id_failure(self, mock_mem0_client: AsyncMock) -> None:
"""Test validation failure for conflicting thread IDs."""
provider = Mem0Provider(
user_id="user123",
scope_to_per_operation_thread_id=True,
mem0_client=mock_mem0_client,
)
provider._per_operation_thread_id = "thread123"
with pytest.raises(ValueError) as exc_info:
provider._validate_per_operation_thread_id("different_thread")
assert "can only be used with one thread at a time" in str(exc_info.value)
def test_validate_per_operation_thread_id_disabled_scope(self, mock_mem0_client: AsyncMock) -> None:
"""Test that validation is skipped when scope is disabled."""
provider = Mem0Provider(
user_id="user123",
scope_to_per_operation_thread_id=False,
mem0_client=mock_mem0_client,
)
provider._per_operation_thread_id = "thread123"
# Should not raise exception even with different thread ID
provider._validate_per_operation_thread_id("different_thread")
class TestMem0ProviderBuildFilters:
"""Test the _build_filters method."""
def test_build_filters_with_user_id_only(self, mock_mem0_client: AsyncMock) -> None:
"""Test building filters with only user_id."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
filters = provider._build_filters()
assert filters == {"user_id": "user123"}
def test_build_filters_with_all_parameters(self, mock_mem0_client: AsyncMock) -> None:
"""Test building filters with all initialization parameters."""
provider = Mem0Provider(
user_id="user123",
agent_id="agent456",
thread_id="thread789",
application_id="app999",
mem0_client=mock_mem0_client,
)
filters = provider._build_filters()
assert filters == {
"user_id": "user123",
"agent_id": "agent456",
"run_id": "thread789",
"app_id": "app999",
}
def test_build_filters_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None:
"""Test that None values are excluded from filters."""
provider = Mem0Provider(
user_id="user123",
agent_id=None,
thread_id=None,
application_id=None,
mem0_client=mock_mem0_client,
)
filters = provider._build_filters()
assert filters == {"user_id": "user123"}
assert "agent_id" not in filters
assert "run_id" not in filters
assert "app_id" not in filters
def test_build_filters_with_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None:
"""Test that per-operation thread ID takes precedence over base thread_id."""
provider = Mem0Provider(
user_id="user123",
thread_id="base_thread",
scope_to_per_operation_thread_id=True,
mem0_client=mock_mem0_client,
)
provider._per_operation_thread_id = "operation_thread"
filters = provider._build_filters()
assert filters == {
"user_id": "user123",
"run_id": "operation_thread", # Per-operation thread, not base_thread
}
def test_build_filters_uses_base_thread_when_no_per_operation(self, mock_mem0_client: AsyncMock) -> None:
"""Test that base thread_id is used when per-operation thread is not set."""
provider = Mem0Provider(
user_id="user123",
thread_id="base_thread",
scope_to_per_operation_thread_id=True,
mem0_client=mock_mem0_client,
)
# _per_operation_thread_id is None
filters = provider._build_filters()
assert filters == {
"user_id": "user123",
"run_id": "base_thread", # Falls back to base thread_id
}
def test_build_filters_returns_empty_dict_when_no_parameters(self, mock_mem0_client: AsyncMock) -> None:
"""Test that _build_filters returns an empty dict when no parameters are set."""
provider = Mem0Provider(mem0_client=mock_mem0_client)
filters = provider._build_filters()
assert filters == {}
class TestMem0Telemetry:
"""Test telemetry configuration for Mem0."""
def test_mem0_telemetry_disabled_by_default(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that MEM0_TELEMETRY is set to 'false' by default when importing the package."""
# Ensure MEM0_TELEMETRY is not set before importing the module under test
monkeypatch.delenv("MEM0_TELEMETRY", raising=False)
# Remove cached modules to force re-import and trigger module-level initialization
modules_to_remove = [key for key in sys.modules if key.startswith("agent_framework_mem0")]
for mod in modules_to_remove:
del sys.modules[mod]
# Import (and reload) the module so that it can set MEM0_TELEMETRY when unset
import agent_framework_mem0
importlib.reload(agent_framework_mem0)
# The environment variable should be set to "false" after importing
assert os.environ.get("MEM0_TELEMETRY") == "false"
def test_mem0_telemetry_respects_user_setting(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that user-set MEM0_TELEMETRY value is not overwritten."""
# Remove cached modules to force re-import
modules_to_remove = [key for key in sys.modules if key.startswith("agent_framework_mem0")]
for mod in modules_to_remove:
del sys.modules[mod]
# Set user preference before import
monkeypatch.setenv("MEM0_TELEMETRY", "true")
# Re-import the module
import agent_framework_mem0
importlib.reload(agent_framework_mem0)
# User setting should be preserved
assert os.environ.get("MEM0_TELEMETRY") == "true"