Files
agent-framework/python/packages/mem0/tests/test_mem0_context_provider.py
Eduard van Valkenburg 1e350ea22f Python: [BREAKING] PR2 — Wire context provider pipeline, remove old types, update all consumers (#3850)
* PR2: Wire context provider pipeline and update all internal consumers

- Replace AgentThread with AgentSession across all packages
- Replace ContextProvider with BaseContextProvider across all packages
- Replace context_provider param with context_providers (Sequence)
- Replace thread= with session= in run() signatures
- Replace get_new_thread() with create_session()
- Add get_session(service_session_id) to agent interface
- DurableAgentThread -> DurableAgentSession
- Remove _notify_thread_of_new_messages from WorkflowAgent
- Wire before_run/after_run context provider pipeline in RawAgent
- Auto-inject InMemoryHistoryProvider when no providers configured

* fix: update all tests for context provider pipeline, fix lazy-loaders, remove old test files

* refactor: update all sample files for context provider pipeline (AgentThread→AgentSession, ContextProvider→BaseContextProvider)

* fix: update remaining ag-ui references (client docstring, getting_started sample)

* fix: make get_session service_session_id keyword-only to avoid confusion with session_id

* refactor: rename _RunContext.thread_messages to session_messages

* refactor: remove _threads.py, _memory.py, and old provider files; migrate devui to use plain message lists

* rename: remove _new_ prefix from test files

* refactor: rewrite SlidingWindowChatMessageStore as SlidingWindowHistoryProvider(InMemoryHistoryProvider)

* fix: read full history from session state directly instead of reaching into provider internals

* fix: update stale .pyi stubs, sample imports, and README references for new provider types

* fix: remove stale message_store, _notify_thread_of_new_messages, and session_id.key references in samples

* refactor: merge context_providers and sessions sample folders into sessions, remove aggregate_context_provider

* refactor: UserInfoMemory stores state in session.state instead of instance attributes

* feat: add Pydantic BaseModel support to session state serialization

Pydantic models stored in session.state are now automatically serialized
via model_dump() and restored via model_validate() during to_dict()/from_dict()
round-trips. Models are auto-registered on first serialization; use
register_state_type() for cold-start deserialization.

Also export register_state_type as a public API.

* fix mem0

* Update sample README links and descriptions for session terminology

- Replace 'thread' with 'session' in sample descriptions across all READMEs
- Update file links for renamed samples (mem0_sessions, redis_sessions, etc.)
- Fix Threads section → Sessions section in main samples/README.md
- Update tools, middleware, workflows, durabletask, azure_functions READMEs
- Update architecture diagrams in concepts/tools/README.md
- Update migration guides (autogen, semantic-kernel)

* Fix broken Redis README link to renamed sample

* Fix Mem0 OSS client search: pass scoping params as direct kwargs

AsyncMemory (OSS) expects user_id/agent_id/run_id as direct kwargs,
while AsyncMemoryClient (Platform) expects them in a filters dict.
Adds tests for both client types.

Port of fix from #3844 to new Mem0ContextProvider.

* Fix rebase issues: restore missing _conversation_state.py and checkpoint decode logic

- Add back _conversation_state.py (encode/decode_chat_messages) lost in rebase
- Fix on_checkpoint_restore to decode cache/conversation with decode_chat_messages
- Fix on_checkpoint_restore to use decode_checkpoint_value for pending requests
- Add tests/workflow/__init__.py for relative import support
- Fix test_agent_executor checkpoint selection (checkpoints[1] not superstep)

* Add STORES_BY_DEFAULT ClassVar to skip redundant InMemoryHistoryProvider injection

Chat clients that store history server-side by default (OpenAI Responses API,
Azure AI Agent) now declare STORES_BY_DEFAULT = True. The agent checks this
during auto-injection and skips InMemoryHistoryProvider unless the user
explicitly sets store=False.

* Fix broken markdown links in azure_ai and redis READMEs

* Fix getting-started samples to use session API instead of removed thread/ContextProvider API

* updates to workflow as agent

* fix group chat import

* Rename Thread→Session throughout, fix service_session_id propagation, remove stale AGUIThread

- Fix: Propagate conversation_id from ChatResponse back to session.service_session_id
  in both streaming and non-streaming paths in _agents.py
- Rename AgentThreadException → AgentSessionException
- Remove stale AGUIThread from ag_ui lazy-loader
- Rename use_service_thread → use_service_session in ag-ui package
- Rename test functions from *_thread_* to *_session_*
- Rename sample files from *_thread* to *_session*
- Update docstrings and comments: thread → session
- Update _mcp.py kwargs filter: add 'session' alongside 'thread'
- Fix ContinuationToken docstring example: thread=thread → session=session
- Fix _clients.py docstring: 'Agent threads' → 'Agent sessions'

* Fix broken markdown links after thread→session file renames

* fix azure ai test
2026-02-12 21:00:32 +00:00

408 lines
19 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
# pyright: reportPrivateUsage=false
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from agent_framework import AgentResponse, Message
from agent_framework._sessions import AgentSession, SessionContext
from agent_framework.exceptions import ServiceInitializationError
from agent_framework_mem0._context_provider import Mem0ContextProvider
@pytest.fixture
def mock_mem0_client() -> AsyncMock:
"""Create a mock Mem0 AsyncMemoryClient."""
from mem0 import AsyncMemoryClient
mock_client = AsyncMock(spec=AsyncMemoryClient)
mock_client.add = AsyncMock()
mock_client.search = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock()
return mock_client
@pytest.fixture
def mock_oss_mem0_client() -> AsyncMock:
"""Create a mock Mem0 OSS AsyncMemory client."""
from mem0 import AsyncMemory
mock_client = AsyncMock(spec=AsyncMemory)
mock_client.add = AsyncMock()
mock_client.search = AsyncMock()
return mock_client
# -- Initialization tests ------------------------------------------------------
class TestInit:
"""Test Mem0ContextProvider initialization."""
def test_init_with_all_params(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(
source_id="mem0",
mem0_client=mock_mem0_client,
api_key="key-123",
application_id="app1",
agent_id="agent1",
user_id="user1",
context_prompt="Custom prompt",
)
assert provider.source_id == "mem0"
assert provider.api_key == "key-123"
assert provider.application_id == "app1"
assert provider.agent_id == "agent1"
assert provider.user_id == "user1"
assert provider.context_prompt == "Custom prompt"
assert provider.mem0_client is mock_mem0_client
assert provider._should_close_client is False
def test_init_default_context_prompt(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
assert provider.context_prompt == Mem0ContextProvider.DEFAULT_CONTEXT_PROMPT
def test_init_auto_creates_client_when_none(self) -> None:
"""When no client is provided, a default AsyncMemoryClient is created and flagged for closing."""
with (
patch("mem0.client.main.AsyncMemoryClient.__init__", return_value=None) as mock_init,
patch("mem0.client.main.AsyncMemoryClient._validate_api_key", return_value=None),
):
provider = Mem0ContextProvider(source_id="mem0", api_key="test-key", user_id="u1")
mock_init.assert_called_once_with(api_key="test-key")
assert provider._should_close_client is True
def test_provided_client_not_flagged_for_close(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
assert provider._should_close_client is False
# -- before_run tests ----------------------------------------------------------
class TestBeforeRun:
"""Test before_run hook."""
async def test_memories_added_to_context(self, mock_mem0_client: AsyncMock) -> None:
"""Mocked mem0 search returns memories → messages added to context with prompt."""
mock_mem0_client.search.return_value = [
{"memory": "User likes Python"},
{"memory": "User prefers dark mode"},
]
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", text="Hello")], session_id="s1")
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
mock_mem0_client.search.assert_awaited_once()
assert "mem0" in ctx.context_messages
added = ctx.context_messages["mem0"]
assert len(added) == 1
assert "User likes Python" in added[0].text # type: ignore[operator]
assert "User prefers dark mode" in added[0].text # type: ignore[operator]
assert provider.context_prompt in added[0].text # type: ignore[operator]
async def test_empty_input_skips_search(self, mock_mem0_client: AsyncMock) -> None:
"""Empty input messages → no search performed."""
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", text="")], session_id="s1")
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
mock_mem0_client.search.assert_not_awaited()
assert "mem0" not in ctx.context_messages
async def test_empty_search_results_no_messages(self, mock_mem0_client: AsyncMock) -> None:
"""Empty search results → no messages added."""
mock_mem0_client.search.return_value = []
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", text="test")], session_id="s1")
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
assert "mem0" not in ctx.context_messages
async def test_validates_filters_before_search(self, mock_mem0_client: AsyncMock) -> None:
"""Raises ServiceInitializationError when no filters."""
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", text="test")], session_id="s1")
with pytest.raises(ServiceInitializationError, match="At least one of the filters"):
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
async def test_v1_1_response_format(self, mock_mem0_client: AsyncMock) -> None:
"""Search response in v1.1 dict format with 'results' key."""
mock_mem0_client.search.return_value = {"results": [{"memory": "remembered fact"}]}
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", text="test")], session_id="s1")
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
added = ctx.context_messages["mem0"]
assert "remembered fact" in added[0].text # type: ignore[operator]
async def test_search_query_combines_input_messages(self, mock_mem0_client: AsyncMock) -> None:
"""Multiple input messages are joined for the search query."""
mock_mem0_client.search.return_value = []
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(
input_messages=[
Message(role="user", text="Hello"),
Message(role="user", text="World"),
],
session_id="s1",
)
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
call_kwargs = mock_mem0_client.search.call_args.kwargs
assert call_kwargs["query"] == "Hello\nWorld"
async def test_oss_client_passes_direct_kwargs(self, mock_oss_mem0_client: AsyncMock) -> None:
"""OSS AsyncMemory client should receive user_id as direct kwarg, not in filters."""
mock_oss_mem0_client.search.return_value = [{"memory": "User likes Python"}]
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", text="Hello")], session_id="s1")
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
call_kwargs = mock_oss_mem0_client.search.call_args.kwargs
assert call_kwargs["query"] == "Hello"
assert call_kwargs["user_id"] == "u1"
assert "filters" not in call_kwargs
async def test_oss_client_all_scoping_params(self, mock_oss_mem0_client: AsyncMock) -> None:
"""OSS client with all scoping parameters passes them as direct kwargs."""
mock_oss_mem0_client.search.return_value = []
provider = Mem0ContextProvider(
source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1", agent_id="a1", application_id="app1"
)
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", text="Hello")], session_id="s1")
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
call_kwargs = mock_oss_mem0_client.search.call_args.kwargs
assert call_kwargs["user_id"] == "u1"
assert call_kwargs["agent_id"] == "a1"
assert "filters" not in call_kwargs
async def test_platform_client_passes_filters_dict(self, mock_mem0_client: AsyncMock) -> None:
"""Platform AsyncMemoryClient should receive scoping params in a filters dict."""
mock_mem0_client.search.return_value = []
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", text="Hello")], session_id="s1")
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
call_kwargs = mock_mem0_client.search.call_args.kwargs
assert call_kwargs["query"] == "Hello"
assert "filters" in call_kwargs
assert call_kwargs["filters"]["user_id"] == "u1"
# -- after_run tests -----------------------------------------------------------
class TestAfterRun:
"""Test after_run hook."""
async def test_stores_input_and_response(self, mock_mem0_client: AsyncMock) -> None:
"""Stores input+response messages to mem0 via client.add."""
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", text="question")], session_id="s1")
ctx._response = AgentResponse(messages=[Message(role="assistant", text="answer")])
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
mock_mem0_client.add.assert_awaited_once()
call_kwargs = mock_mem0_client.add.call_args.kwargs
assert call_kwargs["messages"] == [
{"role": "user", "content": "question"},
{"role": "assistant", "content": "answer"},
]
assert call_kwargs["user_id"] == "u1"
assert call_kwargs["run_id"] == "s1"
async def test_only_stores_user_assistant_system(self, mock_mem0_client: AsyncMock) -> None:
"""Only stores user/assistant/system messages with text."""
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(
input_messages=[
Message(role="user", text="hello"),
Message(role="tool", text="tool output"),
],
session_id="s1",
)
ctx._response = AgentResponse(messages=[Message(role="assistant", text="reply")])
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
call_kwargs = mock_mem0_client.add.call_args.kwargs
roles = [m["role"] for m in call_kwargs["messages"]]
assert "tool" not in roles
assert roles == ["user", "assistant"]
async def test_skips_empty_messages(self, mock_mem0_client: AsyncMock) -> None:
"""Skips messages with empty text."""
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(
input_messages=[
Message(role="user", text=""),
Message(role="user", text=" "),
],
session_id="s1",
)
ctx._response = AgentResponse(messages=[])
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
mock_mem0_client.add.assert_not_awaited()
async def test_uses_session_id_as_run_id(self, mock_mem0_client: AsyncMock) -> None:
"""Uses session_id as run_id."""
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", text="hi")], session_id="my-session")
ctx._response = AgentResponse(messages=[Message(role="assistant", text="hey")])
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
assert mock_mem0_client.add.call_args.kwargs["run_id"] == "my-session"
async def test_validates_filters(self, mock_mem0_client: AsyncMock) -> None:
"""Raises ServiceInitializationError when no filters."""
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", text="hi")], session_id="s1")
ctx._response = AgentResponse(messages=[Message(role="assistant", text="hey")])
with pytest.raises(ServiceInitializationError, match="At least one of the filters"):
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
async def test_stores_with_application_id_metadata(self, mock_mem0_client: AsyncMock) -> None:
"""application_id is passed in metadata."""
provider = Mem0ContextProvider(
source_id="mem0", mem0_client=mock_mem0_client, user_id="u1", application_id="app1"
)
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", text="hi")], session_id="s1")
ctx._response = AgentResponse(messages=[])
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
assert mock_mem0_client.add.call_args.kwargs["metadata"] == {"application_id": "app1"}
# -- _validate_filters tests --------------------------------------------------
class TestValidateFilters:
"""Test _validate_filters method."""
def test_raises_when_no_filters(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
with pytest.raises(ServiceInitializationError, match="At least one of the filters"):
provider._validate_filters()
def test_passes_with_user_id(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
provider._validate_filters() # should not raise
def test_passes_with_agent_id(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, agent_id="a1")
provider._validate_filters()
def test_passes_with_application_id(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, application_id="app1")
provider._validate_filters()
# -- _build_filters tests -----------------------------------------------------
class TestBuildFilters:
"""Test _build_filters method."""
def test_user_id_only(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
assert provider._build_filters() == {"user_id": "u1"}
def test_all_params(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(
source_id="mem0",
mem0_client=mock_mem0_client,
user_id="u1",
agent_id="a1",
application_id="app1",
)
assert provider._build_filters(session_id="sess1") == {
"user_id": "u1",
"agent_id": "a1",
"run_id": "sess1",
"app_id": "app1",
}
def test_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
filters = provider._build_filters()
assert "agent_id" not in filters
assert "run_id" not in filters
assert "app_id" not in filters
def test_session_id_mapped_to_run_id(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
filters = provider._build_filters(session_id="s99")
assert filters["run_id"] == "s99"
def test_empty_when_no_params(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
assert provider._build_filters() == {}
# -- Context manager tests -----------------------------------------------------
class TestContextManager:
"""Test __aenter__/__aexit__ delegation."""
async def test_aenter_delegates_to_client(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
result = await provider.__aenter__()
assert result is provider
mock_mem0_client.__aenter__.assert_awaited_once()
async def test_aexit_closes_auto_created_client(self, mock_mem0_client: AsyncMock) -> None:
"""Auto-created clients (_should_close_client=True) are closed on exit."""
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
provider._should_close_client = True
await provider.__aexit__(None, None, None)
mock_mem0_client.__aexit__.assert_awaited_once()
async def test_aexit_does_not_close_provided_client(self, mock_mem0_client: AsyncMock) -> None:
"""Provided clients (_should_close_client=False) are NOT closed on exit."""
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
assert provider._should_close_client is False
await provider.__aexit__(None, None, None)
mock_mem0_client.__aexit__.assert_not_awaited()
async def test_async_with_syntax(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
async with provider as p:
assert p is provider