mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
1e350ea22f
* PR2: Wire context provider pipeline and update all internal consumers - Replace AgentThread with AgentSession across all packages - Replace ContextProvider with BaseContextProvider across all packages - Replace context_provider param with context_providers (Sequence) - Replace thread= with session= in run() signatures - Replace get_new_thread() with create_session() - Add get_session(service_session_id) to agent interface - DurableAgentThread -> DurableAgentSession - Remove _notify_thread_of_new_messages from WorkflowAgent - Wire before_run/after_run context provider pipeline in RawAgent - Auto-inject InMemoryHistoryProvider when no providers configured * fix: update all tests for context provider pipeline, fix lazy-loaders, remove old test files * refactor: update all sample files for context provider pipeline (AgentThread→AgentSession, ContextProvider→BaseContextProvider) * fix: update remaining ag-ui references (client docstring, getting_started sample) * fix: make get_session service_session_id keyword-only to avoid confusion with session_id * refactor: rename _RunContext.thread_messages to session_messages * refactor: remove _threads.py, _memory.py, and old provider files; migrate devui to use plain message lists * rename: remove _new_ prefix from test files * refactor: rewrite SlidingWindowChatMessageStore as SlidingWindowHistoryProvider(InMemoryHistoryProvider) * fix: read full history from session state directly instead of reaching into provider internals * fix: update stale .pyi stubs, sample imports, and README references for new provider types * fix: remove stale message_store, _notify_thread_of_new_messages, and session_id.key references in samples * refactor: merge context_providers and sessions sample folders into sessions, remove aggregate_context_provider * refactor: UserInfoMemory stores state in session.state instead of instance attributes * feat: add Pydantic BaseModel support to session state serialization Pydantic models stored in session.state are now automatically serialized via model_dump() and restored via model_validate() during to_dict()/from_dict() round-trips. Models are auto-registered on first serialization; use register_state_type() for cold-start deserialization. Also export register_state_type as a public API. * fix mem0 * Update sample README links and descriptions for session terminology - Replace 'thread' with 'session' in sample descriptions across all READMEs - Update file links for renamed samples (mem0_sessions, redis_sessions, etc.) - Fix Threads section → Sessions section in main samples/README.md - Update tools, middleware, workflows, durabletask, azure_functions READMEs - Update architecture diagrams in concepts/tools/README.md - Update migration guides (autogen, semantic-kernel) * Fix broken Redis README link to renamed sample * Fix Mem0 OSS client search: pass scoping params as direct kwargs AsyncMemory (OSS) expects user_id/agent_id/run_id as direct kwargs, while AsyncMemoryClient (Platform) expects them in a filters dict. Adds tests for both client types. Port of fix from #3844 to new Mem0ContextProvider. * Fix rebase issues: restore missing _conversation_state.py and checkpoint decode logic - Add back _conversation_state.py (encode/decode_chat_messages) lost in rebase - Fix on_checkpoint_restore to decode cache/conversation with decode_chat_messages - Fix on_checkpoint_restore to use decode_checkpoint_value for pending requests - Add tests/workflow/__init__.py for relative import support - Fix test_agent_executor checkpoint selection (checkpoints[1] not superstep) * Add STORES_BY_DEFAULT ClassVar to skip redundant InMemoryHistoryProvider injection Chat clients that store history server-side by default (OpenAI Responses API, Azure AI Agent) now declare STORES_BY_DEFAULT = True. The agent checks this during auto-injection and skips InMemoryHistoryProvider unless the user explicitly sets store=False. * Fix broken markdown links in azure_ai and redis READMEs * Fix getting-started samples to use session API instead of removed thread/ContextProvider API * updates to workflow as agent * fix group chat import * Rename Thread→Session throughout, fix service_session_id propagation, remove stale AGUIThread - Fix: Propagate conversation_id from ChatResponse back to session.service_session_id in both streaming and non-streaming paths in _agents.py - Rename AgentThreadException → AgentSessionException - Remove stale AGUIThread from ag_ui lazy-loader - Rename use_service_thread → use_service_session in ag-ui package - Rename test functions from *_thread_* to *_session_* - Rename sample files from *_thread* to *_session* - Update docstrings and comments: thread → session - Update _mcp.py kwargs filter: add 'session' alongside 'thread' - Fix ContinuationToken docstring example: thread=thread → session=session - Fix _clients.py docstring: 'Agent threads' → 'Agent sessions' * Fix broken markdown links after thread→session file renames * fix azure ai test
408 lines
19 KiB
Python
408 lines
19 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
# pyright: reportPrivateUsage=false
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
from agent_framework import AgentResponse, Message
|
|
from agent_framework._sessions import AgentSession, SessionContext
|
|
from agent_framework.exceptions import ServiceInitializationError
|
|
|
|
from agent_framework_mem0._context_provider import Mem0ContextProvider
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_mem0_client() -> AsyncMock:
|
|
"""Create a mock Mem0 AsyncMemoryClient."""
|
|
from mem0 import AsyncMemoryClient
|
|
|
|
mock_client = AsyncMock(spec=AsyncMemoryClient)
|
|
mock_client.add = AsyncMock()
|
|
mock_client.search = AsyncMock()
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock()
|
|
return mock_client
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_oss_mem0_client() -> AsyncMock:
|
|
"""Create a mock Mem0 OSS AsyncMemory client."""
|
|
from mem0 import AsyncMemory
|
|
|
|
mock_client = AsyncMock(spec=AsyncMemory)
|
|
mock_client.add = AsyncMock()
|
|
mock_client.search = AsyncMock()
|
|
return mock_client
|
|
|
|
|
|
# -- Initialization tests ------------------------------------------------------
|
|
|
|
|
|
class TestInit:
|
|
"""Test Mem0ContextProvider initialization."""
|
|
|
|
def test_init_with_all_params(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(
|
|
source_id="mem0",
|
|
mem0_client=mock_mem0_client,
|
|
api_key="key-123",
|
|
application_id="app1",
|
|
agent_id="agent1",
|
|
user_id="user1",
|
|
context_prompt="Custom prompt",
|
|
)
|
|
assert provider.source_id == "mem0"
|
|
assert provider.api_key == "key-123"
|
|
assert provider.application_id == "app1"
|
|
assert provider.agent_id == "agent1"
|
|
assert provider.user_id == "user1"
|
|
assert provider.context_prompt == "Custom prompt"
|
|
assert provider.mem0_client is mock_mem0_client
|
|
assert provider._should_close_client is False
|
|
|
|
def test_init_default_context_prompt(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
assert provider.context_prompt == Mem0ContextProvider.DEFAULT_CONTEXT_PROMPT
|
|
|
|
def test_init_auto_creates_client_when_none(self) -> None:
|
|
"""When no client is provided, a default AsyncMemoryClient is created and flagged for closing."""
|
|
with (
|
|
patch("mem0.client.main.AsyncMemoryClient.__init__", return_value=None) as mock_init,
|
|
patch("mem0.client.main.AsyncMemoryClient._validate_api_key", return_value=None),
|
|
):
|
|
provider = Mem0ContextProvider(source_id="mem0", api_key="test-key", user_id="u1")
|
|
mock_init.assert_called_once_with(api_key="test-key")
|
|
assert provider._should_close_client is True
|
|
|
|
def test_provided_client_not_flagged_for_close(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
assert provider._should_close_client is False
|
|
|
|
|
|
# -- before_run tests ----------------------------------------------------------
|
|
|
|
|
|
class TestBeforeRun:
|
|
"""Test before_run hook."""
|
|
|
|
async def test_memories_added_to_context(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Mocked mem0 search returns memories → messages added to context with prompt."""
|
|
mock_mem0_client.search.return_value = [
|
|
{"memory": "User likes Python"},
|
|
{"memory": "User prefers dark mode"},
|
|
]
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", text="Hello")], session_id="s1")
|
|
|
|
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
|
|
|
mock_mem0_client.search.assert_awaited_once()
|
|
assert "mem0" in ctx.context_messages
|
|
added = ctx.context_messages["mem0"]
|
|
assert len(added) == 1
|
|
assert "User likes Python" in added[0].text # type: ignore[operator]
|
|
assert "User prefers dark mode" in added[0].text # type: ignore[operator]
|
|
assert provider.context_prompt in added[0].text # type: ignore[operator]
|
|
|
|
async def test_empty_input_skips_search(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Empty input messages → no search performed."""
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", text="")], session_id="s1")
|
|
|
|
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
|
|
|
mock_mem0_client.search.assert_not_awaited()
|
|
assert "mem0" not in ctx.context_messages
|
|
|
|
async def test_empty_search_results_no_messages(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Empty search results → no messages added."""
|
|
mock_mem0_client.search.return_value = []
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", text="test")], session_id="s1")
|
|
|
|
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
|
|
|
assert "mem0" not in ctx.context_messages
|
|
|
|
async def test_validates_filters_before_search(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Raises ServiceInitializationError when no filters."""
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", text="test")], session_id="s1")
|
|
|
|
with pytest.raises(ServiceInitializationError, match="At least one of the filters"):
|
|
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
|
|
|
async def test_v1_1_response_format(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Search response in v1.1 dict format with 'results' key."""
|
|
mock_mem0_client.search.return_value = {"results": [{"memory": "remembered fact"}]}
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", text="test")], session_id="s1")
|
|
|
|
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
|
|
|
added = ctx.context_messages["mem0"]
|
|
assert "remembered fact" in added[0].text # type: ignore[operator]
|
|
|
|
async def test_search_query_combines_input_messages(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Multiple input messages are joined for the search query."""
|
|
mock_mem0_client.search.return_value = []
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(
|
|
input_messages=[
|
|
Message(role="user", text="Hello"),
|
|
Message(role="user", text="World"),
|
|
],
|
|
session_id="s1",
|
|
)
|
|
|
|
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
|
|
|
call_kwargs = mock_mem0_client.search.call_args.kwargs
|
|
assert call_kwargs["query"] == "Hello\nWorld"
|
|
|
|
async def test_oss_client_passes_direct_kwargs(self, mock_oss_mem0_client: AsyncMock) -> None:
|
|
"""OSS AsyncMemory client should receive user_id as direct kwarg, not in filters."""
|
|
mock_oss_mem0_client.search.return_value = [{"memory": "User likes Python"}]
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", text="Hello")], session_id="s1")
|
|
|
|
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
|
|
|
call_kwargs = mock_oss_mem0_client.search.call_args.kwargs
|
|
assert call_kwargs["query"] == "Hello"
|
|
assert call_kwargs["user_id"] == "u1"
|
|
assert "filters" not in call_kwargs
|
|
|
|
async def test_oss_client_all_scoping_params(self, mock_oss_mem0_client: AsyncMock) -> None:
|
|
"""OSS client with all scoping parameters passes them as direct kwargs."""
|
|
mock_oss_mem0_client.search.return_value = []
|
|
provider = Mem0ContextProvider(
|
|
source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1", agent_id="a1", application_id="app1"
|
|
)
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", text="Hello")], session_id="s1")
|
|
|
|
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
|
|
|
call_kwargs = mock_oss_mem0_client.search.call_args.kwargs
|
|
assert call_kwargs["user_id"] == "u1"
|
|
assert call_kwargs["agent_id"] == "a1"
|
|
assert "filters" not in call_kwargs
|
|
|
|
async def test_platform_client_passes_filters_dict(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Platform AsyncMemoryClient should receive scoping params in a filters dict."""
|
|
mock_mem0_client.search.return_value = []
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", text="Hello")], session_id="s1")
|
|
|
|
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
|
|
|
call_kwargs = mock_mem0_client.search.call_args.kwargs
|
|
assert call_kwargs["query"] == "Hello"
|
|
assert "filters" in call_kwargs
|
|
assert call_kwargs["filters"]["user_id"] == "u1"
|
|
|
|
|
|
# -- after_run tests -----------------------------------------------------------
|
|
|
|
|
|
class TestAfterRun:
|
|
"""Test after_run hook."""
|
|
|
|
async def test_stores_input_and_response(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Stores input+response messages to mem0 via client.add."""
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", text="question")], session_id="s1")
|
|
ctx._response = AgentResponse(messages=[Message(role="assistant", text="answer")])
|
|
|
|
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
|
|
|
mock_mem0_client.add.assert_awaited_once()
|
|
call_kwargs = mock_mem0_client.add.call_args.kwargs
|
|
assert call_kwargs["messages"] == [
|
|
{"role": "user", "content": "question"},
|
|
{"role": "assistant", "content": "answer"},
|
|
]
|
|
assert call_kwargs["user_id"] == "u1"
|
|
assert call_kwargs["run_id"] == "s1"
|
|
|
|
async def test_only_stores_user_assistant_system(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Only stores user/assistant/system messages with text."""
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(
|
|
input_messages=[
|
|
Message(role="user", text="hello"),
|
|
Message(role="tool", text="tool output"),
|
|
],
|
|
session_id="s1",
|
|
)
|
|
ctx._response = AgentResponse(messages=[Message(role="assistant", text="reply")])
|
|
|
|
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
|
|
|
call_kwargs = mock_mem0_client.add.call_args.kwargs
|
|
roles = [m["role"] for m in call_kwargs["messages"]]
|
|
assert "tool" not in roles
|
|
assert roles == ["user", "assistant"]
|
|
|
|
async def test_skips_empty_messages(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Skips messages with empty text."""
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(
|
|
input_messages=[
|
|
Message(role="user", text=""),
|
|
Message(role="user", text=" "),
|
|
],
|
|
session_id="s1",
|
|
)
|
|
ctx._response = AgentResponse(messages=[])
|
|
|
|
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
|
|
|
mock_mem0_client.add.assert_not_awaited()
|
|
|
|
async def test_uses_session_id_as_run_id(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Uses session_id as run_id."""
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", text="hi")], session_id="my-session")
|
|
ctx._response = AgentResponse(messages=[Message(role="assistant", text="hey")])
|
|
|
|
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
|
|
|
assert mock_mem0_client.add.call_args.kwargs["run_id"] == "my-session"
|
|
|
|
async def test_validates_filters(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Raises ServiceInitializationError when no filters."""
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", text="hi")], session_id="s1")
|
|
ctx._response = AgentResponse(messages=[Message(role="assistant", text="hey")])
|
|
|
|
with pytest.raises(ServiceInitializationError, match="At least one of the filters"):
|
|
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
|
|
|
async def test_stores_with_application_id_metadata(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""application_id is passed in metadata."""
|
|
provider = Mem0ContextProvider(
|
|
source_id="mem0", mem0_client=mock_mem0_client, user_id="u1", application_id="app1"
|
|
)
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", text="hi")], session_id="s1")
|
|
ctx._response = AgentResponse(messages=[])
|
|
|
|
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
|
|
|
assert mock_mem0_client.add.call_args.kwargs["metadata"] == {"application_id": "app1"}
|
|
|
|
|
|
# -- _validate_filters tests --------------------------------------------------
|
|
|
|
|
|
class TestValidateFilters:
|
|
"""Test _validate_filters method."""
|
|
|
|
def test_raises_when_no_filters(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
|
|
with pytest.raises(ServiceInitializationError, match="At least one of the filters"):
|
|
provider._validate_filters()
|
|
|
|
def test_passes_with_user_id(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
provider._validate_filters() # should not raise
|
|
|
|
def test_passes_with_agent_id(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, agent_id="a1")
|
|
provider._validate_filters()
|
|
|
|
def test_passes_with_application_id(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, application_id="app1")
|
|
provider._validate_filters()
|
|
|
|
|
|
# -- _build_filters tests -----------------------------------------------------
|
|
|
|
|
|
class TestBuildFilters:
|
|
"""Test _build_filters method."""
|
|
|
|
def test_user_id_only(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
assert provider._build_filters() == {"user_id": "u1"}
|
|
|
|
def test_all_params(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(
|
|
source_id="mem0",
|
|
mem0_client=mock_mem0_client,
|
|
user_id="u1",
|
|
agent_id="a1",
|
|
application_id="app1",
|
|
)
|
|
assert provider._build_filters(session_id="sess1") == {
|
|
"user_id": "u1",
|
|
"agent_id": "a1",
|
|
"run_id": "sess1",
|
|
"app_id": "app1",
|
|
}
|
|
|
|
def test_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
filters = provider._build_filters()
|
|
assert "agent_id" not in filters
|
|
assert "run_id" not in filters
|
|
assert "app_id" not in filters
|
|
|
|
def test_session_id_mapped_to_run_id(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
filters = provider._build_filters(session_id="s99")
|
|
assert filters["run_id"] == "s99"
|
|
|
|
def test_empty_when_no_params(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
|
|
assert provider._build_filters() == {}
|
|
|
|
|
|
# -- Context manager tests -----------------------------------------------------
|
|
|
|
|
|
class TestContextManager:
|
|
"""Test __aenter__/__aexit__ delegation."""
|
|
|
|
async def test_aenter_delegates_to_client(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
result = await provider.__aenter__()
|
|
assert result is provider
|
|
mock_mem0_client.__aenter__.assert_awaited_once()
|
|
|
|
async def test_aexit_closes_auto_created_client(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Auto-created clients (_should_close_client=True) are closed on exit."""
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
provider._should_close_client = True
|
|
await provider.__aexit__(None, None, None)
|
|
mock_mem0_client.__aexit__.assert_awaited_once()
|
|
|
|
async def test_aexit_does_not_close_provided_client(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Provided clients (_should_close_client=False) are NOT closed on exit."""
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
assert provider._should_close_client is False
|
|
await provider.__aexit__(None, None, None)
|
|
mock_mem0_client.__aexit__.assert_not_awaited()
|
|
|
|
async def test_async_with_syntax(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
async with provider as p:
|
|
assert p is provider
|