mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
93cbf6b3f0
* Parse structuredContent from MCP CallToolResult (#3313) The _parse_tool_result_from_mcp method only iterated over the content field from CallToolResult, ignoring the structuredContent field entirely. MCP servers that return JSON data via structuredContent (e.g., Power BI MCP) appeared to return None. Add handling for structuredContent: when present, serialize it as JSON text and append it to the result list. This preserves the data for the LLM while maintaining backward compatibility with existing behavior. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Python: Parse MCP CallToolResult.structuredContent field to prevent tool results returning None Fixes #3313 * Address review feedback: add default=str to json.dumps and remove .checkpoints/ - Add default=str to json.dumps for structuredContent serialization so non-JSON-serializable values (e.g. bytes) degrade gracefully instead of raising TypeError - Remove all .checkpoints/ runtime artifacts from the repository - Add **/.checkpoints/ to .gitignore to prevent future accidental commits - Add test for non-serializable structuredContent values Fixes #3313 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #3313: Python: MCP CallToolResult.structuredContent field is not parsed, causing tool results to return None --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
491 lines
22 KiB
Python
491 lines
22 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
# pyright: reportPrivateUsage=false
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from agent_framework import AgentResponse, Message
|
|
from agent_framework._sessions import AgentSession, SessionContext
|
|
|
|
from agent_framework_mem0._context_provider import Mem0ContextProvider
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_mem0_client() -> AsyncMock:
|
|
"""Create a mock Mem0 AsyncMemoryClient."""
|
|
from mem0 import AsyncMemoryClient
|
|
|
|
mock_client = AsyncMock(spec=AsyncMemoryClient)
|
|
mock_client.add = AsyncMock()
|
|
mock_client.search = AsyncMock()
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock()
|
|
return mock_client
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_oss_mem0_client() -> AsyncMock:
|
|
"""Create a mock Mem0 OSS AsyncMemory client."""
|
|
from mem0 import AsyncMemory
|
|
|
|
mock_client = AsyncMock(spec=AsyncMemory)
|
|
mock_client.add = AsyncMock()
|
|
mock_client.search = AsyncMock()
|
|
return mock_client
|
|
|
|
|
|
# -- Initialization tests ------------------------------------------------------
|
|
|
|
|
|
class TestInit:
|
|
"""Test Mem0ContextProvider initialization."""
|
|
|
|
def test_init_with_all_params(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(
|
|
source_id="mem0",
|
|
mem0_client=mock_mem0_client,
|
|
api_key="key-123",
|
|
application_id="app1",
|
|
agent_id="agent1",
|
|
user_id="user1",
|
|
context_prompt="Custom prompt",
|
|
)
|
|
assert provider.source_id == "mem0"
|
|
assert provider.api_key == "key-123"
|
|
assert provider.application_id == "app1"
|
|
assert provider.agent_id == "agent1"
|
|
assert provider.user_id == "user1"
|
|
assert provider.context_prompt == "Custom prompt"
|
|
assert provider.mem0_client is mock_mem0_client
|
|
assert provider._should_close_client is False
|
|
|
|
def test_init_default_context_prompt(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
assert provider.context_prompt == Mem0ContextProvider.DEFAULT_CONTEXT_PROMPT
|
|
|
|
def test_init_auto_creates_client_when_none(self) -> None:
|
|
"""When no client is provided, a default AsyncMemoryClient is created and flagged for closing."""
|
|
with (
|
|
patch("mem0.client.main.AsyncMemoryClient.__init__", return_value=None) as mock_init,
|
|
patch("mem0.client.main.AsyncMemoryClient._validate_api_key", return_value=None),
|
|
):
|
|
provider = Mem0ContextProvider(source_id="mem0", api_key="test-key", user_id="u1")
|
|
mock_init.assert_called_once_with(api_key="test-key")
|
|
assert provider._should_close_client is True
|
|
|
|
def test_provided_client_not_flagged_for_close(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
assert provider._should_close_client is False
|
|
|
|
|
|
# -- before_run tests ----------------------------------------------------------
|
|
|
|
|
|
class TestBeforeRun:
|
|
"""Test before_run hook."""
|
|
|
|
async def test_memories_added_to_context(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Mocked mem0 search returns memories → messages added to context with prompt."""
|
|
mock_mem0_client.search.return_value = [
|
|
{"memory": "User likes Python"},
|
|
{"memory": "User prefers dark mode"},
|
|
]
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1")
|
|
|
|
await provider.before_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
mock_mem0_client.search.assert_awaited_once()
|
|
assert "mem0" in ctx.context_messages
|
|
added = ctx.context_messages["mem0"]
|
|
assert len(added) == 1
|
|
assert "User likes Python" in added[0].text # type: ignore[operator]
|
|
assert "User prefers dark mode" in added[0].text # type: ignore[operator]
|
|
assert provider.context_prompt in added[0].text # type: ignore[operator]
|
|
|
|
async def test_empty_input_skips_search(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Empty input messages → no search performed."""
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=[""])], session_id="s1")
|
|
|
|
await provider.before_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
mock_mem0_client.search.assert_not_awaited()
|
|
assert "mem0" not in ctx.context_messages
|
|
|
|
async def test_empty_search_results_no_messages(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Empty search results → no messages added."""
|
|
mock_mem0_client.search.return_value = []
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=["test"])], session_id="s1")
|
|
|
|
await provider.before_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
assert "mem0" not in ctx.context_messages
|
|
|
|
async def test_validates_filters_before_search(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Raises ValueError when no filters."""
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=["test"])], session_id="s1")
|
|
|
|
with pytest.raises(ValueError, match="At least one of the filters"):
|
|
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
|
|
|
async def test_v1_1_response_format(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Search response in v1.1 dict format with 'results' key."""
|
|
mock_mem0_client.search.return_value = {"results": [{"memory": "remembered fact"}]}
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=["test"])], session_id="s1")
|
|
|
|
await provider.before_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
added = ctx.context_messages["mem0"]
|
|
assert "remembered fact" in added[0].text # type: ignore[operator]
|
|
|
|
async def test_search_query_combines_input_messages(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Multiple input messages are joined for the search query."""
|
|
mock_mem0_client.search.return_value = []
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(
|
|
input_messages=[
|
|
Message(role="user", contents=["Hello"]),
|
|
Message(role="user", contents=["World"]),
|
|
],
|
|
session_id="s1",
|
|
)
|
|
|
|
await provider.before_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
call_kwargs = mock_mem0_client.search.call_args.kwargs
|
|
assert call_kwargs["query"] == "Hello\nWorld"
|
|
|
|
async def test_oss_client_passes_direct_kwargs(self, mock_oss_mem0_client: AsyncMock) -> None:
|
|
"""OSS AsyncMemory client should receive user_id as direct kwarg, not in filters."""
|
|
mock_oss_mem0_client.search.return_value = [{"memory": "User likes Python"}]
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1")
|
|
|
|
await provider.before_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
call_kwargs = mock_oss_mem0_client.search.call_args.kwargs
|
|
assert call_kwargs["query"] == "Hello"
|
|
assert call_kwargs["user_id"] == "u1"
|
|
assert "filters" not in call_kwargs
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_oss_client_all_scoping_params_except_app_id(self, mock_oss_mem0_client: AsyncMock) -> None:
|
|
"""OSS client with all scoping parameters passes them as isolated concurrent kwargs."""
|
|
mock_oss_mem0_client.search.return_value = []
|
|
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1", agent_id="a1")
|
|
|
|
mock_context = MagicMock(spec=SessionContext)
|
|
mock_msg = MagicMock()
|
|
mock_msg.text = "hello"
|
|
mock_context.input_messages = [mock_msg]
|
|
mock_context.response = None
|
|
|
|
await provider.before_run(
|
|
agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={}
|
|
)
|
|
|
|
# Re-aligned assertion: We expect 2 separate concurrent calls instead of 1 combined call
|
|
assert mock_oss_mem0_client.search.call_count == 2
|
|
mock_oss_mem0_client.search.assert_any_call(query="hello", user_id="u1")
|
|
mock_oss_mem0_client.search.assert_any_call(query="hello", agent_id="a1")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_platform_client_passes_filters_dict_except_app_id(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Platform client passes scoping parameters concurrently inside the nested filters dictionary."""
|
|
mock_mem0_client.search.return_value = []
|
|
|
|
provider = Mem0ContextProvider(
|
|
source_id="mem0",
|
|
mem0_client=mock_mem0_client,
|
|
user_id="u1",
|
|
agent_id="a1",
|
|
)
|
|
|
|
mock_context = MagicMock(spec=SessionContext)
|
|
mock_msg = MagicMock()
|
|
mock_msg.text = "hello"
|
|
mock_context.input_messages = [mock_msg]
|
|
mock_context.response = None
|
|
|
|
await provider.before_run(
|
|
agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={}
|
|
)
|
|
|
|
# Re-aligned assertion: Platform client isolates filters per call to bypass AND limitations
|
|
assert mock_mem0_client.search.call_count == 2
|
|
mock_mem0_client.search.assert_any_call(query="hello", filters={"user_id": "u1"})
|
|
mock_mem0_client.search.assert_any_call(query="hello", filters={"agent_id": "a1"})
|
|
|
|
|
|
# -- after_run tests -----------------------------------------------------------
|
|
|
|
|
|
class TestAfterRun:
|
|
"""Test after_run hook."""
|
|
|
|
async def test_stores_input_and_response(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Stores input+response messages to mem0 via client.add."""
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=["question"])], session_id="s1")
|
|
ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["answer"])])
|
|
|
|
await provider.after_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
mock_mem0_client.add.assert_awaited_once()
|
|
call_kwargs = mock_mem0_client.add.call_args.kwargs
|
|
assert call_kwargs["messages"] == [
|
|
{"role": "user", "content": "question"},
|
|
{"role": "assistant", "content": "answer"},
|
|
]
|
|
assert call_kwargs["user_id"] == "u1"
|
|
assert "run_id" not in call_kwargs
|
|
|
|
async def test_only_stores_user_assistant_system(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Only stores user/assistant/system messages with text."""
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(
|
|
input_messages=[
|
|
Message(role="user", contents=["hello"]),
|
|
Message(role="tool", contents=["tool output"]),
|
|
],
|
|
session_id="s1",
|
|
)
|
|
ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["reply"])])
|
|
|
|
await provider.after_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
call_kwargs = mock_mem0_client.add.call_args.kwargs
|
|
roles = [m["role"] for m in call_kwargs["messages"]]
|
|
assert "tool" not in roles
|
|
assert roles == ["user", "assistant"]
|
|
|
|
async def test_skips_empty_messages(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Skips messages with empty text."""
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(
|
|
input_messages=[
|
|
Message(role="user", contents=[""]),
|
|
Message(role="user", contents=[" "]),
|
|
],
|
|
session_id="s1",
|
|
)
|
|
ctx._response = AgentResponse(messages=[])
|
|
|
|
await provider.after_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
mock_mem0_client.add.assert_not_awaited()
|
|
|
|
async def test_no_run_id_in_storage(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""run_id is not passed to mem0 add, so memories are not scoped to sessions."""
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=["hi"])], session_id="my-session")
|
|
ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["hey"])])
|
|
|
|
await provider.after_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
assert "run_id" not in mock_mem0_client.add.call_args.kwargs
|
|
|
|
async def test_validates_filters(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Raises ValueError when no filters."""
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=["hi"])], session_id="s1")
|
|
ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["hey"])])
|
|
|
|
with pytest.raises(ValueError, match="At least one of the filters"):
|
|
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
|
|
|
async def test_stores_with_application_id_filters(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""application_id is passed in filters."""
|
|
provider = Mem0ContextProvider(
|
|
source_id="mem0", mem0_client=mock_mem0_client, user_id="u1", application_id="app1"
|
|
)
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=["hi"])], session_id="s1")
|
|
ctx._response = AgentResponse(messages=[])
|
|
|
|
await provider.after_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
assert mock_mem0_client.add.call_args.kwargs["filters"] == {"app_id": "app1"}
|
|
|
|
|
|
# -- _validate_filters tests --------------------------------------------------
|
|
|
|
|
|
class TestValidateFilters:
|
|
"""Test _validate_filters method."""
|
|
|
|
def test_raises_when_no_filters(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
|
|
with pytest.raises(ValueError, match="At least one of the filters"):
|
|
provider._validate_filters()
|
|
|
|
def test_passes_with_user_id(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
provider._validate_filters() # should not raise
|
|
|
|
def test_passes_with_agent_id(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, agent_id="a1")
|
|
provider._validate_filters()
|
|
|
|
def test_passes_with_application_id(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, application_id="app1")
|
|
provider._validate_filters()
|
|
|
|
|
|
# -- _build_search_kwargs tests -----------------------------------------------------
|
|
|
|
|
|
class TestBuildSearchKwargs:
|
|
"""Test _build_search_kwargs method."""
|
|
|
|
def test_user_id_only(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
|
|
# Pass the 3 required arguments
|
|
result = provider._build_search_kwargs("test query", "user_id", "u1")
|
|
|
|
# AsyncMock triggers the Platform client nested 'filters' structure
|
|
assert result == {"query": "test query", "filters": {"user_id": "u1"}}
|
|
|
|
def test_all_params(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(
|
|
source_id="mem0",
|
|
mem0_client=mock_mem0_client,
|
|
user_id="u1",
|
|
agent_id="a1",
|
|
application_id="app1",
|
|
)
|
|
|
|
# Test that app_id correctly merges with the isolated target entity
|
|
result = provider._build_search_kwargs("test query", "agent_id", "a1")
|
|
|
|
assert result == {
|
|
"query": "test query",
|
|
"filters": {
|
|
"agent_id": "a1",
|
|
"app_id": "app1",
|
|
},
|
|
}
|
|
|
|
def test_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
|
|
# application_id is None by default, it should not appear in the dictionary
|
|
result = provider._build_search_kwargs("test query", "user_id", "u1")
|
|
|
|
assert "app_id" not in result.get("filters", {})
|
|
|
|
def test_no_run_id_in_search_filters(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""run_id is excluded from search filters so memories work across sessions."""
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
|
|
result = provider._build_search_kwargs("test query", "user_id", "u1")
|
|
|
|
assert "run_id" not in result.get("filters", {})
|
|
assert "run_id" not in result
|
|
|
|
def test_empty_when_no_params(self, mock_mem0_client: AsyncMock) -> None:
|
|
# Validates base query payload generation
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
|
|
|
|
result = provider._build_search_kwargs("test query", "custom_key", "custom_val")
|
|
|
|
assert result == {"query": "test query", "filters": {"custom_key": "custom_val"}}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_before_run_application_only_fallback(self, mock_mem0_client: AsyncMock) -> None:
|
|
|
|
provider = Mem0ContextProvider(
|
|
source_id="mem0", mem0_client=mock_mem0_client, application_id="app_fallback_test"
|
|
)
|
|
|
|
# Mock a valid message list and session container setup
|
|
mock_context = MagicMock(spec=SessionContext)
|
|
mock_msg = MagicMock()
|
|
mock_msg.text = "Retrieve systemic fallback memory traces"
|
|
mock_context.input_messages = [mock_msg]
|
|
mock_context.response = None
|
|
|
|
mock_mem0_client.search = AsyncMock(return_value=[{"id": "m1", "memory": "System configuration template"}])
|
|
|
|
await provider.before_run(
|
|
agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={}
|
|
)
|
|
|
|
# Verify that an application-scoped search task executed successfully
|
|
assert mock_mem0_client.search.call_count == 1
|
|
mock_context.extend_messages.assert_called_once()
|
|
|
|
|
|
# -- Context manager tests -----------------------------------------------------
|
|
|
|
|
|
class TestContextManager:
|
|
"""Test __aenter__/__aexit__ delegation."""
|
|
|
|
async def test_aenter_delegates_to_client(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
result = await provider.__aenter__()
|
|
assert result is provider
|
|
mock_mem0_client.__aenter__.assert_awaited_once()
|
|
|
|
async def test_aexit_closes_auto_created_client(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Auto-created clients (_should_close_client=True) are closed on exit."""
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
provider._should_close_client = True
|
|
await provider.__aexit__(None, None, None)
|
|
mock_mem0_client.__aexit__.assert_awaited_once()
|
|
|
|
async def test_aexit_does_not_close_provided_client(self, mock_mem0_client: AsyncMock) -> None:
|
|
"""Provided clients (_should_close_client=False) are NOT closed on exit."""
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
assert provider._should_close_client is False
|
|
await provider.__aexit__(None, None, None)
|
|
mock_mem0_client.__aexit__.assert_not_awaited()
|
|
|
|
async def test_async_with_syntax(self, mock_mem0_client: AsyncMock) -> None:
|
|
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
|
async with provider as p:
|
|
assert p is provider
|