Python: [BREAKING] cleanup of thread API and serialization (#893)

* cleanup of threads and serialization

* fix for sliding window

* fix redis test

* updated from comments

* updated context provider and threads

* updated lock

* add asyncio default

* fix redis tests

* fix tests

* fix tests

* renamed to invoking

* fixed tests

* fix for instructions
This commit is contained in:
Eduard van Valkenburg
2025-09-29 18:22:34 +02:00
committed by GitHub
Unverified
parent bf5931932e
commit 10d10364a9
52 changed files with 1642 additions and 1411 deletions
@@ -3,7 +3,7 @@
import json
import os
import sys
from collections.abc import AsyncIterable, MutableMapping, MutableSequence
from collections.abc import AsyncIterable, MutableMapping, MutableSequence, Sequence
from typing import Any, ClassVar, TypeVar
from agent_framework import (
@@ -269,7 +269,8 @@ class AzureAIAgentClient(BaseChatClient):
**kwargs: Any,
) -> ChatResponse:
return await ChatResponse.from_chat_response_generator(
updates=self._inner_get_streaming_response(messages=messages, chat_options=chat_options, **kwargs)
updates=self._inner_get_streaming_response(messages=messages, chat_options=chat_options, **kwargs),
output_format_type=chat_options.response_format,
)
async def _inner_get_streaming_response(
@@ -660,7 +661,7 @@ class AzureAIAgentClient(BaseChatClient):
)
)
instructions: list[str] = []
instructions: list[str] = [chat_options.instructions] if chat_options and chat_options.instructions else []
required_action_results: list[FunctionResultContent | FunctionApprovalResponseContent] | None = None
additional_messages: list[ThreadMessageOptions] | None = None
@@ -708,7 +709,7 @@ class AzureAIAgentClient(BaseChatClient):
return run_options, required_action_results
async def _prep_tools(
self, tools: list["ToolProtocol | MutableMapping[str, Any]"]
self, tools: Sequence["ToolProtocol | MutableMapping[str, Any]"]
) -> list[ToolDefinition | dict[str, Any]]:
"""Prepare tool definitions for the run options."""
tool_definitions: list[ToolDefinition | dict[str, Any]] = []
@@ -4,11 +4,14 @@ from collections.abc import AsyncIterable
from typing import Any, ClassVar
from agent_framework import (
AgentMiddlewares,
AgentRunResponse,
AgentRunResponseUpdate,
AgentThread,
AggregateContextProvider,
BaseAgent,
ChatMessage,
ContextProvider,
Role,
TextContent,
)
@@ -53,20 +56,16 @@ class CopilotStudioSettings(AFBaseSettings):
class CopilotStudioAgent(BaseAgent):
"""A Copilot Studio Agent."""
client: CopilotClient
settings: ConnectionSettings | None
token: str | None
cloud: PowerPlatformCloud | None
agent_type: AgentType | None
custom_power_platform_cloud: str | None
username: str | None
token_cache: Any | None
scopes: list[str] | None
def __init__(
self,
client: CopilotClient | None = None,
settings: ConnectionSettings | None = None,
*,
id: str | None = None,
name: str | None = None,
description: str | None = None,
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
middleware: AgentMiddlewares | list[AgentMiddlewares] | None = None,
environment_id: str | None = None,
agent_identifier: str | None = None,
client_id: str | None = None,
@@ -88,6 +87,11 @@ class CopilotStudioAgent(BaseAgent):
a new client will be created using the other parameters.
settings: Optional pre-configured ConnectionSettings. If not provided,
settings will be created from the other parameters.
id: id of the CopilotAgent
name: Name of the CopilotAgent
description: Description of the CopilotAgent
context_providers: Context Providers, to be used by the copilot agent.
middleware: Agent middlewares used by the agent.
environment_id: Environment ID of the Power Platform environment containing
the Copilot Studio app. Can also be set via COPILOTSTUDIOAGENT__ENVIRONMENTID
environment variable.
@@ -113,6 +117,13 @@ class CopilotStudioAgent(BaseAgent):
Raises:
ServiceInitializationError: If required configuration is missing or invalid.
"""
super().__init__(
id=id,
name=name,
description=description,
context_providers=context_providers,
middleware=middleware,
)
if not client:
try:
copilot_studio_settings = CopilotStudioSettings(
@@ -169,17 +180,13 @@ class CopilotStudioAgent(BaseAgent):
client = CopilotClient(settings=settings, token=token)
super().__init__(
client=client, # type: ignore[reportCallIssue]
settings=settings, # type: ignore[reportCallIssue]
token=token, # type: ignore[reportCallIssue]
cloud=cloud, # type: ignore[reportCallIssue]
agent_type=agent_type, # type: ignore[reportCallIssue]
custom_power_platform_cloud=custom_power_platform_cloud, # type: ignore[reportCallIssue]
username=username, # type: ignore[reportCallIssue]
token_cache=token_cache, # type: ignore[reportCallIssue]
scopes=scopes, # type: ignore[reportCallIssue]
)
self.client = client
self.cloud = cloud
self.agent_type = agent_type
self.custom_power_platform_cloud = custom_power_platform_cloud
self.username = username
self.token_cache = token_cache
self.scopes = scopes
async def run(
self,
@@ -121,7 +121,6 @@ class TestCopilotStudioAgent:
with pytest.raises(ServiceInitializationError, match="agent identifier"):
CopilotStudioAgent()
@pytest.mark.asyncio
async def test_run_with_string_message(self, mock_copilot_client: MagicMock, mock_activity: MagicMock) -> None:
"""Test run method with string message."""
agent = CopilotStudioAgent(client=mock_copilot_client)
@@ -141,7 +140,6 @@ class TestCopilotStudioAgent:
assert content.text == "Test response"
assert response.messages[0].role == Role.ASSISTANT
@pytest.mark.asyncio
async def test_run_with_chat_message(self, mock_copilot_client: MagicMock, mock_activity: MagicMock) -> None:
"""Test run method with ChatMessage."""
agent = CopilotStudioAgent(client=mock_copilot_client)
@@ -162,7 +160,6 @@ class TestCopilotStudioAgent:
assert content.text == "Test response"
assert response.messages[0].role == Role.ASSISTANT
@pytest.mark.asyncio
async def test_run_with_thread(self, mock_copilot_client: MagicMock, mock_activity: MagicMock) -> None:
"""Test run method with existing thread."""
agent = CopilotStudioAgent(client=mock_copilot_client)
@@ -180,7 +177,6 @@ class TestCopilotStudioAgent:
assert len(response.messages) == 1
assert thread.service_thread_id == "test-conversation-id"
@pytest.mark.asyncio
async def test_run_start_conversation_failure(self, mock_copilot_client: MagicMock) -> None:
"""Test run method when conversation start fails."""
agent = CopilotStudioAgent(client=mock_copilot_client)
@@ -190,7 +186,6 @@ class TestCopilotStudioAgent:
with pytest.raises(ServiceException, match="Failed to start a new conversation"):
await agent.run("test message")
@pytest.mark.asyncio
async def test_run_stream_with_string_message(self, mock_copilot_client: MagicMock) -> None:
"""Test run_stream method with string message."""
agent = CopilotStudioAgent(client=mock_copilot_client)
@@ -217,7 +212,6 @@ class TestCopilotStudioAgent:
assert response_count == 1
@pytest.mark.asyncio
async def test_run_stream_with_thread(self, mock_copilot_client: MagicMock) -> None:
"""Test run_stream method with existing thread."""
agent = CopilotStudioAgent(client=mock_copilot_client)
@@ -246,7 +240,6 @@ class TestCopilotStudioAgent:
assert response_count == 1
assert thread.service_thread_id == "test-conversation-id"
@pytest.mark.asyncio
async def test_run_stream_no_typing_activity(self, mock_copilot_client: MagicMock) -> None:
"""Test run_stream method with non-typing activity."""
agent = CopilotStudioAgent(client=mock_copilot_client)
@@ -268,7 +261,6 @@ class TestCopilotStudioAgent:
assert response_count == 0
@pytest.mark.asyncio
async def test_run_multiple_activities(self, mock_copilot_client: MagicMock) -> None:
"""Test run method with multiple message activities."""
agent = CopilotStudioAgent(client=mock_copilot_client)
@@ -296,7 +288,6 @@ class TestCopilotStudioAgent:
assert isinstance(response, AgentRunResponse)
assert len(response.messages) == 2
@pytest.mark.asyncio
async def test_run_list_of_messages(self, mock_copilot_client: MagicMock, mock_activity: MagicMock) -> None:
"""Test run method with list of messages."""
agent = CopilotStudioAgent(client=mock_copilot_client)
@@ -313,7 +304,6 @@ class TestCopilotStudioAgent:
assert isinstance(response, AgentRunResponse)
assert len(response.messages) == 1
@pytest.mark.asyncio
async def test_run_stream_start_conversation_failure(self, mock_copilot_client: MagicMock) -> None:
"""Test run_stream method when conversation start fails."""
agent = CopilotStudioAgent(client=mock_copilot_client)
@@ -227,17 +227,9 @@ class AgentFrameworkExecutor:
async def deserialize_thread(self, thread_id: str, agent_id: str, serialized_state: dict[str, Any]) -> bool:
"""Deserialize thread state from persistence."""
try:
# Create new thread
thread = AgentThread()
# Use AgentThread's built-in deserialization
from agent_framework._threads import deserialize_thread_state
await deserialize_thread_state(thread, serialized_state)
thread = await AgentThread.deserialize(serialized_state)
# Store the restored thread
self.thread_storage[thread_id] = thread
if agent_id not in self.agent_threads:
self.agent_threads[agent_id] = []
self.agent_threads[agent_id].append(thread_id)
@@ -20,7 +20,7 @@ def test_entities_dir():
return str(samples_dir.resolve())
@pytest.mark.asyncio
@pytest.mark.skip("Skipping while we fix discovery")
async def test_discover_agents(test_entities_dir):
"""Test that agent discovery works and returns valid agent entities."""
discovery = EntityDiscovery(test_entities_dir)
@@ -39,7 +39,6 @@ async def test_discover_agents(test_entities_dir):
assert hasattr(agent, "description"), "Agent should have description attribute"
@pytest.mark.asyncio
async def test_discover_workflows(test_entities_dir):
"""Test that workflow discovery works and returns valid workflow entities."""
discovery = EntityDiscovery(test_entities_dir)
@@ -58,7 +57,6 @@ async def test_discover_workflows(test_entities_dir):
assert hasattr(workflow, "description"), "Workflow should have description attribute"
@pytest.mark.asyncio
async def test_empty_directory():
"""Test discovery with empty directory."""
with tempfile.TemporaryDirectory() as temp_dir:
@@ -36,7 +36,6 @@ async def executor(test_entities_dir):
return executor
@pytest.mark.asyncio
async def test_executor_entity_discovery(executor):
"""Test executor entity discovery."""
entities = await executor.discover_entities()
@@ -55,7 +54,6 @@ async def test_executor_entity_discovery(executor):
assert entity.type in ["agent", "workflow"], "Entity should have valid type"
@pytest.mark.asyncio
async def test_executor_get_entity_info(executor):
"""Test getting entity info by ID."""
entities = await executor.discover_entities()
@@ -68,7 +66,6 @@ async def test_executor_get_entity_info(executor):
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="requires OpenAI API key")
@pytest.mark.asyncio
async def test_executor_sync_execution(executor):
"""Test synchronous execution."""
entities = await executor.discover_entities()
@@ -90,7 +87,7 @@ async def test_executor_sync_execution(executor):
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="requires OpenAI API key")
@pytest.mark.asyncio
@pytest.mark.skip("Skipping while we fix discovery")
async def test_executor_streaming_execution(executor):
"""Test streaming execution."""
entities = await executor.discover_entities()
@@ -121,14 +118,12 @@ async def test_executor_streaming_execution(executor):
assert len(text_events) > 0
@pytest.mark.asyncio
async def test_executor_invalid_entity_id(executor):
"""Test execution with invalid entity ID."""
with pytest.raises(EntityNotFoundError):
executor.get_entity_info("nonexistent_agent")
@pytest.mark.asyncio
async def test_executor_missing_entity_id(executor):
"""Test execution without entity ID."""
request = AgentFrameworkRequest(
@@ -56,7 +56,6 @@ def test_request() -> AgentFrameworkRequest:
)
@pytest.mark.asyncio
async def test_critical_isinstance_bug_detection(mapper: MessageMapper, test_request: AgentFrameworkRequest) -> None:
"""CRITICAL: Test that would have caught the isinstance vs hasattr bug."""
@@ -79,7 +78,6 @@ async def test_critical_isinstance_bug_detection(mapper: MessageMapper, test_req
assert all(event.type != "unknown" for event in events)
@pytest.mark.asyncio
async def test_text_content_mapping(mapper: MessageMapper, test_request: AgentFrameworkRequest) -> None:
"""Test TextContent mapping."""
content = create_test_content("text", text="Hello, clean test!")
@@ -92,7 +90,6 @@ async def test_text_content_mapping(mapper: MessageMapper, test_request: AgentFr
assert events[0].delta == "Hello, clean test!"
@pytest.mark.asyncio
async def test_function_call_mapping(mapper: MessageMapper, test_request: AgentFrameworkRequest) -> None:
"""Test FunctionCallContent mapping."""
content = create_test_content("function_call", name="test_func", arguments={"location": "TestCity"})
@@ -108,7 +105,6 @@ async def test_function_call_mapping(mapper: MessageMapper, test_request: AgentF
assert "TestCity" in full_json
@pytest.mark.asyncio
async def test_error_content_mapping(mapper: MessageMapper, test_request: AgentFrameworkRequest) -> None:
"""Test ErrorContent mapping."""
content = create_test_content("error", message="Test error", code="test_code")
@@ -122,7 +118,6 @@ async def test_error_content_mapping(mapper: MessageMapper, test_request: AgentF
assert events[0].code == "test_code"
@pytest.mark.asyncio
async def test_mixed_content_types(mapper: MessageMapper, test_request: AgentFrameworkRequest) -> None:
"""Test multiple content types together."""
contents = [
@@ -142,7 +137,6 @@ async def test_mixed_content_types(mapper: MessageMapper, test_request: AgentFra
assert "response.function_call_arguments.delta" in event_types
@pytest.mark.asyncio
async def test_unknown_content_fallback(mapper: MessageMapper, test_request: AgentFrameworkRequest) -> None:
"""Test graceful handling of unknown content types."""
# Test the fallback path directly since we can't create invalid AgentRunResponseUpdate
+1 -4
View File
@@ -20,7 +20,6 @@ def test_entities_dir():
return str(samples_dir.resolve())
@pytest.mark.asyncio
async def test_server_health_endpoint(test_entities_dir):
"""Test /health endpoint."""
server = DevServer(entities_dir=test_entities_dir)
@@ -32,7 +31,7 @@ async def test_server_health_endpoint(test_entities_dir):
# Framework name is now hardcoded since we simplified to single framework
@pytest.mark.asyncio
@pytest.mark.skip("Skipping while we fix discovery")
async def test_server_entities_endpoint(test_entities_dir):
"""Test /v1/entities endpoint."""
server = DevServer(entities_dir=test_entities_dir)
@@ -47,7 +46,6 @@ async def test_server_entities_endpoint(test_entities_dir):
assert "WeatherAgent" in agent_names
@pytest.mark.asyncio
async def test_server_execution_sync(test_entities_dir):
"""Test sync execution endpoint."""
server = DevServer(entities_dir=test_entities_dir)
@@ -68,7 +66,6 @@ async def test_server_execution_sync(test_entities_dir):
assert len(response.output) > 0
@pytest.mark.asyncio
async def test_server_execution_streaming(test_entities_dir):
"""Test streaming execution endpoint."""
server = DevServer(entities_dir=test_entities_dir)
+3 -1
View File
@@ -130,7 +130,9 @@ test-tau2 = "pytest tau2/tests --cov=agent_framework_lab_tau2 --cov-report=term-
[tool.pytest.ini_options]
pythonpath = ["."]
addopts = "--strict-markers --strict-config"
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
markers = [
"unit: marks tests as unit tests",
"integration: marks tests as integration tests",
]
]
@@ -5,13 +5,12 @@ from collections.abc import Sequence
from typing import Any
import tiktoken
from agent_framework._threads import ChatMessageList
from agent_framework._types import ChatMessage, Role
from agent_framework import ChatMessage, ChatMessageStore, Role
from loguru import logger
class SlidingWindowChatMessageList(ChatMessageList):
"""A token-aware sliding window implementation of ChatMessageList.
class SlidingWindowChatMessageStore(ChatMessageStore):
"""A token-aware sliding window implementation of ChatMessageStore.
Maintains two message lists: complete history and truncated window.
Automatically removes oldest messages when token limit is exceeded.
@@ -25,8 +24,8 @@ class SlidingWindowChatMessageList(ChatMessageList):
system_message: str | None = None,
tool_definitions: Any | None = None,
):
super().__init__(messages)
self._truncated_messages = self._messages.copy() # Separate truncated view
super().__init__(messages=messages)
self.truncated_messages = self.messages.copy()
self.max_tokens = max_tokens
self.system_message = system_message # Included in token count
self.tool_definitions = tool_definitions
@@ -36,25 +35,25 @@ class SlidingWindowChatMessageList(ChatMessageList):
async def add_messages(self, messages: Sequence[ChatMessage]) -> None:
await super().add_messages(messages)
self._truncated_messages = self._messages.copy()
self.truncated_messages = self.messages.copy()
self.truncate_messages()
async def list_messages(self) -> list[ChatMessage]:
"""Get the current list of messages, which may be truncated."""
return self._truncated_messages
return self.truncated_messages
async def list_all_messages(self) -> list[ChatMessage]:
"""Get all messages from the store including the truncated ones."""
return self._messages
return self.messages
def truncate_messages(self) -> None:
while len(self._truncated_messages) > 0 and self.get_token_count() > self.max_tokens:
while len(self.truncated_messages) > 0 and self.get_token_count() > self.max_tokens:
logger.warning("Messages exceed max tokens. Truncating oldest message.")
self._truncated_messages.pop(0)
self.truncated_messages.pop(0)
# Remove leading tool messages
while len(self._truncated_messages) > 0 and self._truncated_messages[0].role == Role.TOOL:
while len(self.truncated_messages) > 0 and self.truncated_messages[0].role == Role.TOOL:
logger.warning("Removing leading tool message because tool result cannot be the first message.")
self._truncated_messages.pop(0)
self.truncated_messages.pop(0)
def get_token_count(self) -> int:
"""Estimate token count for a list of messages using tiktoken.
@@ -72,7 +71,7 @@ class SlidingWindowChatMessageList(ChatMessageList):
total_tokens += len(self.encoding.encode(self.system_message))
total_tokens += 4 # Extra tokens for system message formatting
for msg in self._truncated_messages:
for msg in self.truncated_messages:
# Add 4 tokens per message for role, formatting, etc.
total_tokens += 4
@@ -29,7 +29,7 @@ from tau2.user.user_simulator import ( # type: ignore[import-untyped]
from tau2.utils.utils import get_now # type: ignore[import-untyped]
from ._message_utils import flip_messages, log_messages
from ._sliding_window import SlidingWindowChatMessageList
from ._sliding_window import SlidingWindowChatMessageStore
from ._tau2_utils import convert_agent_framework_messages_to_tau2_messages, convert_tau2_tool_to_ai_function
# Agent instructions matching tau2's LLMAgent
@@ -196,7 +196,7 @@ class TaskRunner:
instructions=assistant_system_prompt,
tools=ai_functions, # type: ignore
temperature=self.assistant_sampling_temperature,
chat_message_store_factory=lambda: SlidingWindowChatMessageList(
chat_message_store_factory=lambda: SlidingWindowChatMessageStore(
system_message=assistant_system_prompt,
tool_definitions=[tool.openai_schema for tool in tools],
max_tokens=self.assistant_window_size,
@@ -352,7 +352,7 @@ class TaskRunner:
# 2. The assistant's message store (not just the truncated window)
# 3. The final user message (if any)
assistant_executor = cast(AgentExecutor, self._assistant_executor)
message_store = cast(SlidingWindowChatMessageList, assistant_executor._agent_thread.message_store)
message_store = cast(SlidingWindowChatMessageStore, assistant_executor._agent_thread.message_store)
full_conversation = [first_message] + await message_store.list_all_messages()
if self._final_user_message is not None:
full_conversation.extend(self._final_user_message)
@@ -4,20 +4,19 @@
from unittest.mock import patch
import pytest
from agent_framework._types import ChatMessage, FunctionCallContent, FunctionResultContent, Role, TextContent
from agent_framework_lab_tau2._sliding_window import SlidingWindowChatMessageList
from agent_framework_lab_tau2._sliding_window import SlidingWindowChatMessageStore
def test_initialization_empty():
"""Test initializing with no messages."""
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
assert sliding_window.max_tokens == 1000
assert sliding_window.system_message is None
assert sliding_window.tool_definitions is None
assert len(sliding_window._messages) == 0
assert len(sliding_window._truncated_messages) == 0
assert len(sliding_window.messages) == 0
assert len(sliding_window.truncated_messages) == 0
def test_initialization_with_parameters():
@@ -25,7 +24,7 @@ def test_initialization_with_parameters():
system_msg = "You are a helpful assistant"
tool_defs = [{"name": "test_tool", "description": "A test tool"}]
sliding_window = SlidingWindowChatMessageList(
sliding_window = SlidingWindowChatMessageStore(
max_tokens=2000, system_message=system_msg, tool_definitions=tool_defs
)
@@ -41,16 +40,15 @@ def test_initialization_with_messages():
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Hi there!")]),
]
sliding_window = SlidingWindowChatMessageList(messages=messages, max_tokens=1000)
sliding_window = SlidingWindowChatMessageStore(messages=messages, max_tokens=1000)
assert len(sliding_window._messages) == 2
assert len(sliding_window._truncated_messages) == 2
assert len(sliding_window.messages) == 2
assert len(sliding_window.truncated_messages) == 2
@pytest.mark.asyncio
async def test_add_messages_simple():
"""Test adding messages without truncation."""
sliding_window = SlidingWindowChatMessageList(max_tokens=10000) # Large limit
sliding_window = SlidingWindowChatMessageStore(max_tokens=10000) # Large limit
new_messages = [
ChatMessage(role=Role.USER, contents=[TextContent(text="What's the weather?")]),
@@ -65,10 +63,9 @@ async def test_add_messages_simple():
assert messages[1].text == "I can help with that."
@pytest.mark.asyncio
async def test_list_all_messages_vs_list_messages():
"""Test difference between list_all_messages and list_messages."""
sliding_window = SlidingWindowChatMessageList(max_tokens=50) # Small limit to force truncation
sliding_window = SlidingWindowChatMessageStore(max_tokens=50) # Small limit to force truncation
# Add many messages to trigger truncation
messages = [
@@ -89,8 +86,8 @@ async def test_list_all_messages_vs_list_messages():
def test_get_token_count_basic():
"""Test basic token counting."""
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
sliding_window._truncated_messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])]
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
sliding_window.truncated_messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])]
token_count = sliding_window.get_token_count()
@@ -101,13 +98,13 @@ def test_get_token_count_basic():
def test_get_token_count_with_system_message():
"""Test token counting includes system message."""
system_msg = "You are a helpful assistant"
sliding_window = SlidingWindowChatMessageList(max_tokens=1000, system_message=system_msg)
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000, system_message=system_msg)
# Without messages
token_count_empty = sliding_window.get_token_count()
# Add a message
sliding_window._truncated_messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])]
sliding_window.truncated_messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])]
token_count_with_message = sliding_window.get_token_count()
# With message should be more tokens
@@ -119,8 +116,8 @@ def test_get_token_count_function_call():
"""Test token counting with function calls."""
function_call = FunctionCallContent(call_id="call_123", name="test_function", arguments={"param": "value"})
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
sliding_window._truncated_messages = [ChatMessage(role=Role.ASSISTANT, contents=[function_call])]
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
sliding_window.truncated_messages = [ChatMessage(role=Role.ASSISTANT, contents=[function_call])]
token_count = sliding_window.get_token_count()
assert token_count > 0
@@ -130,8 +127,8 @@ def test_get_token_count_function_result():
"""Test token counting with function results."""
function_result = FunctionResultContent(call_id="call_123", result={"success": True, "data": "result"})
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
sliding_window._truncated_messages = [ChatMessage(role=Role.TOOL, contents=[function_result])]
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
sliding_window.truncated_messages = [ChatMessage(role=Role.TOOL, contents=[function_result])]
token_count = sliding_window.get_token_count()
assert token_count > 0
@@ -140,7 +137,7 @@ def test_get_token_count_function_result():
@patch("agent_framework_lab_tau2._sliding_window.logger")
def test_truncate_messages_removes_old_messages(mock_logger):
"""Test that truncation removes old messages when token limit exceeded."""
sliding_window = SlidingWindowChatMessageList(max_tokens=20) # Very small limit
sliding_window = SlidingWindowChatMessageStore(max_tokens=20) # Very small limit
# Create messages that will exceed the limit
messages = [
@@ -155,11 +152,11 @@ def test_truncate_messages_removes_old_messages(mock_logger):
ChatMessage(role=Role.USER, contents=[TextContent(text="Short msg")]),
]
sliding_window._truncated_messages = messages.copy()
sliding_window.truncated_messages = messages.copy()
sliding_window.truncate_messages()
# Should have fewer messages after truncation
assert len(sliding_window._truncated_messages) < len(messages)
assert len(sliding_window.truncated_messages) < len(messages)
# Should have logged warnings
assert mock_logger.warning.called
@@ -168,18 +165,18 @@ def test_truncate_messages_removes_old_messages(mock_logger):
@patch("agent_framework_lab_tau2._sliding_window.logger")
def test_truncate_messages_removes_leading_tool_messages(mock_logger):
"""Test that truncation removes leading tool messages."""
sliding_window = SlidingWindowChatMessageList(max_tokens=10000) # Large limit
sliding_window = SlidingWindowChatMessageStore(max_tokens=10000) # Large limit
# Create messages starting with tool message
tool_message = ChatMessage(role=Role.TOOL, contents=[FunctionResultContent(call_id="call_123", result="result")])
user_message = ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])
sliding_window._truncated_messages = [tool_message, user_message]
sliding_window.truncated_messages = [tool_message, user_message]
sliding_window.truncate_messages()
# Tool message should be removed from the beginning
assert len(sliding_window._truncated_messages) == 1
assert sliding_window._truncated_messages[0].role == Role.USER
assert len(sliding_window.truncated_messages) == 1
assert sliding_window.truncated_messages[0].role == Role.USER
# Should have logged warning about removing tool message
mock_logger.warning.assert_called()
@@ -187,7 +184,7 @@ def test_truncate_messages_removes_leading_tool_messages(mock_logger):
def test_estimate_any_object_token_count_dict():
"""Test token counting for dictionary objects."""
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
test_dict = {"key": "value", "number": 42}
token_count = sliding_window.estimate_any_object_token_count(test_dict)
@@ -197,7 +194,7 @@ def test_estimate_any_object_token_count_dict():
def test_estimate_any_object_token_count_string():
"""Test token counting for string objects."""
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
test_string = "This is a test string"
token_count = sliding_window.estimate_any_object_token_count(test_string)
@@ -207,7 +204,7 @@ def test_estimate_any_object_token_count_string():
def test_estimate_any_object_token_count_non_serializable():
"""Test token counting for non-JSON-serializable objects."""
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
# Create an object that can't be JSON serialized
class CustomObject:
@@ -221,10 +218,9 @@ def test_estimate_any_object_token_count_non_serializable():
assert token_count > 0
@pytest.mark.asyncio
async def test_real_world_scenario():
"""Test a realistic conversation scenario."""
sliding_window = SlidingWindowChatMessageList(
sliding_window = SlidingWindowChatMessageStore(
max_tokens=30,
system_message="You are a helpful assistant", # Moderate limit
)
+227 -147
View File
@@ -4,18 +4,19 @@ import inspect
import sys
from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, Sequence
from contextlib import AbstractAsyncContextManager, AsyncExitStack
from typing import Any, ClassVar, Literal, Protocol, TypeVar, runtime_checkable
from copy import copy
from itertools import chain
from typing import Any, ClassVar, Literal, Protocol, TypeVar, cast, runtime_checkable
from uuid import uuid4
from pydantic import BaseModel, Field, PrivateAttr, create_model
from pydantic import BaseModel, Field, create_model
from ._clients import BaseChatClient, ChatClientProtocol
from ._logging import get_logger
from ._mcp import MCPTool
from ._memory import AggregateContextProvider, Context, ContextProvider
from ._middleware import Middleware, use_agent_middleware
from ._pydantic import AFBaseModel
from ._threads import AgentThread, ChatMessageStore, deserialize_thread_state, thread_on_new_messages
from ._threads import AgentThread, ChatMessageStoreProtocol
from ._tools import FUNCTION_INVOKING_CHAT_CLIENT_MARKER, AIFunction, ToolProtocol
from ._types import (
AgentRunResponse,
@@ -30,6 +31,10 @@ from ._types import (
from .exceptions import AgentExecutionException
from .observability import use_agent_observability
if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
else:
@@ -122,7 +127,7 @@ class AgentProtocol(Protocol):
"""
...
def get_new_thread(self) -> AgentThread:
def get_new_thread(self, **kwargs: Any) -> AgentThread:
"""Creates a new conversation thread for the agent."""
...
@@ -130,7 +135,7 @@ class AgentProtocol(Protocol):
# region BaseAgent
class BaseAgent(AFBaseModel):
class BaseAgent:
"""Base class for all Agent Framework agents.
Attributes:
@@ -143,18 +148,55 @@ class BaseAgent(AFBaseModel):
middleware: List of middleware to intercept agent and function invocations.
"""
id: str = Field(default_factory=lambda: str(uuid4()))
name: str | None = None
description: str | None = None
context_providers: AggregateContextProvider | None = None
middleware: Middleware | list[Middleware] | None = None
def __init__(
self,
id: str | None = None,
name: str | None = None,
description: str | None = None,
context_providers: ContextProvider | Sequence[ContextProvider] | None = None,
middleware: Middleware | Sequence[Middleware] | None = None,
**kwargs: Any,
) -> None:
"""Base class for all Agent Framework agents.
Args:
id: The unique identifier of the agent If no id is provided,
a new UUID will be generated.
name: The name of the agent, can be None.
description: The description of the agent.
display_name: The display name of the agent, which is either the name or id.
context_providers: The collection of multiple context providers to include during agent invocation.
middleware: List of middleware to intercept agent and function invocations.
kwargs: will be stored in `additional_properties`
"""
if id is None:
id = str(uuid4())
self.id = id
self.name = name
self.description = description
self.context_provider = self._prepare_context_providers(context_providers)
if middleware is None or isinstance(middleware, Sequence):
self.middleware: list[Middleware] | None = cast(list[Middleware], middleware) if middleware else None
else:
self.middleware = [middleware]
self.additional_properties = kwargs
async def _notify_thread_of_new_messages(
self, thread: AgentThread, new_messages: ChatMessage | Sequence[ChatMessage]
self,
thread: AgentThread,
input_messages: ChatMessage | Sequence[ChatMessage],
response_messages: ChatMessage | Sequence[ChatMessage],
) -> None:
"""Notify the thread of new messages."""
if isinstance(new_messages, ChatMessage) or len(new_messages) > 0:
await thread_on_new_messages(thread, new_messages)
"""Notify the thread of new messages.
This also calls the invoked method of a potential context provider on the thread.
"""
if isinstance(input_messages, ChatMessage) or len(input_messages) > 0:
await thread.on_new_messages(input_messages)
if isinstance(response_messages, ChatMessage) or len(response_messages) > 0:
await thread.on_new_messages(response_messages)
if thread.context_provider:
await thread.context_provider.invoked(input_messages, response_messages)
@property
def display_name(self) -> str:
@@ -164,14 +206,14 @@ class BaseAgent(AFBaseModel):
"""
return self.name or self.id
def get_new_thread(self) -> AgentThread:
def get_new_thread(self, **kwargs: Any) -> AgentThread:
"""Returns AgentThread instance that is compatible with the agent."""
return AgentThread()
return AgentThread(**kwargs, context_provider=self.context_provider)
async def deserialize_thread(self, serialized_thread: Any, **kwargs: Any) -> AgentThread:
"""Deserializes the thread."""
thread: AgentThread = self.get_new_thread()
await deserialize_thread_state(thread, serialized_thread, **kwargs)
await thread.deserialize(serialized_thread, **kwargs)
return thread
def as_tool(
@@ -236,7 +278,12 @@ class BaseAgent(AFBaseModel):
# Create final text from accumulated updates
return AgentRunResponse.from_agent_run_response_updates(response_updates).text
return AIFunction(name=tool_name, description=tool_description, func=agent_wrapper, input_model=input_model)
return AIFunction(
name=tool_name,
description=tool_description,
func=agent_wrapper,
input_model=input_model,
)
def _normalize_messages(
self,
@@ -253,6 +300,18 @@ class BaseAgent(AFBaseModel):
return [ChatMessage(role=Role.USER, text=msg) if isinstance(msg, str) else msg for msg in messages]
def _prepare_context_providers(
self,
context_providers: ContextProvider | Sequence[ContextProvider] | None = None,
) -> AggregateContextProvider | None:
if not context_providers:
return None
if isinstance(context_providers, AggregateContextProvider):
return context_providers
return AggregateContextProvider(context_providers)
# region ChatAgent
@@ -263,12 +322,6 @@ class ChatAgent(BaseAgent):
"""A Chat Client Agent."""
AGENT_SYSTEM_NAME: ClassVar[str] = "microsoft.agent_framework"
chat_client: ChatClientProtocol
instructions: str | None = None
chat_options: ChatOptions
chat_message_store_factory: Callable[[], ChatMessageStore] | None = None
_local_mcp_tools: list[MCPTool] = PrivateAttr(default_factory=list) # type: ignore[reportUnknownVariableType]
_async_exit_stack: AsyncExitStack = PrivateAttr(default_factory=AsyncExitStack)
def __init__(
self,
@@ -278,6 +331,9 @@ class ChatAgent(BaseAgent):
id: str | None = None,
name: str | None = None,
description: str | None = None,
chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None,
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
middleware: Middleware | list[Middleware] | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str | int, float] | None = None,
max_tokens: int | None = None,
@@ -297,10 +353,7 @@ class ChatAgent(BaseAgent):
| None = None,
top_p: float | None = None,
user: str | None = None,
additional_properties: dict[str, Any] | None = None,
chat_message_store_factory: Callable[[], ChatMessageStore] | None = None,
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
middleware: Middleware | list[Middleware] | None = None,
request_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""Create a ChatAgent.
@@ -317,6 +370,10 @@ class ChatAgent(BaseAgent):
id: The unique identifier for the agent, will be created automatically if not provided.
name: The name of the agent.
description: A brief description of the agent's purpose.
chat_message_store_factory: factory function to create an instance of ChatMessageStoreProtocol.
If not provided, the default in-memory store will be used.
context_providers: The collection of multiple context providers to include during agent invocation.
middleware: List of middleware to intercept agent and function invocations.
frequency_penalty: the frequency penalty to use.
logit_bias: the logit bias to use.
max_tokens: The maximum number of tokens to generate.
@@ -332,64 +389,54 @@ class ChatAgent(BaseAgent):
tools: the tools to use for the request.
top_p: the nucleus sampling probability to use.
user: the user to associate with the request.
additional_properties: additional properties to include in the request.
chat_message_store_factory: factory function to create an instance of ChatMessageStore. If not provided,
the default in-memory store will be used.
context_providers: The collection of multiple context providers to include during agent invocation.
middleware: List of middleware to intercept agent and function invocations.
kwargs: any additional keyword arguments.
Unused, can be used by subclasses of this Agent.
request_kwargs: a dictionary of other values that will be passed through
to the chat_client `get_response` and `get_streaming_response` methods.
kwargs: any additional keyword arguments. Will be stored as `additional_properties`
"""
if not hasattr(chat_client, FUNCTION_INVOKING_CHAT_CLIENT_MARKER) and isinstance(chat_client, BaseChatClient):
logger.warning(
"The provided chat client does not support function invoking, this might limit agent capabilities."
)
kwargs.update(additional_properties or {})
aggregate_context_providers = self._prepare_context_providers(context_providers)
super().__init__(
id=id,
name=name,
description=description,
context_providers=context_providers,
middleware=middleware,
**kwargs,
)
self.chat_client = chat_client
self.chat_message_store_factory = chat_message_store_factory
# We ignore the MCP Servers here and store them separately,
# we add their functions to the tools list at runtime
normalized_tools = [] if tools is None else tools if isinstance(tools, list) else [tools]
local_mcp_tools = [tool for tool in normalized_tools if isinstance(tool, MCPTool)]
final_tools = [tool for tool in normalized_tools if not isinstance(tool, MCPTool)]
args: dict[str, Any] = {
"chat_client": chat_client,
"chat_message_store_factory": chat_message_store_factory,
"context_providers": aggregate_context_providers,
"middleware": middleware,
"chat_options": ChatOptions(
ai_model_id=model,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
max_tokens=max_tokens,
metadata=metadata,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
store=store,
temperature=temperature,
tool_choice=tool_choice,
tools=final_tools, # type: ignore[reportArgumentType]
top_p=top_p,
user=user,
additional_properties=kwargs,
),
}
if instructions is not None:
args["instructions"] = instructions
if name is not None:
args["name"] = name
if description is not None:
args["description"] = description
if id is not None:
args["id"] = id
super().__init__(**args)
normalized_tools: list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] = ( # type:ignore[reportUnknownVariableType]
[] if tools is None else tools if isinstance(tools, list) else [tools]
)
self._local_mcp_tools = [tool for tool in normalized_tools if isinstance(tool, MCPTool)]
agent_tools = [tool for tool in normalized_tools if not isinstance(tool, MCPTool)]
self.chat_options = ChatOptions(
ai_model_id=model,
frequency_penalty=frequency_penalty,
instructions=instructions,
logit_bias=logit_bias,
max_tokens=max_tokens,
metadata=metadata,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
store=store,
temperature=temperature,
tool_choice=tool_choice,
tools=agent_tools, # type: ignore[reportArgumentType]
top_p=top_p,
user=user,
additional_properties=request_kwargs or {}, # type: ignore
)
self._async_exit_stack = AsyncExitStack()
self._update_agent_name()
self._local_mcp_tools = local_mcp_tools # type: ignore[assignment]
async def __aenter__(self) -> "Self":
"""Async context manager entry.
@@ -399,16 +446,17 @@ class ChatAgent(BaseAgent):
This list might be extended in the future.
"""
context_managers = [self.chat_client, *self._local_mcp_tools]
if self.context_providers:
context_managers.append(self.context_providers)
for context_manager in context_managers:
for context_manager in chain([self.chat_client], self._local_mcp_tools):
if isinstance(context_manager, AbstractAsyncContextManager):
await self._async_exit_stack.enter_async_context(context_manager)
return self
async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None:
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> None:
"""Async context manager exit.
Close the async exit stack to ensure all context managers are exited properly.
@@ -443,11 +491,9 @@ class ChatAgent(BaseAgent):
temperature: float | None = None,
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
tools: ToolProtocol
| list[ToolProtocol]
| Callable[..., Any]
| list[Callable[..., Any]]
| MutableMapping[str, Any]
| list[MutableMapping[str, Any]]
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
| None = None,
top_p: float | None = None,
user: str | None = None,
@@ -485,28 +531,32 @@ class ChatAgent(BaseAgent):
will only be passed to functions that are called.
"""
input_messages = self._normalize_messages(messages)
context = await self.context_providers.model_invoking(input_messages) if self.context_providers else None
thread, thread_messages = await self._prepare_thread_and_messages(
thread=thread, context=context, input_messages=input_messages
thread, run_chat_options, thread_messages = await self._prepare_thread_and_messages(
thread=thread, input_messages=input_messages
)
normalized_tools: list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] = ( # type:ignore[reportUnknownVariableType]
[] if tools is None else tools if isinstance(tools, list) else [tools]
)
agent_name = self._get_agent_name()
# Resolve final tool list (runtime provided tools + local MCP server tools)
final_tools: list[ToolProtocol | Callable[..., Any] | dict[str, Any]] = []
# Normalize tools argument to a list without mutating the original parameter
normalized_tools = [] if tools is None else tools if isinstance(tools, list) else [tools]
for tool in normalized_tools:
if isinstance(tool, MCPTool):
if not tool.is_connected:
await self._async_exit_stack.enter_async_context(tool)
final_tools.extend(tool.functions) # type: ignore
else:
final_tools.append(tool) # type: ignore
for mcp_server in self._local_mcp_tools:
if not mcp_server.is_connected:
await self._async_exit_stack.enter_async_context(mcp_server)
final_tools.extend(mcp_server.functions)
response = await self.chat_client.get_response(
messages=thread_messages,
chat_options=self.chat_options
chat_options=run_chat_options
& ChatOptions(
ai_model_id=model,
conversation_id=thread.service_thread_id,
@@ -529,7 +579,7 @@ class ChatAgent(BaseAgent):
**kwargs,
)
self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)
await self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)
# Ensure that the author name is set for each message in the response.
for message in response.messages:
@@ -538,13 +588,7 @@ class ChatAgent(BaseAgent):
# Only notify the thread of new messages if the chatResponse was successful
# to avoid inconsistent messages state in the thread.
await self._notify_thread_of_new_messages(thread, input_messages)
await self._notify_thread_of_new_messages(thread, response.messages)
if self.context_providers:
await self.context_providers.thread_created(response.conversation_id)
await self.context_providers.messages_adding(thread.service_thread_id, input_messages + response.messages)
await self._notify_thread_of_new_messages(thread, input_messages, response.messages)
return AgentRunResponse(
messages=response.messages,
response_id=response.response_id,
@@ -614,29 +658,34 @@ class ChatAgent(BaseAgent):
"""
input_messages = self._normalize_messages(messages)
context = await self.context_providers.model_invoking(input_messages) if self.context_providers else None
thread, thread_messages = await self._prepare_thread_and_messages(
thread=thread, context=context, input_messages=input_messages
thread, run_chat_options, thread_messages = await self._prepare_thread_and_messages(
thread=thread, input_messages=input_messages
)
agent_name = self._get_agent_name()
response_updates: list[ChatResponseUpdate] = []
# Resolve final tool list (runtime provided tools + local MCP server tools)
final_tools: list[ToolProtocol | MutableMapping[str, Any] | Callable[..., Any]] = []
normalized_tools: list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] = ( # type: ignore[reportUnknownVariableType]
[] if tools is None else tools if isinstance(tools, list) else [tools]
)
# Normalize tools argument to a list without mutating the original parameter
normalized_tools = [] if tools is None else tools if isinstance(tools, list) else [tools]
for tool in normalized_tools:
if isinstance(tool, MCPTool):
if not tool.is_connected:
await self._async_exit_stack.enter_async_context(tool)
final_tools.extend(tool.functions) # type: ignore
else:
final_tools.append(tool)
for mcp_server in self._local_mcp_tools:
if not mcp_server.is_connected:
await self._async_exit_stack.enter_async_context(mcp_server)
final_tools.extend(mcp_server.functions)
async for update in self.chat_client.get_streaming_response(
messages=thread_messages,
chat_options=self.chat_options
chat_options=run_chat_options
& ChatOptions(
conversation_id=thread.service_thread_id,
frequency_penalty=frequency_penalty,
@@ -675,27 +724,46 @@ class ChatAgent(BaseAgent):
)
response = ChatResponse.from_chat_response_updates(response_updates)
await self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)
await self._notify_thread_of_new_messages(thread, input_messages, response.messages)
self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)
@override
def get_new_thread(
self,
*,
service_thread_id: str | None = None,
**kwargs: Any,
) -> AgentThread:
"""Get a new conversation thread for the agent.
# Only notify the thread of new messages if the chatResponse was successful
# to avoid inconsistent messages state in the thread.
await self._notify_thread_of_new_messages(thread, input_messages)
await self._notify_thread_of_new_messages(thread, response.messages)
If you supply a service_thread_id, the thread will be marked as service managed.
if self.context_providers:
await self.context_providers.thread_created(response.conversation_id)
await self.context_providers.messages_adding(thread.service_thread_id, input_messages + response.messages)
If you don't supply a service_thread_id but have a chat_message_store_factory configured on the agent,
that factory will be used to create a message store for the thread and the thread will be
managed locally.
def get_new_thread(self) -> AgentThread:
message_store: ChatMessageStore | None = None
When neither is present, the thread will be created without a service ID or message store,
this will be updated based on usage, when you run the agent with this thread.
If you run with store=True, the response will respond with a thread_id and that will be set.
Otherwise a messages store is created from the default factory.
if self.chat_message_store_factory:
message_store = self.chat_message_store_factory()
Args:
service_thread_id: Optional service managed thread ID.
kwargs: not used at present.
"""
if service_thread_id is not None:
return AgentThread(
service_thread_id=service_thread_id,
context_provider=self.context_provider,
)
if self.chat_message_store_factory is not None:
return AgentThread(
message_store=self.chat_message_store_factory(),
context_provider=self.context_provider,
)
return AgentThread(context_provider=self.context_provider)
return AgentThread() if message_store is None else AgentThread(message_store=message_store)
def _update_thread_with_type_and_conversation_id(
async def _update_thread_with_type_and_conversation_id(
self, thread: AgentThread, response_conversation_id: str | None
) -> None:
"""Update thread with storage type and conversation ID.
@@ -719,6 +787,8 @@ class ChatAgent(BaseAgent):
# If we got a conversation id back from the chat client, it means that the service
# supports server side thread storage so we should update the thread with the new id.
thread.service_thread_id = response_conversation_id
if thread.context_provider:
await thread.context_provider.thread_created(thread.service_thread_id)
elif thread.message_store is None and self.chat_message_store_factory is not None:
# If the service doesn't use service side thread storage (i.e. we got no id back from invocation), and
# the thread has no message_store yet, and we have a custom messages store, we should update the thread
@@ -729,14 +799,14 @@ class ChatAgent(BaseAgent):
self,
*,
thread: AgentThread | None,
context: Context | None,
input_messages: list[ChatMessage] | None = None,
) -> tuple[AgentThread, list[ChatMessage]]:
) -> tuple[AgentThread, ChatOptions, list[ChatMessage]]:
"""Prepare the messages for agent execution.
Also updates the chat_options of the agent, with
Args:
thread: The conversation thread.
context: Context to include in messages.
input_messages: Messages to process.
Returns:
@@ -745,32 +815,42 @@ class ChatAgent(BaseAgent):
Raises:
AgentExecutionException: If the thread is not of the expected type.
"""
chat_options = copy(self.chat_options) if self.chat_options else ChatOptions()
thread = thread or self.get_new_thread()
messages: list[ChatMessage] = []
if self.instructions:
messages.append(ChatMessage(role=Role.SYSTEM, text=self.instructions))
if context and context.contents:
messages.append(ChatMessage(role=Role.SYSTEM, contents=context.contents))
if thread.service_thread_id and thread.context_provider:
await thread.context_provider.thread_created(thread.service_thread_id)
thread_messages: list[ChatMessage] = []
if thread.message_store:
messages.extend(await thread.message_store.list_messages() or [])
messages.extend(input_messages or [])
return thread, messages
thread_messages.extend(await thread.message_store.list_messages() or [])
context: Context | None = None
if self.context_provider:
async with self.context_provider:
context = await self.context_provider.invoking(input_messages or [])
if context:
if context.messages:
thread_messages.extend(context.messages)
if context.tools:
if chat_options.tools is not None:
chat_options.tools.extend(context.tools)
else:
chat_options.tools = list(context.tools)
if context.instructions:
chat_options.instructions = (
context.instructions
if not chat_options.instructions
else f"{chat_options.instructions}\n{context.instructions}"
)
thread_messages.extend(input_messages or [])
if (
thread.service_thread_id
and chat_options.conversation_id
and thread.service_thread_id != chat_options.conversation_id
):
raise AgentExecutionException(
"The conversation_id set on the agent is different from the one set on the thread, "
"only one ID can be used for a run."
)
return thread, chat_options, thread_messages
def _get_agent_name(self) -> str:
return self.name or "UnnamedAgent"
def _prepare_context_providers(
self,
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
) -> AggregateContextProvider | None:
if not context_providers:
return None
if isinstance(context_providers, AggregateContextProvider):
return context_providers
if isinstance(context_providers, ContextProvider):
return AggregateContextProvider([context_providers])
return AggregateContextProvider(context_providers)
@@ -18,7 +18,7 @@ from ._middleware import (
Middleware,
)
from ._pydantic import AFBaseModel
from ._threads import ChatMessageStore
from ._threads import ChatMessageStoreProtocol
from ._tools import ToolProtocol
from ._types import (
ChatMessage,
@@ -207,9 +207,12 @@ class BaseChatClient(AFBaseModel, ABC):
# This is used for OTel setup, should be overridden in subclasses
def prepare_messages(
self, messages: str | ChatMessage | list[str] | list[ChatMessage]
self, messages: str | ChatMessage | list[str] | list[ChatMessage], chat_options: ChatOptions
) -> MutableSequence[ChatMessage]:
"""Turn the allowed input into a list of chat messages."""
if chat_options.instructions:
system_msg = ChatMessage(role="system", text=chat_options.instructions)
return [system_msg, *prepare_messages(messages)]
return prepare_messages(messages)
def _filter_internal_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
@@ -368,7 +371,7 @@ class BaseChatClient(AFBaseModel, ABC):
user=user,
additional_properties=additional_properties or {},
)
prepped_messages = self.prepare_messages(messages)
prepped_messages = self.prepare_messages(messages, chat_options)
self._prepare_tool_choice(chat_options=chat_options)
filtered_kwargs = self._filter_internal_kwargs(kwargs)
@@ -449,7 +452,7 @@ class BaseChatClient(AFBaseModel, ABC):
user=user,
additional_properties=additional_properties or {},
)
prepped_messages = self.prepare_messages(messages)
prepped_messages = self.prepare_messages(messages, chat_options)
self._prepare_tool_choice(chat_options=chat_options)
filtered_kwargs = self._filter_internal_kwargs(kwargs)
@@ -492,7 +495,7 @@ class BaseChatClient(AFBaseModel, ABC):
| MutableMapping[str, Any]
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
| None = None,
chat_message_store_factory: Callable[[], ChatMessageStore] | None = None,
chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None,
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
middleware: Middleware | list[Middleware] | None = None,
**kwargs: Any,
@@ -503,8 +506,8 @@ class BaseChatClient(AFBaseModel, ABC):
name: The name of the agent.
instructions: The instructions for the agent.
tools: Optional list of tools to associate with the agent.
chat_message_store_factory: Factory function to create an instance of ChatMessageStore. If not provided,
the default in-memory store will be used.
chat_message_store_factory: Factory function to create an instance of ChatMessageStoreProtocol.
If not provided, the default in-memory store will be used.
context_providers: Context providers to include during agent invocation.
middleware: List of middleware to intercept agent and function invocations.
**kwargs: Additional keyword arguments to pass to the agent.
@@ -246,6 +246,7 @@ class MCPTool:
self.request_timeout = request_timeout
self.chat_client = chat_client
self.functions: list[AIFunction[Any, Any]] = []
self.is_connected: bool = False
def __str__(self) -> str:
return f"MCPTool(name={self.name}, description={self.description})"
@@ -282,6 +283,7 @@ class MCPTool:
# If the session is not initialized, we need to reinitialize it
await self.session.initialize()
logger.debug("Connected to MCP server: %s", self.session)
self.is_connected = True
if self.load_tools_flag:
await self.load_tools()
if self.load_prompts_flag:
@@ -434,6 +436,7 @@ class MCPTool:
"""Disconnect from the MCP server."""
await self._exit_stack.aclose()
self.session = None
self.is_connected = False
@abstractmethod
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
+83 -38
View File
@@ -6,11 +6,15 @@ from abc import ABC, abstractmethod
from collections.abc import MutableSequence, Sequence
from contextlib import AsyncExitStack
from types import TracebackType
from typing import ClassVar
from typing import Any, Final, cast
from ._pydantic import AFBaseModel
from ._types import ChatMessage, Contents
from ._tools import ToolProtocol
from ._types import ChatMessage
if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
else:
@@ -21,7 +25,7 @@ else:
__all__ = ["AggregateContextProvider", "Context", "ContextProvider"]
class Context(AFBaseModel):
class Context:
"""A class containing any context that should be provided to the AI model as supplied by an ContextProvider.
Each ContextProvider has the ability to provide its own context for each invocation.
@@ -30,28 +34,39 @@ class Context(AFBaseModel):
This context is per invocation, and will not be stored as part of the chat history.
"""
contents: list[Contents] | None = None
"""
Any content to pass to the AI model in addition to any other prompts
that it may already have (in the case of an agent), or chat history that may already exist.
"""
def __init__(
self,
instructions: str | None = None,
messages: Sequence[ChatMessage] | None = None,
tools: Sequence[ToolProtocol] | None = None,
):
"""Create a new Context object.
Args:
instructions: Instructions to provide to the AI model.
messages: a list of messages.
tools: a list of tools to provide to this run.
"""
self.instructions = instructions
self.messages: Sequence[ChatMessage] = messages or []
self.tools: Sequence[ToolProtocol] = tools or []
# region ContextProvider
class ContextProvider(AFBaseModel, ABC):
class ContextProvider(ABC):
"""Base class for all context providers.
A context provider is a component that can be used to enhance the AI's context management.
It can listen to changes in the conversation and provide additional context to the AI model
just before invocation.
It also has a default memory prompt that can be used by all providers.
"""
# Default prompt to be used by all context providers when assembling memories/instructions
DEFAULT_CONTEXT_PROMPT: ClassVar[str] = (
"## Memories\nConsider the following memories when answering user questions:"
)
DEFAULT_CONTEXT_PROMPT: Final[str] = "## Memories\nConsider the following memories when answering user questions:"
async def thread_created(self, thread_id: str | None) -> None:
"""Called just after a new thread is created.
@@ -65,19 +80,27 @@ class ContextProvider(AFBaseModel, ABC):
"""
pass
async def messages_adding(self, thread_id: str | None, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
"""Called just before messages are added to the chat by any participant.
async def invoked(
self,
request_messages: ChatMessage | Sequence[ChatMessage],
response_messages: ChatMessage | Sequence[ChatMessage] | None = None,
invoke_exception: Exception | None = None,
**kwargs: Any,
) -> None:
"""Called after the agent has received a response from the underlying inference service.
Inheritors can use this method to update their context based on new messages.
You can inspect the request and response messages, and update the state of the context provider
Args:
thread_id: The ID of the thread for the new message.
new_messages: New messages to add.
request_messages: messages that were sent to the model/agent
response_messages: messages that were returned by the model/agent
invoke_exception: exception that was thrown, if any.
kwargs: not used at present.
"""
pass
@abstractmethod
async def model_invoking(self, messages: ChatMessage | MutableSequence[ChatMessage]) -> Context:
async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], **kwargs: Any) -> Context:
"""Called just before the Model/Agent/etc. is invoked.
Implementers can load any additional context required at this time,
@@ -85,6 +108,7 @@ class ContextProvider(AFBaseModel, ABC):
Args:
messages: The most recent messages that the agent is being invoked with.
kwargs: not used at present.
"""
pass
@@ -125,16 +149,16 @@ class AggregateContextProvider(ContextProvider):
It delegates events to multiple context providers and aggregates responses from those events before returning.
"""
providers: list[ContextProvider]
"""List of registered context providers."""
def __init__(self, context_providers: Sequence[ContextProvider] | None = None) -> None:
"""Initialize AggregateContextProvider with context providers.
def __init__(self, context_providers: ContextProvider | Sequence[ContextProvider] | None = None) -> None:
"""Initialize the AggregateContextProvider with context providers.
Args:
context_providers: Context providers to add.
"""
super().__init__(providers=list(context_providers or [])) # type: ignore
if isinstance(context_providers, ContextProvider):
self.providers = [context_providers]
else:
self.providers = cast(list[ContextProvider], context_providers) or []
self._exit_stack: AsyncExitStack | None = None
def add(self, context_provider: ContextProvider) -> None:
@@ -145,24 +169,44 @@ class AggregateContextProvider(ContextProvider):
"""
self.providers.append(context_provider)
@override
async def thread_created(self, thread_id: str | None = None) -> None:
await asyncio.gather(*[x.thread_created(thread_id) for x in self.providers])
async def messages_adding(self, thread_id: str | None, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
await asyncio.gather(*[x.messages_adding(thread_id, new_messages) for x in self.providers])
@override
async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], **kwargs: Any) -> Context:
contexts = await asyncio.gather(*[provider.invoking(messages, **kwargs) for provider in self.providers])
instructions: str = ""
return_messages: list[ChatMessage] = []
tools: list[ToolProtocol] = []
for ctx in contexts:
if ctx.instructions:
instructions += ctx.instructions
if ctx.messages:
return_messages.extend(ctx.messages)
if ctx.tools:
tools.extend(ctx.tools)
return Context(instructions=instructions, messages=return_messages, tools=tools)
async def model_invoking(self, messages: ChatMessage | MutableSequence[ChatMessage]) -> Context:
sub_contexts = await asyncio.gather(*[x.model_invoking(messages) for x in self.providers])
combined_context = Context()
# Flatten the list of lists and filter out None values
all_contents = []
for ctx in sub_contexts:
if ctx.contents:
all_contents.extend(ctx.contents)
combined_context.contents = all_contents if all_contents else None
return combined_context
@override
async def invoked(
self,
request_messages: ChatMessage | Sequence[ChatMessage],
response_messages: ChatMessage | Sequence[ChatMessage] | None = None,
invoke_exception: Exception | None = None,
**kwargs: Any,
) -> None:
await asyncio.gather(*[
x.invoked(
request_messages=request_messages,
response_messages=response_messages,
invoke_exception=invoke_exception,
**kwargs,
)
for x in self.providers
])
@override
async def __aenter__(self) -> "Self":
"""Enter async context manager and set up all providers.
@@ -178,6 +222,7 @@ class AggregateContextProvider(ContextProvider):
return self
@override
async def __aexit__(
self,
exc_type: type[BaseException] | None,
@@ -35,6 +35,7 @@ class MiddlewareType(Enum):
__all__ = [
"AgentMiddleware",
"AgentMiddlewares",
"AgentRunContext",
"ChatContext",
"ChatMiddleware",
@@ -230,6 +231,7 @@ Middleware: TypeAlias = (
| ChatMiddleware
| ChatMiddlewareCallable
)
AgentMiddlewares: TypeAlias = AgentMiddleware | AgentMiddlewareCallable
# Middleware type markers for decorators
@@ -1009,7 +1011,7 @@ def use_chat_middleware(chat_client_class: type[TChatClient]) -> type[TChatClien
pipeline = ChatMiddlewarePipeline(chat_middleware_list) # type: ignore[arg-type]
context = ChatContext(
chat_client=self,
messages=self.prepare_messages(messages),
messages=self.prepare_messages(messages, chat_options),
chat_options=chat_options,
is_streaming=False,
kwargs=kwargs,
@@ -1059,7 +1061,7 @@ def use_chat_middleware(chat_client_class: type[TChatClient]) -> type[TChatClien
pipeline = ChatMiddlewarePipeline(all_middleware) # type: ignore[arg-type]
context = ChatContext(
chat_client=self,
messages=self.prepare_messages(messages),
messages=self.prepare_messages(messages, chat_options),
chat_options=chat_options,
is_streaming=True,
kwargs=kwargs,
+232 -248
View File
@@ -1,15 +1,19 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import Sequence
from typing import Any, Protocol, overload
from typing import Any, Protocol, TypeVar
from pydantic import model_validator
from ._memory import AggregateContextProvider
from ._pydantic import AFBaseModel
from ._types import ChatMessage
from .exceptions import AgentThreadException
__all__ = ["AgentThread", "ChatMessageList", "ChatMessageStore"]
__all__ = ["AgentThread", "ChatMessageStore", "ChatMessageStoreProtocol"]
class ChatMessageStore(Protocol):
class ChatMessageStoreProtocol(Protocol):
"""Defines methods for storing and retrieving chat messages associated with a specific thread.
Implementations of this protocol are responsible for managing the storage of chat messages,
@@ -24,7 +28,7 @@ class ChatMessageStore(Protocol):
If the messages stored in the store become very large, it is up to the store to
truncate, summarize or otherwise limit the number of messages returned.
When using implementations of ChatMessageStore, a new one should be created for each thread
When using implementations of ChatMessageStoreProtocol, a new one should be created for each thread
since they may contain state that is specific to a thread.
"""
...
@@ -33,66 +37,187 @@ class ChatMessageStore(Protocol):
"""Adds messages to the store."""
...
async def deserialize_state(self, serialized_store_state: Any, **kwargs: Any) -> None:
"""Deserializes the state into the properties on this store.
@classmethod
async def deserialize(cls, serialized_store_state: Any, **kwargs: Any) -> "ChatMessageStoreProtocol":
"""Creates a new instance of the store from previously serialized state.
This method, together with serialize_state can be used to save and load messages from a persistent store
if this store only has messages in memory.
"""
...
async def serialize_state(self, **kwargs: Any) -> Any:
async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) -> None:
"""Update the current ChatMessageStore instance from serialized state data.
Args:
serialized_store_state: Previously serialized state data containing messages.
**kwargs: Additional arguments for deserialization.
"""
...
async def serialize(self, **kwargs: Any) -> Any:
"""Serializes the current object's state.
This method, together with deserialize_state can be used to save and load messages from a persistent store
This method, together with deserialize can be used to save and load messages from a persistent store
if this store only has messages in memory.
"""
...
class AgentThread(AFBaseModel):
"""Base class for agent threads."""
class AgentThreadState(AFBaseModel):
"""State model for serializing and deserializing thread information.
_service_thread_id: str | None = None
_message_store: ChatMessageStore | None = None
Attributes:
service_thread_id: Optional ID of the thread managed by the agent service.
chat_message_store_state: Optional serialized state of the chat message store.
"""
@overload
def __init__(self) -> None:
"""Initialize an empty AgentThread with no service thread ID or message store."""
...
service_thread_id: str | None = None
chat_message_store_state: Any | None = None
@overload
def __init__(self, service_thread_id: str) -> None:
"""Initialize an AgentThread with a service thread ID.
@model_validator(mode="before")
def validate_only_one(cls, values: dict[str, Any]) -> dict[str, Any]:
if (
isinstance(values, dict)
and values.get("service_thread_id") is not None
and values.get("chat_message_store_state") is not None
):
raise AgentThreadException("Only one of service_thread_id or chat_message_store_state may be set.")
return values
class ChatMessageStoreState(AFBaseModel):
"""State model for serializing and deserializing chat message store data.
Attributes:
messages: List of chat messages stored in the message store.
"""
messages: list[ChatMessage]
TChatMessageStore = TypeVar("TChatMessageStore", bound="ChatMessageStore")
class ChatMessageStore:
"""An in-memory implementation of ChatMessageStoreProtocol that stores messages in a list.
This implementation provides a simple, list-based storage for chat messages
with support for serialization and deserialization. It implements all the
required methods of the ChatMessageStoreProtocol protocol.
The store maintains messages in memory and provides methods to serialize
and deserialize the state for persistence purposes.
Args:
messages: Optional initial list of ChatMessage objects to populate the store.
"""
def __init__(self, messages: Sequence[ChatMessage] | None = None):
"""Create a ChatMessageStore for use in a thread.
Args:
service_thread_id: The ID of the thread managed by the agent service.
messages: The messages to store.
"""
...
self.messages = list(messages) if messages else []
@overload
def __init__(self, *, message_store: ChatMessageStore) -> None:
"""Initialize an AgentThread with a custom message store.
async def add_messages(self, messages: Sequence[ChatMessage]) -> None:
"""Add messages to the store.
Args:
message_store: The ChatMessageStore implementation for managing chat messages.
messages: Sequence of ChatMessage objects to add to the store.
"""
...
self.messages.extend(messages)
def __init__(self, service_thread_id: str | None = None, *, message_store: ChatMessageStore | None = None) -> None:
"""Initialize an AgentThread.
async def list_messages(self) -> list[ChatMessage]:
"""Get all messages from the store in chronological order.
Returns:
List of ChatMessage objects, ordered from oldest to newest.
"""
return self.messages
@classmethod
async def deserialize(
cls: type[TChatMessageStore], serialized_store_state: Any, **kwargs: Any
) -> TChatMessageStore:
"""Create a new ChatMessageStore instance from serialized state data.
Args:
serialized_store_state: Previously serialized state data containing messages.
**kwargs: Additional arguments for deserialization.
Returns:
A new ChatMessageStore instance populated with messages from the serialized state.
"""
state = ChatMessageStoreState.model_validate(serialized_store_state, **kwargs)
if state.messages:
return cls(messages=state.messages)
return cls()
async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) -> None:
"""Update the current ChatMessageStore instance from serialized state data.
Args:
serialized_store_state: Previously serialized state data containing messages.
**kwargs: Additional arguments for deserialization.
"""
if not serialized_store_state:
return
state = ChatMessageStoreState.model_validate(serialized_store_state, **kwargs)
if state.messages:
self.messages = state.messages
async def serialize(self, **kwargs: Any) -> Any:
"""Serialize the current store state for persistence.
Args:
**kwargs: Additional arguments for serialization.
Returns:
Serialized state data that can be used with deserialize_state.
"""
state = ChatMessageStoreState(messages=self.messages)
return state.model_dump(**kwargs)
TAgentThread = TypeVar("TAgentThread", bound="AgentThread")
class AgentThread:
"""The Agent thread class, this can represent both a locally managed thread or a thread managed by the service."""
def __init__(
self,
*,
service_thread_id: str | None = None,
message_store: ChatMessageStoreProtocol | None = None,
context_provider: AggregateContextProvider | None = None,
) -> None:
"""Initialize an AgentThread, do not use this method manually, always use: agent.get_new_thread().
Args:
service_thread_id: Optional ID of the thread managed by the agent service.
message_store: Optional ChatMessageStore implementation for managing chat messages.
context_provider: Optional ContextProvider for the thread.
Note:
Either service_thread_id or message_store may be set, but not both.
"""
super().__init__()
if service_thread_id is not None and message_store is not None:
raise AgentThreadException("Only the service_thread_id or message_store may be set, but not both.")
self.service_thread_id = service_thread_id
self.message_store = message_store
self._service_thread_id = service_thread_id
self._message_store = message_store
self.context_provider = context_provider
@property
def is_initialized(self) -> bool:
"""Indicates if the thread is initialized.
This means either the service_thread_id or the message_store is set.
"""
return self._service_thread_id is not None or self._message_store is not None
@property
def service_thread_id(self) -> str | None:
@@ -105,39 +230,53 @@ class AgentThread(AFBaseModel):
Note that either service_thread_id or message_store may be set, but not both.
"""
if not self._service_thread_id and not service_thread_id:
if service_thread_id is None:
return
if self._message_store is not None:
raise ValueError(
raise AgentThreadException(
"Only the service_thread_id or message_store may be set, "
"but not both and switching from one to another is not supported."
)
self._service_thread_id = service_thread_id
@property
def message_store(self) -> ChatMessageStore | None:
"""Gets the ChatMessageStore used by this thread, when messages should be stored in a custom location."""
def message_store(self) -> ChatMessageStoreProtocol | None:
"""Gets the ChatMessageStoreProtocol used by this thread."""
return self._message_store
@message_store.setter
def message_store(self, message_store: ChatMessageStore | None) -> None:
"""Sets the ChatMessageStore used by this thread, when messages should be stored in a custom location.
def message_store(self, message_store: ChatMessageStoreProtocol | None) -> None:
"""Sets the ChatMessageStoreProtocol used by this thread.
Note that either service_thread_id or message_store may be set, but not both.
"""
if self._message_store is None and message_store is None:
if message_store is None:
return
if self._service_thread_id:
raise ValueError(
if self._service_thread_id is not None:
raise AgentThreadException(
"Only the service_thread_id or message_store may be set, "
"but not both and switching from one to another is not supported."
)
self._message_store = message_store
async def on_new_messages(self, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
"""Invoked when a new message has been contributed to the chat by any participant."""
if self._service_thread_id is not None:
# If the thread messages are stored in the service there is nothing to do here,
# since invoking the service should already update the thread.
return
if self._message_store is None:
# If there is no conversation id, and no store we can
# create a default in memory store.
self._message_store = ChatMessageStore()
# If a store has been provided, we need to add the messages to the store.
if isinstance(new_messages, ChatMessage):
new_messages = [new_messages]
await self._message_store.add_messages(new_messages)
async def serialize(self, **kwargs: Any) -> dict[str, Any]:
"""Serializes the current object's state.
@@ -146,226 +285,71 @@ class AgentThread(AFBaseModel):
"""
chat_message_store_state = None
if self._message_store is not None:
chat_message_store_state = await self._message_store.serialize_state(**kwargs)
chat_message_store_state = await self._message_store.serialize(**kwargs)
state = ThreadState(
state = AgentThreadState(
service_thread_id=self._service_thread_id, chat_message_store_state=chat_message_store_state
)
return state.model_dump()
async def thread_on_new_messages(thread: AgentThread, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
"""Invoked when a new message has been contributed to the chat by any participant."""
if thread.service_thread_id is not None:
# If the thread messages are stored in the service there is nothing to do here,
# since invoking the service should already update the thread.
return
if thread.message_store is None:
# If there is no conversation id, and no store we can
# create a default in memory store.
thread.message_store = ChatMessageList()
# If a store has been provided, we need to add the messages to the store.
if isinstance(new_messages, ChatMessage):
new_messages = [new_messages]
await thread.message_store.add_messages(new_messages)
async def deserialize_thread_state(
thread: AgentThread,
serialized_thread: dict[str, Any],
**kwargs: Any,
) -> None:
"""Deserializes the state from a dictionary into the thread properties."""
state = ThreadState.model_validate(serialized_thread)
if state.service_thread_id:
thread.service_thread_id = state.service_thread_id
# Since we have an ID, we should not have a chat message store and we can return here.
return
# If we don't have any ChatMessageStore state return here.
if state.chat_message_store_state is None:
return
if thread.message_store is None:
# If we don't have a chat message store yet, create an in-memory one.
thread.message_store = ChatMessageList()
await thread.message_store.deserialize_state(state.chat_message_store_state, **kwargs)
class ThreadState(AFBaseModel):
"""State model for serializing and deserializing thread information.
Attributes:
service_thread_id: Optional ID of the thread managed by the agent service.
chat_message_store_state: Optional serialized state of the chat message store.
"""
service_thread_id: str | None = None
chat_message_store_state: Any | None = None
class StoreState(AFBaseModel):
"""State model for serializing and deserializing chat message store data.
Attributes:
messages: List of chat messages stored in the message store.
"""
messages: list[ChatMessage]
class ChatMessageList:
"""An in-memory implementation of ChatMessageStore that stores messages in a list.
This implementation provides a simple, list-based storage for chat messages
with support for serialization and deserialization. It implements all the
required methods of the ChatMessageStore protocol and provides additional
list-like operations for direct message manipulation.
The store maintains messages in memory and provides methods to serialize
and deserialize the state for persistence purposes.
"""
def __init__(self, messages: Sequence[ChatMessage] | None = None) -> None:
"""Initialize the message store with optional initial messages.
@classmethod
async def deserialize(
cls: type[TAgentThread],
serialized_thread_state: dict[str, Any],
*,
message_store: ChatMessageStoreProtocol | None = None,
**kwargs: Any,
) -> TAgentThread:
"""Deserializes the state from a dictionary into a new AgentThread instance.
Args:
messages: Optional collection of initial ChatMessage objects to store.
"""
self._messages: list[ChatMessage] = []
if messages:
self._messages.extend(messages)
async def add_messages(self, messages: Sequence[ChatMessage]) -> None:
"""Add messages to the store.
Args:
messages: Sequence of ChatMessage objects to add to the store.
"""
self._messages.extend(messages)
async def list_messages(self) -> list[ChatMessage]:
"""Get all messages from the store in chronological order.
Returns:
List of ChatMessage objects, ordered from oldest to newest.
"""
return self._messages
async def deserialize_state(self, serialized_store_state: Any, **kwargs: Any) -> None:
"""Deserialize state data into this store instance.
Args:
serialized_store_state: Previously serialized state data containing messages.
serialized_thread_state: The serialized thread state as a dictionary.
message_store: Optional ChatMessageStoreProtocol to use for managing messages.
If not provided, a new ChatMessageStore will be created if needed.
**kwargs: Additional arguments for deserialization.
"""
if serialized_store_state:
state = StoreState.model_validate(obj=serialized_store_state, **kwargs)
if state.messages:
self._messages.extend(state.messages)
async def serialize_state(self, **kwargs: Any) -> Any:
"""Serialize the current store state for persistence.
Args:
**kwargs: Additional arguments for serialization.
Returns:
Serialized state data that can be used with deserialize_state.
A new AgentThread instance with properties set from the serialized state.
"""
state = StoreState(messages=self._messages)
return state.model_dump(**kwargs)
state = AgentThreadState.model_validate(serialized_thread_state)
def __len__(self) -> int:
"""Return the number of messages in the store.
if state.service_thread_id is not None:
return cls(service_thread_id=state.service_thread_id)
Returns:
The count of messages currently stored.
"""
return len(self._messages)
# If we don't have any ChatMessageStoreProtocol state return here.
if state.chat_message_store_state is None:
return cls()
def __getitem__(self, index: int) -> ChatMessage:
"""Get a message by index.
if message_store is not None:
try:
await message_store.update_from_state(state.chat_message_store_state, **kwargs)
except Exception as ex:
raise AgentThreadException("Failed to deserialize the provided message store.") from ex
return cls(message_store=message_store)
try:
message_store = await ChatMessageStore.deserialize(state.chat_message_store_state, **kwargs)
except Exception as ex:
raise AgentThreadException("Failed to deserialize the message store.") from ex
return cls(message_store=message_store)
Args:
index: The index of the message to retrieve.
async def update_from_thread_state(
self,
serialized_thread_state: dict[str, Any],
**kwargs: Any,
) -> None:
"""Deserializes the state from a dictionary into the thread properties."""
state = AgentThreadState.model_validate(serialized_thread_state)
Returns:
The ChatMessage at the specified index.
"""
return self._messages[index]
def __setitem__(self, index: int, item: ChatMessage) -> None:
"""Set a message at the specified index.
Args:
index: The index at which to set the message.
item: The ChatMessage to set at the specified index.
"""
self._messages[index] = item
def append(self, item: ChatMessage) -> None:
"""Append a message to the end of the store.
Args:
item: The ChatMessage to append.
"""
self._messages.append(item)
def clear(self) -> None:
"""Remove all messages from the store."""
self._messages.clear()
def index(self, item: ChatMessage) -> int:
"""Return the index of the first occurrence of the specified message.
Args:
item: The ChatMessage to find.
Returns:
The index of the first occurrence of the message.
Raises:
ValueError: If the message is not found in the store.
"""
return self._messages.index(item)
def insert(self, index: int, item: ChatMessage) -> None:
"""Insert a message at the specified index.
Args:
index: The index at which to insert the message.
item: The ChatMessage to insert.
"""
self._messages.insert(index, item)
def remove(self, item: ChatMessage) -> None:
"""Remove the first occurrence of the specified message from the store.
Args:
item: The ChatMessage to remove.
Raises:
ValueError: If the message is not found in the store.
"""
self._messages.remove(item)
def pop(self, index: int = -1) -> ChatMessage:
"""Remove and return a message at the specified index.
Args:
index: The index of the message to remove and return. Defaults to -1 (last item).
Returns:
The ChatMessage that was removed.
Raises:
IndexError: If the index is out of range.
"""
return self._messages.pop(index)
if state.service_thread_id is not None:
self.service_thread_id = state.service_thread_id
# Since we have an ID, we should not have a chat message store and we can return here.
return
# If we don't have any ChatMessageStoreProtocol state return here.
if state.chat_message_store_state is None:
return
if self.message_store is not None:
await self.message_store.update_from_state(state.chat_message_store_state, **kwargs)
# If we don't have a chat message store yet, create an in-memory one.
return
# Create the message store from the default.
self.message_store = await ChatMessageStore.deserialize(state.chat_message_store_state, **kwargs) # type: ignore
@@ -1799,6 +1799,7 @@ class ChatOptions(AFBaseModel):
allow_multiple_tool_calls: bool | None = None
conversation_id: str | None = None
frequency_penalty: Annotated[float | None, Field(ge=-2.0, le=2.0)] = None
instructions: str | None = None
logit_bias: MutableMapping[str | int, float] | None = None
max_tokens: Annotated[int | None, Field(gt=0)] = None
metadata: MutableMapping[str, str] | None = None
@@ -1811,7 +1812,7 @@ class ChatOptions(AFBaseModel):
store: bool | None = None
temperature: Annotated[float | None, Field(ge=0.0, le=2.0)] = None
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | Mapping[str, Any] | None = None
tools: list[ToolProtocol | MutableMapping[str, Any]] | None = None
tools: MutableSequence[ToolProtocol | MutableMapping[str, Any]] | None = None
top_p: Annotated[float | None, Field(ge=0.0, le=1.0)] = None
user: str | None = None
@@ -1902,11 +1903,13 @@ class ChatOptions(AFBaseModel):
# tool_choice has a specialized serialize method. Save it here so we can fix it later.
tool_choice = other.tool_choice or self.tool_choice
updated_values = other.model_dump(exclude_none=True, exclude={"tools"})
logit_bias = updated_values.pop("logit_bias", {})
metadata = updated_values.pop("metadata", {})
additional_properties = updated_values.pop("additional_properties", {})
combined = self.model_copy(update=updated_values)
combined.tool_choice = tool_choice
combined.instructions = " ".join([combined.instructions or "", other.instructions or ""])
combined.logit_bias = {**(combined.logit_bias or {}), **logit_bias}
combined.metadata = {**(combined.metadata or {}), **metadata}
combined.additional_properties = {**(combined.additional_properties or {}), **additional_properties}
@@ -187,5 +187,3 @@ __all__ = [
with contextlib.suppress(AttributeError, TypeError, ValueError):
# Rebuild WorkflowExecutor to resolve Workflow forward reference
WorkflowExecutor.model_rebuild()
# Rebuild WorkflowAgent to resolve Workflow forward reference
WorkflowAgent.model_rebuild()
@@ -6,7 +6,7 @@ from collections.abc import AsyncIterable, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast
from pydantic import Field
from pydantic import BaseModel
from agent_framework import (
AgentRunResponse,
@@ -21,7 +21,6 @@ from agent_framework import (
UsageDetails,
)
from .._pydantic import AFBaseModel
from ..exceptions import AgentExecutionException
from ._events import (
AgentRunUpdateEvent,
@@ -41,15 +40,10 @@ class WorkflowAgent(BaseAgent):
# Class variable for the request info function name
REQUEST_INFO_FUNCTION_NAME: ClassVar[str] = "request_info"
class RequestInfoFunctionArgs(AFBaseModel):
class RequestInfoFunctionArgs(BaseModel):
request_id: str
data: Any
workflow: "Workflow" = Field(description="The workflow wrapped as an agent")
pending_requests: dict[str, RequestInfoEvent] = Field(
default_factory=dict, description="Pending request info events"
)
def __init__(
self,
workflow: "Workflow",
@@ -70,9 +64,6 @@ class WorkflowAgent(BaseAgent):
"""
if id is None:
id = f"WorkflowAgent_{uuid.uuid4().hex[:8]}"
# Initialize with standard BaseAgent parameters first
kwargs["workflow"] = workflow
# Validate the workflow's start executor can handle agent-facing message inputs
try:
start_executor = workflow.get_start_executor()
@@ -84,6 +75,9 @@ class WorkflowAgent(BaseAgent):
super().__init__(id=id, name=name, description=description, **kwargs)
self.workflow: "Workflow" = workflow
self.pending_requests: dict[str, RequestInfoEvent] = {}
async def run(
self,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
@@ -116,8 +110,7 @@ class WorkflowAgent(BaseAgent):
response = self.merge_updates(response_updates, response_id)
# Notify thread of new messages (both input and response messages)
await self._notify_thread_of_new_messages(thread, input_messages)
await self._notify_thread_of_new_messages(thread, response.messages)
await self._notify_thread_of_new_messages(thread, input_messages, response.messages)
return response
@@ -151,8 +144,7 @@ class WorkflowAgent(BaseAgent):
response = self.merge_updates(response_updates, response_id)
# Notify thread of new messages (both input and response messages)
await self._notify_thread_of_new_messages(thread, input_messages)
await self._notify_thread_of_new_messages(thread, response.messages)
await self._notify_thread_of_new_messages(thread, input_messages, response.messages)
async def _run_stream_impl(
self,
@@ -49,6 +49,12 @@ class AgentInitializationError(AgentException):
pass
class AgentThreadException(AgentException):
"""An error occurred while managing the agent thread."""
pass
class ChatClientException(AgentFrameworkException):
"""An error occurred while dealing with a chat client."""
@@ -407,7 +407,7 @@ class OpenAIAssistantsClient(OpenAIConfigMixin, BaseChatClient):
"json_schema": chat_options.response_format.model_json_schema(),
}
instructions: list[str] = []
instructions: list[str] = [chat_options.instructions] if chat_options and chat_options.instructions else []
tool_results: list[FunctionResultContent] | None = None
additional_messages: list[AdditionalMessage] | None = None
@@ -119,7 +119,7 @@ class OpenAIBaseChatClient(OpenAIBase, BaseChatClient):
# region content creation
def _chat_to_tool_spec(self, tools: list[ToolProtocol | MutableMapping[str, Any]]) -> list[dict[str, Any]]:
def _chat_to_tool_spec(self, tools: Sequence[ToolProtocol | MutableMapping[str, Any]]) -> list[dict[str, Any]]:
chat_tools: list[dict[str, Any]] = []
for tool in tools:
if isinstance(tool, ToolProtocol):
@@ -132,7 +132,9 @@ class OpenAIBaseChatClient(OpenAIBase, BaseChatClient):
chat_tools.append(tool if isinstance(tool, dict) else dict(tool))
return chat_tools
def _process_web_search_tool(self, tools: list[ToolProtocol | MutableMapping[str, Any]]) -> dict[str, Any] | None:
def _process_web_search_tool(
self, tools: Sequence[ToolProtocol | MutableMapping[str, Any]]
) -> dict[str, Any] | None:
for tool in tools:
if isinstance(tool, HostedWebSearchTool):
# Web search tool requires special handling
@@ -152,6 +154,9 @@ class OpenAIBaseChatClient(OpenAIBase, BaseChatClient):
def _prepare_options(self, messages: MutableSequence[ChatMessage], chat_options: ChatOptions) -> dict[str, Any]:
# Preprocess web search tool if it exists
options_dict = chat_options.to_provider_settings()
instructions = options_dict.pop("instructions", None)
if instructions:
messages = [ChatMessage(role="system", text=instructions), *messages]
if messages and "messages" not in options_dict:
options_dict["messages"] = self._prepare_chat_history_for_request(messages)
if "messages" not in options_dict:
@@ -172,7 +172,7 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient):
# region Prep methods
def _tools_to_response_tools(
self, tools: list[ToolProtocol | MutableMapping[str, Any]]
self, tools: Sequence[ToolProtocol | MutableMapping[str, Any]]
) -> list[ToolParam | dict[str, Any]]:
response_tools: list[ToolParam | dict[str, Any]] = []
for tool in tools:
@@ -314,6 +314,8 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient):
options_dict["user"] = chat_options.user
# messages
if instructions := options_dict.pop("instructions", None):
messages = [ChatMessage(role="system", text=instructions), *messages]
request_input = self._prepare_chat_messages_for_request(messages)
if not request_input:
raise ServiceInvalidRequestError("Messages are required for chat completions")
+1 -1
View File
@@ -141,7 +141,7 @@ class MockBaseChatClient(BaseChatClient):
logger.debug(f"Running base chat client inner, with: {messages=}, {chat_options=}, {kwargs=}")
self.call_count += 1
if not self.run_responses:
return ChatResponse(messages=ChatMessage(role="assistant", text=f"test response - {messages[0].text}"))
return ChatResponse(messages=ChatMessage(role="assistant", text=f"test response - {messages[-1].text}"))
response = self.run_responses.pop(0)
+75 -95
View File
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import AsyncIterable, MutableSequence, Sequence
from typing import Any
from uuid import uuid4
from pytest import raises
@@ -10,17 +11,18 @@ from agent_framework import (
AgentRunResponse,
AgentRunResponseUpdate,
AgentThread,
AggregateContextProvider,
ChatAgent,
ChatClientProtocol,
ChatMessage,
ChatMessageList,
ChatMessageStore,
ChatResponse,
Contents,
Context,
ContextProvider,
HostedCodeInterpreterTool,
Role,
TextContent,
)
from agent_framework._memory import AggregateContextProvider, Context, ContextProvider
from agent_framework.exceptions import AgentExecutionException
@@ -98,11 +100,10 @@ async def test_chat_client_agent_get_new_thread(chat_client: ChatClientProtocol)
async def test_chat_client_agent_prepare_thread_and_messages(chat_client: ChatClientProtocol) -> None:
agent = ChatAgent(chat_client=chat_client)
message = ChatMessage(role=Role.USER, text="Hello")
thread = AgentThread(message_store=ChatMessageList(messages=[message]))
thread = AgentThread(message_store=ChatMessageStore(messages=[message]))
_, result_messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
_, _, result_messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
thread=thread,
context=Context(),
input_messages=[ChatMessage(role=Role.USER, text="Test")],
)
@@ -152,7 +153,7 @@ async def test_chat_client_agent_update_thread_conversation_id_missing(chat_clie
thread = AgentThread(service_thread_id="123")
with raises(AgentExecutionException, match="Service did not return a valid conversation id"):
agent._update_thread_with_type_and_conversation_id(thread, None) # type: ignore[reportPrivateUsage]
await agent._update_thread_with_type_and_conversation_id(thread, None) # type: ignore[reportPrivateUsage]
async def test_chat_client_agent_default_author_name(chat_client: ChatClientProtocol) -> None:
@@ -191,49 +192,49 @@ async def test_chat_client_agent_author_name_is_used_from_response(chat_client_b
# Mock context provider for testing
class MockContextProvider(ContextProvider):
context_contents: list[Contents] | None = None
thread_created_called: bool = False
messages_adding_called: bool = False
model_invoking_called: bool = False
thread_created_thread_id: str | None = None
messages_adding_thread_id: str | None = None
new_messages: list[ChatMessage] = []
def __init__(self, contents: list[Contents] | None = None) -> None:
super().__init__()
self.context_contents = contents
def __init__(self, messages: list[ChatMessage] | None = None) -> None:
self.context_messages = messages
self.thread_created_called = False
self.messages_adding_called = False
self.model_invoking_called = False
self.invoked_called = False
self.invoking_called = False
self.thread_created_thread_id = None
self.messages_adding_thread_id = None
self.new_messages = []
self.invoked_thread_id = None
self.new_messages: list[ChatMessage] = []
async def thread_created(self, thread_id: str | None) -> None:
self.thread_created_called = True
self.thread_created_thread_id = thread_id
async def messages_adding(self, thread_id: str | None, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
self.messages_adding_called = True
self.messages_adding_thread_id = thread_id
if isinstance(new_messages, ChatMessage):
self.new_messages.append(new_messages)
async def invoked(
self,
request_messages: ChatMessage | Sequence[ChatMessage],
response_messages: ChatMessage | Sequence[ChatMessage] | None = None,
invoke_exception: Any = None,
**kwargs: Any,
) -> None:
self.invoked_called = True
if isinstance(request_messages, ChatMessage):
self.new_messages.append(request_messages)
else:
self.new_messages.extend(new_messages)
self.new_messages.extend(request_messages)
if isinstance(response_messages, ChatMessage):
self.new_messages.append(response_messages)
else:
self.new_messages.extend(response_messages)
async def model_invoking(self, messages: ChatMessage | MutableSequence[ChatMessage]) -> Context:
self.model_invoking_called = True
return Context(contents=self.context_contents)
async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], **kwargs: Any) -> Context:
self.invoking_called = True
return Context(messages=self.context_messages)
async def test_chat_agent_context_providers_model_invoking(chat_client: ChatClientProtocol) -> None:
"""Test that context providers' model_invoking is called during agent run."""
mock_provider = MockContextProvider(contents=[TextContent("Test context instructions")])
"""Test that context providers' invoking is called during agent run."""
mock_provider = MockContextProvider(messages=[ChatMessage(role=Role.SYSTEM, text="Test context instructions")])
agent = ChatAgent(chat_client=chat_client, context_providers=mock_provider)
await agent.run("Hello")
assert mock_provider.model_invoking_called
assert mock_provider.invoking_called
async def test_chat_agent_context_providers_thread_created(chat_client_base: ChatClientProtocol) -> None:
@@ -255,75 +256,54 @@ async def test_chat_agent_context_providers_thread_created(chat_client_base: Cha
async def test_chat_agent_context_providers_messages_adding(chat_client: ChatClientProtocol) -> None:
"""Test that context providers' messages_adding is called during agent run."""
"""Test that context providers' invoked is called during agent run."""
mock_provider = MockContextProvider()
agent = ChatAgent(chat_client=chat_client, context_providers=mock_provider)
await agent.run("Hello")
assert mock_provider.messages_adding_called
assert mock_provider.invoked_called
# Should be called with both input and response messages
assert len(mock_provider.new_messages) >= 2
async def test_chat_agent_context_instructions_in_messages(chat_client: ChatClientProtocol) -> None:
"""Test that AI context instructions are included in messages."""
mock_provider = MockContextProvider(contents=[TextContent("Context-specific instructions")])
mock_provider = MockContextProvider(messages=[ChatMessage(role="system", text="Context-specific instructions")])
agent = ChatAgent(chat_client=chat_client, instructions="Agent instructions", context_providers=mock_provider)
# We need to test the _prepare_thread_and_messages method directly
context = Context(contents=[TextContent("Context-specific instructions")])
_, messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
thread=None, context=context, input_messages=[ChatMessage(role=Role.USER, text="Hello")]
_, _, messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
thread=None, input_messages=[ChatMessage(role=Role.USER, text="Hello")]
)
# Should have agent instructions, context instructions, and user message
assert len(messages) == 3
assert messages[0].role == Role.SYSTEM
assert messages[0].text == "Agent instructions"
assert messages[1].role == Role.SYSTEM
assert messages[1].text == "Context-specific instructions"
assert messages[2].role == Role.USER
assert messages[2].text == "Hello"
async def test_chat_agent_context_instructions_without_agent_instructions(chat_client: ChatClientProtocol) -> None:
"""Test that AI context instructions work when agent has no instructions."""
agent = ChatAgent(chat_client=chat_client) # No instructions
context = Context(contents=[TextContent("Context-only instructions")])
_, messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
thread=None, context=context, input_messages=[ChatMessage(role=Role.USER, text="Hello")]
)
# Should have context instructions and user message only
# Should have context instructions, and user message
assert len(messages) == 2
assert messages[0].role == Role.SYSTEM
assert messages[0].text == "Context-only instructions"
assert messages[0].text == "Context-specific instructions"
assert messages[1].role == Role.USER
assert messages[1].text == "Hello"
# instructions system message is added by a chat_client
async def test_chat_agent_no_context_instructions(chat_client: ChatClientProtocol) -> None:
"""Test behavior when AI context has no instructions."""
agent = ChatAgent(chat_client=chat_client, instructions="Agent instructions")
context = Context() # No instructions
mock_provider = MockContextProvider()
agent = ChatAgent(chat_client=chat_client, instructions="Agent instructions", context_providers=mock_provider)
_, messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
thread=None, context=context, input_messages=[ChatMessage(role=Role.USER, text="Hello")]
_, _, messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
thread=None, input_messages=[ChatMessage(role=Role.USER, text="Hello")]
)
# Should have agent instructions and user message only
assert len(messages) == 2
assert messages[0].role == Role.SYSTEM
assert messages[0].text == "Agent instructions"
assert messages[1].role == Role.USER
assert messages[1].text == "Hello"
assert len(messages) == 1
assert messages[0].role == Role.USER
assert messages[0].text == "Hello"
async def test_chat_agent_run_stream_context_providers(chat_client: ChatClientProtocol) -> None:
"""Test that context providers work with run_stream method."""
mock_provider = MockContextProvider(contents=[TextContent("Stream context instructions")])
mock_provider = MockContextProvider(messages=[ChatMessage(role=Role.SYSTEM, text="Stream context instructions")])
agent = ChatAgent(chat_client=chat_client, context_providers=mock_provider)
# Collect all stream updates
@@ -332,47 +312,48 @@ async def test_chat_agent_run_stream_context_providers(chat_client: ChatClientPr
updates.append(update)
# Verify context provider was called
assert mock_provider.model_invoking_called
assert mock_provider.thread_created_called
assert mock_provider.messages_adding_called
assert mock_provider.invoking_called
# no conversation id is created, so no need to thread_create to be called.
assert not mock_provider.thread_created_called
assert mock_provider.invoked_called
async def test_chat_agent_multiple_context_providers(chat_client: ChatClientProtocol) -> None:
"""Test that multiple context providers work together."""
provider1 = MockContextProvider(contents=[TextContent("First provider instructions")])
provider2 = MockContextProvider(contents=[TextContent("Second provider instructions")])
provider1 = MockContextProvider(messages=[ChatMessage(role=Role.SYSTEM, text="First provider instructions")])
provider2 = MockContextProvider(messages=[ChatMessage(role=Role.SYSTEM, text="Second provider instructions")])
agent = ChatAgent(chat_client=chat_client, context_providers=[provider1, provider2])
await agent.run("Hello")
# Both providers should be called
assert provider1.model_invoking_called
assert provider1.thread_created_called
assert provider1.messages_adding_called
assert provider1.invoking_called
assert not provider1.thread_created_called
assert provider1.invoked_called
assert provider2.model_invoking_called
assert provider2.thread_created_called
assert provider2.messages_adding_called
assert provider2.invoking_called
assert not provider2.thread_created_called
assert provider2.invoked_called
async def test_chat_agent_aggregate_context_provider_combines_instructions() -> None:
"""Test that AggregateContextProvider combines instructions from multiple providers."""
provider1 = MockContextProvider(contents=[TextContent("First instruction")])
provider2 = MockContextProvider(contents=[TextContent("Second instruction")])
provider1 = MockContextProvider(messages=[ChatMessage(role=Role.SYSTEM, text="First instruction")])
provider2 = MockContextProvider(messages=[ChatMessage(role=Role.SYSTEM, text="Second instruction")])
aggregate = AggregateContextProvider()
aggregate.providers.append(provider1)
aggregate.providers.append(provider2)
# Test model_invoking combines instructions
result = await aggregate.model_invoking([ChatMessage(role=Role.USER, text="Test")])
# Test invoking combines instructions
result = await aggregate.invoking([ChatMessage(role=Role.USER, text="Test")])
assert result.contents
assert isinstance(result.contents[0], TextContent)
assert isinstance(result.contents[1], TextContent)
assert result.contents[0].text == "First instruction"
assert result.contents[1].text == "Second instruction"
assert result.messages
assert isinstance(result.messages[0], ChatMessage)
assert isinstance(result.messages[1], ChatMessage)
assert result.messages[0].text == "First instruction"
assert result.messages[1].text == "Second instruction"
async def test_chat_agent_context_providers_with_thread_service_id(chat_client_base: ChatClientProtocol) -> None:
@@ -388,12 +369,11 @@ async def test_chat_agent_context_providers_with_thread_service_id(chat_client_b
agent = ChatAgent(chat_client=chat_client_base, context_providers=mock_provider)
# Use existing service-managed thread
thread = AgentThread(service_thread_id="existing-thread-id")
thread = agent.get_new_thread(service_thread_id="existing-thread-id")
await agent.run("Hello", thread=thread)
# messages_adding should be called with the service thread ID from response
assert mock_provider.messages_adding_called
assert mock_provider.messages_adding_thread_id == "service-thread-123" # Updated thread ID from response
# invoked should be called with the service thread ID from response
assert mock_provider.invoked_called
# Tests for as_tool method
+114 -120
View File
@@ -1,33 +1,23 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import MutableSequence, Sequence
from collections.abc import MutableSequence
from typing import Any
from unittest.mock import AsyncMock, Mock
from agent_framework import ChatMessage, Contents, Role, TextContent
from agent_framework import ChatMessage, Role, TextContent
from agent_framework._memory import AggregateContextProvider, Context, ContextProvider
class MockContextProvider(ContextProvider):
"""Mock ContextProvider for testing."""
context_contents: list[Contents] | None = None
thread_created_called: bool = False
messages_adding_called: bool = False
model_invoking_called: bool = False
thread_created_thread_id: str | None = None
messages_adding_thread_id: str | None = None
messages_adding_new_messages: ChatMessage | Sequence[ChatMessage] | None = None
model_invoking_messages: ChatMessage | MutableSequence[ChatMessage] | None = None
def __init__(self, context_contents: list[Contents] | None = None) -> None:
super().__init__()
self.context_contents = context_contents
def __init__(self, messages: list[ChatMessage] | None = None) -> None:
self.context_messages = messages
self.thread_created_called = False
self.messages_adding_called = False
self.model_invoking_called = False
self.invoked_called = False
self.invoking_called = False
self.thread_created_thread_id = None
self.messages_adding_thread_id = None
self.messages_adding_new_messages = None
self.new_messages = None
self.model_invoking_messages = None
async def thread_created(self, thread_id: str | None) -> None:
@@ -35,18 +25,23 @@ class MockContextProvider(ContextProvider):
self.thread_created_called = True
self.thread_created_thread_id = thread_id
async def messages_adding(self, thread_id: str | None, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
"""Track messages_adding calls."""
self.messages_adding_called = True
self.messages_adding_thread_id = thread_id
self.messages_adding_new_messages = new_messages
async def invoked(
self,
request_messages: Any,
response_messages: Any | None = None,
invoke_exception: Exception | None = None,
**kwargs: Any,
) -> None:
"""Track invoked calls."""
self.invoked_called = True
self.new_messages = request_messages
async def model_invoking(self, messages: ChatMessage | MutableSequence[ChatMessage]) -> Context:
"""Track model_invoking calls and return context."""
self.model_invoking_called = True
async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], **kwargs: Any) -> Context:
"""Track invoking calls and return context."""
self.invoking_called = True
self.model_invoking_messages = messages
context = Context()
context.contents = self.context_contents
context.messages = self.context_messages
return context
@@ -65,19 +60,21 @@ class TestAggregateContextProvider:
def test_init_with_providers(self) -> None:
"""Test initialization with providers."""
provider1 = MockContextProvider([TextContent("instructions1")])
provider2 = MockContextProvider([TextContent("instructions1")])
providers = [provider1, provider2]
provider1 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 1")])
provider2 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 2")])
provider3 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 3")])
providers = [provider1, provider2, provider3]
aggregate = AggregateContextProvider(providers)
assert len(aggregate.providers) == 2
assert len(aggregate.providers) == 3
assert aggregate.providers[0] is provider1
assert aggregate.providers[1] is provider2
assert aggregate.providers[2] is provider3
def test_add_provider(self) -> None:
"""Test adding a provider."""
aggregate = AggregateContextProvider()
provider = MockContextProvider([TextContent("instructions")])
provider = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions")])
aggregate.add(provider)
assert len(aggregate.providers) == 1
@@ -86,8 +83,8 @@ class TestAggregateContextProvider:
def test_add_multiple_providers(self) -> None:
"""Test adding multiple providers."""
aggregate = AggregateContextProvider()
provider1 = MockContextProvider([TextContent("instructions1")])
provider2 = MockContextProvider([TextContent("instructions2")])
provider1 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 1")])
provider2 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 2")])
aggregate.add(provider1)
aggregate.add(provider2)
@@ -105,8 +102,8 @@ class TestAggregateContextProvider:
async def test_thread_created_with_providers(self) -> None:
"""Test thread_created calls all providers."""
provider1 = MockContextProvider([TextContent("instructions1")])
provider2 = MockContextProvider([TextContent("instructions2")])
provider1 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 1")])
provider2 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 2")])
aggregate = AggregateContextProvider([provider1, provider2])
thread_id = "thread-123"
@@ -119,7 +116,7 @@ class TestAggregateContextProvider:
async def test_thread_created_with_none_thread_id(self) -> None:
"""Test thread_created with None thread_id."""
provider = MockContextProvider([TextContent("instructions")])
provider = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions")])
aggregate = AggregateContextProvider([provider])
await aggregate.thread_created(None)
@@ -128,155 +125,150 @@ class TestAggregateContextProvider:
assert provider.thread_created_thread_id is None
async def test_messages_adding_with_no_providers(self) -> None:
"""Test messages_adding with no providers."""
"""Test invoked with no providers."""
aggregate = AggregateContextProvider()
message = ChatMessage(text="Hello", role=Role.USER)
# Should not raise an exception
await aggregate.messages_adding("thread-123", message)
await aggregate.invoked(message)
async def test_messages_adding_with_single_message(self) -> None:
"""Test messages_adding with a single message."""
provider1 = MockContextProvider([TextContent("instructions1")])
provider2 = MockContextProvider([TextContent("instructions2")])
"""Test invoked with a single message."""
provider1 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 1")])
provider2 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 2")])
aggregate = AggregateContextProvider([provider1, provider2])
thread_id = "thread-123"
message = ChatMessage(text="Hello", role=Role.USER)
await aggregate.messages_adding(thread_id, message)
await aggregate.invoked(message)
assert provider1.messages_adding_called
assert provider1.messages_adding_thread_id == thread_id
assert provider1.messages_adding_new_messages == message
assert provider2.messages_adding_called
assert provider2.messages_adding_thread_id == thread_id
assert provider2.messages_adding_new_messages == message
assert provider1.invoked_called
assert provider1.new_messages == message
assert provider2.invoked_called
assert provider2.new_messages == message
async def test_messages_adding_with_message_sequence(self) -> None:
"""Test messages_adding with a sequence of messages."""
provider = MockContextProvider([TextContent("instructions")])
"""Test invoked with a sequence of messages."""
provider = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions")])
aggregate = AggregateContextProvider([provider])
thread_id = "thread-123"
messages = [
ChatMessage(text="Hello", role=Role.USER),
ChatMessage(text="Hi there", role=Role.ASSISTANT),
]
await aggregate.messages_adding(thread_id, messages)
await aggregate.invoked(messages)
assert provider.messages_adding_called
assert provider.messages_adding_thread_id == thread_id
assert provider.messages_adding_new_messages == messages
assert provider.invoked_called
assert provider.new_messages == messages
async def test_model_invoking_with_no_providers(self) -> None:
"""Test model_invoking with no providers."""
"""Test invoking with no providers."""
aggregate = AggregateContextProvider()
message = ChatMessage(text="Hello", role=Role.USER)
context = await aggregate.model_invoking(message)
context = await aggregate.invoking(message)
assert isinstance(context, Context)
assert not context.contents
assert not context.messages
async def test_model_invoking_with_single_provider(self) -> None:
"""Test model_invoking with a single provider."""
provider = MockContextProvider([TextContent("Test instructions")])
"""Test invoking with a single provider."""
provider = MockContextProvider(messages=[ChatMessage(role="user", text="Test instructions")])
aggregate = AggregateContextProvider([provider])
message = ChatMessage(text="Hello", role=Role.USER)
context = await aggregate.model_invoking(message)
message = [ChatMessage(text="Hello", role=Role.USER)]
context = await aggregate.invoking(message)
assert provider.model_invoking_called
assert provider.invoking_called
assert provider.model_invoking_messages == message
assert isinstance(context, Context)
assert context.contents
assert isinstance(context.contents[0], TextContent)
assert context.contents[0].text == "Test instructions"
assert context.messages
assert isinstance(context.messages[0].contents[0], TextContent)
assert context.messages[0].text == "Test instructions"
async def test_model_invoking_with_multiple_providers(self) -> None:
"""Test model_invoking combines contexts from multiple providers."""
provider1 = MockContextProvider([TextContent("Instructions 1")])
provider2 = MockContextProvider([TextContent("Instructions 2")])
provider3 = MockContextProvider([TextContent("Instructions 3")])
"""Test invoking combines contexts from multiple providers."""
provider1 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 1")])
provider2 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 2")])
provider3 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 3")])
aggregate = AggregateContextProvider([provider1, provider2, provider3])
messages = [ChatMessage(text="Hello", role=Role.USER)]
context = await aggregate.model_invoking(messages)
context = await aggregate.invoking(messages)
assert provider1.model_invoking_called
assert provider1.invoking_called
assert provider1.model_invoking_messages == messages
assert provider2.model_invoking_called
assert provider2.invoking_called
assert provider2.model_invoking_messages == messages
assert provider3.model_invoking_called
assert provider3.invoking_called
assert provider3.model_invoking_messages == messages
assert isinstance(context, Context)
assert context.contents
assert isinstance(context.contents[0], TextContent)
assert isinstance(context.contents[1], TextContent)
assert isinstance(context.contents[2], TextContent)
assert context.contents[0].text == "Instructions 1"
assert context.contents[1].text == "Instructions 2"
assert context.contents[2].text == "Instructions 3"
assert context.messages
assert isinstance(context.messages[0].contents[0], TextContent)
assert isinstance(context.messages[1].contents[0], TextContent)
assert isinstance(context.messages[2].contents[0], TextContent)
assert context.messages[0].text == "Instructions 1"
assert context.messages[1].text == "Instructions 2"
assert context.messages[2].text == "Instructions 3"
async def test_model_invoking_with_none_instructions(self) -> None:
"""Test model_invoking filters out None instructions."""
provider1 = MockContextProvider([TextContent("Instructions 1")])
provider2 = MockContextProvider(None) # None instructions
provider3 = MockContextProvider([TextContent("Instructions 3")])
"""Test invoking filters out None instructions."""
provider1 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 1")])
provider2 = MockContextProvider(messages=None) # None instructions
provider3 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 3")])
aggregate = AggregateContextProvider([provider1, provider2, provider3])
message = ChatMessage(text="Hello", role=Role.USER)
context = await aggregate.model_invoking(message)
context = await aggregate.invoking(message)
assert isinstance(context, Context)
assert context.contents
assert isinstance(context.contents[0], TextContent)
assert isinstance(context.contents[1], TextContent)
assert context.contents[0].text == "Instructions 1"
assert context.contents[1].text == "Instructions 3"
assert context.messages
assert isinstance(context.messages[0].contents[0], TextContent)
assert isinstance(context.messages[1].contents[0], TextContent)
assert context.messages[0].text == "Instructions 1"
assert context.messages[1].text == "Instructions 3"
async def test_model_invoking_with_all_none_instructions(self) -> None:
"""Test model_invoking when all providers return None instructions."""
"""Test invoking when all providers return None instructions."""
provider1 = MockContextProvider(None)
provider2 = MockContextProvider(None)
aggregate = AggregateContextProvider([provider1, provider2])
message = ChatMessage(text="Hello", role=Role.USER)
context = await aggregate.model_invoking(message)
context = await aggregate.invoking(message)
assert isinstance(context, Context)
assert not context.contents
assert not context.messages
async def test_model_invoking_with_mutable_sequence(self) -> None:
"""Test model_invoking with MutableSequence of messages."""
provider = MockContextProvider([TextContent("Test instructions")])
"""Test invoking with MutableSequence of messages."""
provider = MockContextProvider(messages=[ChatMessage(role="user", text="Test instructions")])
aggregate = AggregateContextProvider([provider])
messages = [ChatMessage(text="Hello", role=Role.USER)]
context = await aggregate.model_invoking(messages)
context = await aggregate.invoking(messages)
assert provider.model_invoking_called
assert provider.invoking_called
assert provider.model_invoking_messages == messages
assert isinstance(context, Context)
assert context.contents
assert isinstance(context.contents[0], TextContent)
assert context.contents[0].text == "Test instructions"
assert context.messages
assert isinstance(context.messages[0].contents[0], TextContent)
assert context.messages[0].text == "Test instructions"
async def test_async_methods_concurrent_execution(self) -> None:
"""Test that async methods execute providers concurrently."""
# Use AsyncMock to verify concurrent execution
provider1 = Mock(spec=ContextProvider)
provider1.thread_created = AsyncMock()
provider1.messages_adding = AsyncMock()
provider1.model_invoking = AsyncMock(return_value=Context(contents=[TextContent("Test 1")]))
provider1.invoked = AsyncMock()
provider1.invoking = AsyncMock(return_value=Context(messages=[ChatMessage(role="user", text="Test 1")]))
provider2 = Mock(spec=ContextProvider)
provider2.thread_created = AsyncMock()
provider2.messages_adding = AsyncMock()
provider2.model_invoking = AsyncMock(return_value=Context(contents=[TextContent("Test 2")]))
provider2.invoked = AsyncMock()
provider2.invoking = AsyncMock(return_value=Context(messages=[ChatMessage(role="user", text="Test 2")]))
aggregate = AggregateContextProvider([provider1, provider2])
@@ -285,18 +277,20 @@ class TestAggregateContextProvider:
provider1.thread_created.assert_called_once_with("thread-123")
provider2.thread_created.assert_called_once_with("thread-123")
# Test messages_adding
# Test invoked
message = ChatMessage(text="Hello", role=Role.USER)
await aggregate.messages_adding("thread-123", message)
provider1.messages_adding.assert_called_once_with("thread-123", message)
provider2.messages_adding.assert_called_once_with("thread-123", message)
await aggregate.invoked(message)
provider1.invoked.assert_called_once_with(
request_messages=message, response_messages=None, invoke_exception=None
)
provider2.invoked.assert_called_once_with(
request_messages=message, response_messages=None, invoke_exception=None
)
# Test model_invoking
context = await aggregate.model_invoking(message)
provider1.model_invoking.assert_called_once_with(message)
provider2.model_invoking.assert_called_once_with(message)
assert context.contents
assert isinstance(context.contents[0], TextContent)
assert isinstance(context.contents[1], TextContent)
assert context.contents[0].text == "Test 1"
assert context.contents[1].text == "Test 2"
# Test invoking
context = await aggregate.invoking(message)
provider1.invoking.assert_called_once_with(message)
provider2.invoking.assert_called_once_with(message)
assert context.messages
assert context.messages[0].text == "Test 1"
assert context.messages[1].text == "Test 2"
@@ -1505,9 +1505,13 @@ class TestChatAgentChatMiddleware:
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
) -> None:
# Modify the first message by adding a prefix
if context.messages and len(context.messages) > 0:
original_text = context.messages[0].text or ""
context.messages[0] = ChatMessage(role=context.messages[0].role, text=f"MODIFIED: {original_text}")
if context.messages:
for idx, msg in enumerate(context.messages):
if msg.role.value == "system":
continue
original_text = msg.text or ""
context.messages[idx] = ChatMessage(role=msg.role, text=f"MODIFIED: {original_text}")
break
await next(context)
# Create ChatAgent with message-modifying middleware
@@ -1519,8 +1523,7 @@ class TestChatAgentChatMiddleware:
response = await agent.run(messages)
# Verify that the message was modified (MockBaseChatClient echoes back the input)
assert response is not None
assert len(response.messages) > 0
assert response and response.messages
assert "MODIFIED: test message" in response.messages[0].text
async def test_chat_middleware_can_override_response(self) -> None:
+72 -156
View File
@@ -5,12 +5,13 @@ from typing import Any
import pytest
from agent_framework import AgentThread, ChatMessage, ChatMessageList, Role
from agent_framework._threads import StoreState, ThreadState, deserialize_thread_state, thread_on_new_messages
from agent_framework import AgentThread, ChatMessage, ChatMessageStore, Role
from agent_framework._threads import AgentThreadState, ChatMessageStoreState
from agent_framework.exceptions import AgentThreadException
class MockChatMessageStore:
"""Mock implementation of ChatMessageStore for testing."""
"""Mock implementation of ChatMessageStoreProtocol for testing."""
def __init__(self, messages: list[ChatMessage] | None = None) -> None:
self._messages = messages or []
@@ -23,15 +24,21 @@ class MockChatMessageStore:
async def add_messages(self, messages: Sequence[ChatMessage]) -> None:
self._messages.extend(messages)
async def serialize_state(self, **kwargs: Any) -> Any:
async def serialize(self, **kwargs: Any) -> Any:
self._serialize_calls += 1
return {"messages": [msg.__dict__ for msg in self._messages], "kwargs": kwargs}
async def deserialize_state(self, serialized_store_state: Any, **kwargs: Any) -> None:
async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) -> None:
self._deserialize_calls += 1
if serialized_store_state and "messages" in serialized_store_state:
self._messages = serialized_store_state["messages"]
@classmethod
async def deserialize(cls, serialized_store_state: Any, **kwargs: Any) -> "MockChatMessageStore":
instance = cls()
await instance.update_from_state(serialized_store_state, **kwargs)
return instance
@pytest.fixture
def sample_messages() -> list[ChatMessage]:
@@ -67,7 +74,7 @@ class TestAgentThread:
def test_init_with_message_store(self) -> None:
"""Test AgentThread initialization with message_store."""
store = ChatMessageList()
store = ChatMessageStore()
thread = AgentThread(message_store=store)
assert thread.service_thread_id is None
assert thread.message_store is store
@@ -81,11 +88,11 @@ class TestAgentThread:
assert thread.service_thread_id == service_thread_id
def test_service_thread_id_setter_with_existing_message_store_raises_error(self) -> None:
"""Test that setting service_thread_id when message_store exists raises ValueError."""
store = ChatMessageList()
"""Test that setting service_thread_id when message_store exists raises AgentThreadException."""
store = ChatMessageStore()
thread = AgentThread(message_store=store)
with pytest.raises(ValueError, match="Only the service_thread_id or message_store may be set"):
with pytest.raises(AgentThreadException, match="Only the service_thread_id or message_store may be set"):
thread.service_thread_id = "test-conversation-789"
def test_service_thread_id_setter_with_none_values(self) -> None:
@@ -97,18 +104,18 @@ class TestAgentThread:
def test_message_store_property_setter(self) -> None:
"""Test message_store property setter."""
thread = AgentThread()
store = ChatMessageList()
store = ChatMessageStore()
thread.message_store = store
assert thread.message_store is store
def test_message_store_setter_with_existing_service_thread_id_raises_error(self) -> None:
"""Test that setting message_store when service_thread_id exists raises ValueError."""
"""Test that setting message_store when service_thread_id exists raises AgentThreadException."""
service_thread_id = "test-conversation-999"
thread = AgentThread(service_thread_id=service_thread_id)
store = ChatMessageList()
store = ChatMessageStore()
with pytest.raises(ValueError, match="Only the service_thread_id or message_store may be set"):
with pytest.raises(AgentThreadException, match="Only the service_thread_id or message_store may be set"):
thread.message_store = store
def test_message_store_setter_with_none_values(self) -> None:
@@ -119,7 +126,7 @@ class TestAgentThread:
async def test_get_messages_with_message_store(self, sample_messages: list[ChatMessage]) -> None:
"""Test get_messages when message_store is set."""
store = ChatMessageList(sample_messages)
store = ChatMessageStore(sample_messages)
thread = AgentThread(message_store=store)
assert thread.message_store is not None
@@ -142,19 +149,19 @@ class TestAgentThread:
"""Test _on_new_messages when service_thread_id is set (should do nothing)."""
thread = AgentThread(service_thread_id="test-conv")
await thread_on_new_messages(thread, sample_message)
await thread.on_new_messages(sample_message)
# Should not create a message store
assert thread.message_store is None
async def test_on_new_messages_single_message_creates_store(self, sample_message: ChatMessage) -> None:
"""Test _on_new_messages with single message creates ChatMessageList."""
"""Test _on_new_messages with single message creates ChatMessageStore."""
thread = AgentThread()
await thread_on_new_messages(thread, sample_message)
await thread.on_new_messages(sample_message)
assert thread.message_store is not None
assert isinstance(thread.message_store, ChatMessageList)
assert isinstance(thread.message_store, ChatMessageStore)
messages = await thread.message_store.list_messages()
assert len(messages) == 1
assert messages[0].text == "Test message"
@@ -163,7 +170,7 @@ class TestAgentThread:
"""Test _on_new_messages with multiple messages."""
thread = AgentThread()
await thread_on_new_messages(thread, sample_messages)
await thread.on_new_messages(sample_messages)
assert thread.message_store is not None
messages = await thread.message_store.list_messages()
@@ -172,10 +179,10 @@ class TestAgentThread:
async def test_on_new_messages_with_existing_store(self, sample_message: ChatMessage) -> None:
"""Test _on_new_messages adds to existing message store."""
initial_messages = [ChatMessage(role=Role.USER, text="Initial", message_id="init1")]
store = ChatMessageList(initial_messages)
store = ChatMessageStore(initial_messages)
thread = AgentThread(message_store=store)
await thread_on_new_messages(thread, sample_message)
await thread.on_new_messages(sample_message)
assert thread.message_store is not None
messages = await thread.message_store.list_messages()
@@ -185,32 +192,30 @@ class TestAgentThread:
async def test_deserialize_with_service_thread_id(self) -> None:
"""Test _deserialize with service_thread_id."""
thread = AgentThread()
serialized_data = {"service_thread_id": "test-conv-123", "chat_message_store_state": None}
await deserialize_thread_state(thread, serialized_data)
thread = await AgentThread.deserialize(serialized_data)
assert thread.service_thread_id == "test-conv-123"
assert thread.message_store is None
async def test_deserialize_with_store_state(self, sample_messages: list[ChatMessage]) -> None:
"""Test _deserialize with chat_message_store_state."""
thread = AgentThread()
store_state = {"messages": sample_messages}
serialized_data = {"service_thread_id": None, "chat_message_store_state": store_state}
await deserialize_thread_state(thread, serialized_data)
thread = await AgentThread.deserialize(serialized_data)
assert thread.service_thread_id is None
assert thread.message_store is not None
assert isinstance(thread.message_store, ChatMessageList)
assert isinstance(thread.message_store, ChatMessageStore)
async def test_deserialize_with_no_state(self) -> None:
"""Test _deserialize with no state."""
thread = AgentThread()
serialized_data = {"service_thread_id": None, "chat_message_store_state": None}
await deserialize_thread_state(thread, serialized_data)
await thread.deserialize(serialized_data)
assert thread.service_thread_id is None
assert thread.message_store is None
@@ -221,7 +226,7 @@ class TestAgentThread:
thread = AgentThread(message_store=store)
serialized_data: dict[str, Any] = {"service_thread_id": None, "chat_message_store_state": {"messages": []}}
await deserialize_thread_state(thread, serialized_data)
await thread.update_from_thread_state(serialized_data)
assert store._deserialize_calls == 1 # pyright: ignore[reportPrivateUsage]
@@ -265,31 +270,31 @@ class TestAgentThread:
class TestChatMessageList:
"""Test cases for ChatMessageList class."""
"""Test cases for ChatMessageStore class."""
def test_init_empty(self) -> None:
"""Test ChatMessageList initialization with no messages."""
store = ChatMessageList()
assert len(store) == 0
"""Test ChatMessageStore initialization with no messages."""
store = ChatMessageStore()
assert len(store.messages) == 0
def test_init_with_messages(self, sample_messages: list[ChatMessage]) -> None:
"""Test ChatMessageList initialization with messages."""
store = ChatMessageList(sample_messages)
assert len(store) == 3
"""Test ChatMessageStore initialization with messages."""
store = ChatMessageStore(sample_messages)
assert len(store.messages) == 3
async def test_add_messages(self, sample_messages: list[ChatMessage]) -> None:
"""Test adding messages to the store."""
store = ChatMessageList()
store = ChatMessageStore()
await store.add_messages(sample_messages)
assert len(store) == 3
assert len(store.messages) == 3
messages = await store.list_messages()
assert messages[0].text == "Hello"
async def test_get_messages(self, sample_messages: list[ChatMessage]) -> None:
"""Test getting messages from the store."""
store = ChatMessageList(sample_messages)
store = ChatMessageStore(sample_messages)
messages = await store.list_messages()
@@ -298,28 +303,28 @@ class TestChatMessageList:
async def test_serialize_state(self, sample_messages: list[ChatMessage]) -> None:
"""Test serializing store state."""
store = ChatMessageList(sample_messages)
store = ChatMessageStore(sample_messages)
result = await store.serialize_state()
result = await store.serialize()
assert "messages" in result
assert len(result["messages"]) == 3
async def test_serialize_state_empty(self) -> None:
"""Test serializing empty store state."""
store = ChatMessageList()
store = ChatMessageStore()
result = await store.serialize_state()
result = await store.serialize()
assert "messages" in result
assert len(result["messages"]) == 0
async def test_deserialize_state(self, sample_messages: list[ChatMessage]) -> None:
"""Test deserializing store state."""
store = ChatMessageList()
store = ChatMessageStore()
state_data = {"messages": sample_messages}
await store.deserialize_state(state_data)
await store.update_from_state(state_data)
messages = await store.list_messages()
assert len(messages) == 3
@@ -327,156 +332,67 @@ class TestChatMessageList:
async def test_deserialize_state_none(self) -> None:
"""Test deserializing None state."""
store = ChatMessageList()
store = ChatMessageStore()
await store.deserialize_state(None)
await store.update_from_state(None)
assert len(store) == 0
assert len(store.messages) == 0
async def test_deserialize_state_empty(self) -> None:
"""Test deserializing empty state."""
store = ChatMessageList()
store = ChatMessageStore()
await store.deserialize_state({})
await store.update_from_state({})
assert len(store) == 0
def test_len(self, sample_messages: list[ChatMessage]) -> None:
"""Test __len__ method."""
store = ChatMessageList(sample_messages)
assert len(store) == 3
empty_store = ChatMessageList()
assert len(empty_store) == 0
def test_getitem(self, sample_messages: list[ChatMessage]) -> None:
"""Test __getitem__ method."""
store = ChatMessageList(sample_messages)
assert store[0].text == "Hello"
assert store[1].text == "Hi there!"
assert store[2].text == "How are you?"
def test_setitem(self, sample_messages: list[ChatMessage], sample_message: ChatMessage) -> None:
"""Test __setitem__ method."""
store = ChatMessageList(sample_messages)
store[1] = sample_message
assert store[1].text == "Test message"
assert store[1].message_id == "test1"
def test_append(self, sample_message: ChatMessage) -> None:
"""Test append method."""
store = ChatMessageList()
store.append(sample_message)
assert len(store) == 1
assert store[0].text == "Test message"
def test_clear(self, sample_messages: list[ChatMessage]) -> None:
"""Test clear method."""
store = ChatMessageList(sample_messages)
assert len(store) == 3
store.clear()
assert len(store) == 0
def test_index(self, sample_messages: list[ChatMessage]) -> None:
"""Test index method."""
store = ChatMessageList(sample_messages)
index = store.index(sample_messages[1])
assert index == 1
def test_insert(self, sample_messages: list[ChatMessage], sample_message: ChatMessage) -> None:
"""Test insert method."""
store = ChatMessageList(sample_messages)
store.insert(1, sample_message)
assert len(store) == 4
assert store[1].text == "Test message"
assert store[2].text == "Hi there!" # Original message at index 1 is now at index 2
def test_remove(self, sample_messages: list[ChatMessage]) -> None:
"""Test remove method."""
store = ChatMessageList(sample_messages)
message_to_remove = sample_messages[1]
store.remove(message_to_remove)
assert len(store) == 2
assert store[0].text == "Hello"
assert store[1].text == "How are you?"
def test_pop_default(self, sample_messages: list[ChatMessage]) -> None:
"""Test pop method with default index."""
store = ChatMessageList(sample_messages)
popped_message = store.pop()
assert len(store) == 2
assert popped_message.text == "How are you?" # Last message
def test_pop_with_index(self, sample_messages: list[ChatMessage]) -> None:
"""Test pop method with specific index."""
store = ChatMessageList(sample_messages)
popped_message = store.pop(1)
assert len(store) == 2
assert popped_message.text == "Hi there!"
assert store[0].text == "Hello"
assert store[1].text == "How are you?"
assert len(store.messages) == 0
class TestStoreState:
"""Test cases for StoreState class."""
"""Test cases for ChatMessageStoreState class."""
def test_init(self, sample_messages: list[ChatMessage]) -> None:
"""Test StoreState initialization."""
state = StoreState(messages=sample_messages)
"""Test ChatMessageStoreState initialization."""
state = ChatMessageStoreState(messages=sample_messages)
assert len(state.messages) == 3
assert state.messages[0].text == "Hello"
def test_init_empty(self) -> None:
"""Test StoreState initialization with empty messages."""
state = StoreState(messages=[])
"""Test ChatMessageStoreState initialization with empty messages."""
state = ChatMessageStoreState(messages=[])
assert len(state.messages) == 0
class TestThreadState:
"""Test cases for ThreadState class."""
"""Test cases for AgentThreadState class."""
def test_init_with_service_thread_id(self) -> None:
"""Test ThreadState initialization with service_thread_id."""
state = ThreadState(service_thread_id="test-conv-123")
"""Test AgentThreadState initialization with service_thread_id."""
state = AgentThreadState(service_thread_id="test-conv-123")
assert state.service_thread_id == "test-conv-123"
assert state.chat_message_store_state is None
def test_init_with_chat_message_store_state(self) -> None:
"""Test ThreadState initialization with chat_message_store_state."""
"""Test AgentThreadState initialization with chat_message_store_state."""
store_data: dict[str, Any] = {"messages": []}
state = ThreadState(chat_message_store_state=store_data)
state = AgentThreadState(chat_message_store_state=store_data)
assert state.service_thread_id is None
assert state.chat_message_store_state == store_data
def test_init_with_both(self) -> None:
"""Test ThreadState initialization with both parameters."""
"""Test AgentThreadState initialization with both parameters."""
store_data: dict[str, Any] = {"messages": []}
state = ThreadState(service_thread_id="test-conv-456", chat_message_store_state=store_data)
assert state.service_thread_id == "test-conv-456"
assert state.chat_message_store_state == store_data
with pytest.raises(
AgentThreadException, match="Only one of service_thread_id or chat_message_store_state may be set"
):
AgentThreadState(service_thread_id="test-conv-123", chat_message_store_state=store_data)
def test_init_defaults(self) -> None:
"""Test ThreadState initialization with defaults."""
state = ThreadState()
"""Test AgentThreadState initialization with defaults."""
state = AgentThreadState()
assert state.service_thread_id is None
assert state.chat_message_store_state is None
@@ -678,7 +678,6 @@ def test_chat_response_updates_to_chat_response_multiple_multiple():
assert chat_response.text == "I'm doing well, thank you! More contextFinal part"
@mark.asyncio
async def test_chat_response_from_async_generator():
async def gen() -> AsyncIterable[ChatResponseUpdate]:
yield ChatResponseUpdate(text="Hello", message_id="1")
@@ -688,7 +687,6 @@ async def test_chat_response_from_async_generator():
assert resp.text == "Hello world"
@mark.asyncio
async def test_chat_response_from_async_generator_output_format():
async def gen() -> AsyncIterable[ChatResponseUpdate]:
yield ChatResponseUpdate(text='{ "respon', message_id="1")
@@ -702,7 +700,6 @@ async def test_chat_response_from_async_generator_output_format():
assert resp.value.response == "Hello"
@mark.asyncio
async def test_chat_response_from_async_generator_output_format_in_method():
async def gen() -> AsyncIterable[ChatResponseUpdate]:
yield ChatResponseUpdate(text='{ "respon', message_id="1")
@@ -1199,7 +1196,6 @@ def agent_run_response_async() -> AgentRunResponse:
return AgentRunResponse(messages=[ChatMessage(role="user", text="Hello")])
@mark.asyncio
async def test_agent_run_response_from_async_generator():
async def gen():
yield AgentRunResponseUpdate(contents=[TextContent("A")])
@@ -383,7 +383,6 @@ async def test_openai_assistants_client_prepare_thread_existing_no_run(mock_asyn
mock_async_openai.beta.threads.runs.cancel.assert_not_called()
@pytest.mark.asyncio
async def test_openai_assistants_client_process_stream_events_thread_run_created(mock_async_openai: MagicMock) -> None:
"""Test _process_stream_events with thread.run.created event."""
chat_client = create_test_openai_assistants_client(mock_async_openai)
@@ -417,7 +416,6 @@ async def test_openai_assistants_client_process_stream_events_thread_run_created
assert update.raw_representation == mock_response.data
@pytest.mark.asyncio
async def test_openai_assistants_client_process_stream_events_message_delta_text(mock_async_openai: MagicMock) -> None:
"""Test _process_stream_events with thread.message.delta event containing text."""
chat_client = create_test_openai_assistants_client(mock_async_openai)
@@ -462,7 +460,6 @@ async def test_openai_assistants_client_process_stream_events_message_delta_text
assert update.raw_representation == mock_message_delta
@pytest.mark.asyncio
async def test_openai_assistants_client_process_stream_events_requires_action(mock_async_openai: MagicMock) -> None:
"""Test _process_stream_events with thread.run.requires_action event."""
chat_client = create_test_openai_assistants_client(mock_async_openai)
@@ -506,7 +503,6 @@ async def test_openai_assistants_client_process_stream_events_requires_action(mo
chat_client._create_function_call_contents.assert_called_once_with(mock_run, None) # type: ignore
@pytest.mark.asyncio
async def test_openai_assistants_client_process_stream_events_run_step_created(mock_async_openai: MagicMock) -> None:
"""Test _process_stream_events with thread.run.step.created event."""
@@ -539,7 +535,6 @@ async def test_openai_assistants_client_process_stream_events_run_step_created(m
assert len(updates) == 0
@pytest.mark.asyncio
async def test_openai_assistants_client_process_stream_events_run_completed_with_usage(
mock_async_openai: MagicMock,
) -> None:
@@ -101,7 +101,6 @@ class TimedApproval(RequestInfoMessage):
issued_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
@pytest.mark.asyncio
async def test_rehydrate_falls_back_when_request_type_missing() -> None:
"""Rehydration should succeed even if the original request type cannot be imported.
@@ -144,7 +143,6 @@ async def test_rehydrate_falls_back_when_request_type_missing() -> None:
assert getattr(event.data, "iteration", None) == 2
@pytest.mark.asyncio
async def test_has_pending_request_detects_snapshot() -> None:
request_id = "req-pending"
snapshot = {
@@ -172,7 +170,6 @@ async def test_has_pending_request_detects_snapshot() -> None:
assert await executor.has_pending_request(request_id, ctx)
@pytest.mark.asyncio
async def test_has_pending_request_false_when_snapshot_absent() -> None:
shared_state = SharedState()
runner_ctx = _StubRunnerContext({"pending_requests": {}})
@@ -5,16 +5,20 @@ from collections.abc import MutableSequence, Sequence
from contextlib import AbstractAsyncContextManager
from typing import Any
from agent_framework import ChatMessage, Context, ContextProvider, TextContent
from agent_framework import ChatMessage, Context, ContextProvider
from agent_framework.exceptions import ServiceInitializationError
from mem0 import AsyncMemory, AsyncMemoryClient
from pydantic import PrivateAttr
if sys.version_info >= (3, 11):
from typing import NotRequired, Self, TypedDict # pragma: no cover
else:
from typing_extensions import NotRequired, Self, TypedDict # pragma: no cover
if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
# Type aliases for Mem0 search response formats (v1.1 and v2; v1 is deprecated, but matches the type definition for v2)
class MemorySearchResponse_v1_1(TypedDict):
@@ -26,19 +30,11 @@ MemorySearchResponse_v2 = list[dict[str, Any]]
class Mem0Provider(ContextProvider):
mem0_client: AsyncMemory | AsyncMemoryClient
api_key: str | None = None
application_id: str | None = None
agent_id: str | None = None
thread_id: str | None = None
user_id: str | None = None
scope_to_per_operation_thread_id: bool = False
context_prompt: str = ContextProvider.DEFAULT_CONTEXT_PROMPT
_should_close_client: bool = PrivateAttr(default=False) # Track whether we should close client connection
"""Mem0 Context Provider."""
def __init__(
self,
mem0_client: AsyncMemory | AsyncMemoryClient | None = None,
api_key: str | None = None,
application_id: str | None = None,
agent_id: str | None = None,
@@ -46,11 +42,11 @@ class Mem0Provider(ContextProvider):
user_id: str | None = None,
scope_to_per_operation_thread_id: bool = False,
context_prompt: str = ContextProvider.DEFAULT_CONTEXT_PROMPT,
mem0_client: AsyncMemory | AsyncMemoryClient | None = None,
) -> None:
"""Initializes a new instance of the Mem0Provider class.
Args:
mem0_client: A pre-created Mem0 MemoryClient or None to create a default client.
api_key: The API key for authenticating with the Mem0 API. If not
provided, it will attempt to use the MEM0_API_KEY environment variable.
application_id: The application ID for scoping memories or None.
@@ -59,24 +55,20 @@ class Mem0Provider(ContextProvider):
user_id: The user ID for scoping memories or None.
scope_to_per_operation_thread_id: Whether to scope memories to per-operation thread ID.
context_prompt: The prompt to prepend to retrieved memories.
mem0_client: A pre-created Mem0 MemoryClient or None to create a default client.
"""
should_close_client = False
if mem0_client is None:
mem0_client = AsyncMemoryClient(api_key=api_key)
should_close_client = True
super().__init__(
api_key=api_key, # type: ignore[reportCallIssue]
application_id=application_id, # type: ignore[reportCallIssue]
agent_id=agent_id, # type: ignore[reportCallIssue]
thread_id=thread_id, # type: ignore[reportCallIssue]
user_id=user_id, # type: ignore[reportCallIssue]
scope_to_per_operation_thread_id=scope_to_per_operation_thread_id, # type: ignore[reportCallIssue]
context_prompt=context_prompt, # type: ignore[reportCallIssue]
mem0_client=mem0_client, # type: ignore[reportCallIssue]
)
self.api_key = api_key
self.application_id = application_id
self.agent_id = agent_id
self.thread_id = thread_id
self.user_id = user_id
self.scope_to_per_operation_thread_id = scope_to_per_operation_thread_id
self.context_prompt = context_prompt
self.mem0_client = mem0_client
self._per_operation_thread_id: str | None = None
self._should_close_client = should_close_client
@@ -100,18 +92,27 @@ class Mem0Provider(ContextProvider):
self._validate_per_operation_thread_id(thread_id)
self._per_operation_thread_id = self._per_operation_thread_id or thread_id
async def messages_adding(self, thread_id: str | None, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
"""Called when a new message is being added to the thread.
Args:
thread_id: The ID of the thread or None.
new_messages: New messages to add.
"""
@override
async def invoked(
self,
request_messages: ChatMessage | Sequence[ChatMessage],
response_messages: ChatMessage | Sequence[ChatMessage] | None = None,
invoke_exception: Exception | None = None,
**kwargs: Any,
) -> None:
self._validate_filters()
self._validate_per_operation_thread_id(thread_id)
self._per_operation_thread_id = self._per_operation_thread_id or thread_id
messages_list = [new_messages] if isinstance(new_messages, ChatMessage) else list(new_messages)
request_messages_list = (
[request_messages] if isinstance(request_messages, ChatMessage) else list(request_messages)
)
response_messages_list = (
[response_messages]
if isinstance(response_messages, ChatMessage)
else list(response_messages)
if response_messages
else []
)
messages_list = [*request_messages_list, *response_messages_list]
messages: list[dict[str, str]] = [
{"role": message.role.value, "content": message.text}
@@ -128,11 +129,13 @@ class Mem0Provider(ContextProvider):
metadata={"application_id": self.application_id},
)
async def model_invoking(self, messages: ChatMessage | MutableSequence[ChatMessage]) -> Context:
@override
async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], **kwargs: Any) -> Context:
"""Called before invoking the AI model to provide context.
Args:
messages: List of new messages in the thread.
kwargs: not used at present.
Returns:
Context: Context object containing instructions with memories.
@@ -159,9 +162,11 @@ class Mem0Provider(ContextProvider):
line_separated_memories = "\n".join(memory.get("memory", "") for memory in memories)
content = TextContent(f"{self.context_prompt}\n{line_separated_memories}") if line_separated_memories else None
return Context(contents=[content] if content else None)
return Context(
messages=[ChatMessage(role="user", text=f"{self.context_prompt}\n{line_separated_memories}")]
if line_separated_memories
else None
)
def _validate_filters(self) -> None:
"""Validates that at least one filter is provided.
@@ -4,7 +4,7 @@
from unittest.mock import AsyncMock, patch
import pytest
from agent_framework import ChatMessage, Context, Role, TextContent
from agent_framework import ChatMessage, Context, Role
from agent_framework.exceptions import ServiceInitializationError
from agent_framework.mem0 import Mem0Provider
@@ -173,27 +173,25 @@ class TestMem0ProviderThreadMethods:
assert "can only be used with one thread at a time" in str(exc_info.value)
async def test_messages_adding_sets_per_operation_thread_id(
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
) -> None:
"""Test that messages_adding sets per-operation thread ID."""
async def test_messages_adding_sets_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None:
"""Test that invoked sets per-operation thread ID."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
await provider.messages_adding("thread123", sample_messages)
await provider.thread_created("thread123")
assert provider._per_operation_thread_id == "thread123"
class TestMem0ProviderMessagesAdding:
"""Test messages_adding method."""
"""Test invoked method."""
async def test_messages_adding_fails_without_filters(self, mock_mem0_client: AsyncMock) -> None:
"""Test that messages_adding fails when no filters are provided."""
"""Test that invoked fails when no filters are provided."""
provider = Mem0Provider(mem0_client=mock_mem0_client)
message = ChatMessage(role=Role.USER, text="Hello!")
with pytest.raises(ServiceInitializationError) as exc_info:
await provider.messages_adding("thread123", message)
await provider.invoked(message)
assert "At least one of the filters" in str(exc_info.value)
@@ -202,7 +200,7 @@ class TestMem0ProviderMessagesAdding:
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
message = ChatMessage(role=Role.USER, text="Hello!")
await provider.messages_adding("thread123", message)
await provider.invoked(message)
mock_mem0_client.add.assert_called_once()
call_args = mock_mem0_client.add.call_args
@@ -215,7 +213,7 @@ class TestMem0ProviderMessagesAdding:
"""Test adding multiple messages."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
await provider.messages_adding("thread123", sample_messages)
await provider.invoked(sample_messages)
mock_mem0_client.add.assert_called_once()
call_args = mock_mem0_client.add.call_args
@@ -232,7 +230,7 @@ class TestMem0ProviderMessagesAdding:
"""Test adding messages with agent_id."""
provider = Mem0Provider(agent_id="agent123", mem0_client=mock_mem0_client)
await provider.messages_adding("thread123", sample_messages)
await provider.invoked(sample_messages)
call_args = mock_mem0_client.add.call_args
assert call_args.kwargs["agent_id"] == "agent123"
@@ -244,7 +242,7 @@ class TestMem0ProviderMessagesAdding:
"""Test adding messages with application_id in metadata."""
provider = Mem0Provider(user_id="user123", application_id="app123", mem0_client=mock_mem0_client)
await provider.messages_adding("thread123", sample_messages)
await provider.invoked(sample_messages)
call_args = mock_mem0_client.add.call_args
assert call_args.kwargs["metadata"] == {"application_id": "app123"}
@@ -261,7 +259,8 @@ class TestMem0ProviderMessagesAdding:
)
provider._per_operation_thread_id = "operation_thread"
await provider.messages_adding("operation_thread", sample_messages)
await provider.thread_created(thread_id="operation_thread")
await provider.invoked(sample_messages)
call_args = mock_mem0_client.add.call_args
assert call_args.kwargs["run_id"] == "operation_thread"
@@ -277,7 +276,7 @@ class TestMem0ProviderMessagesAdding:
mem0_client=mock_mem0_client,
)
await provider.messages_adding("operation_thread", sample_messages)
await provider.invoked(sample_messages)
call_args = mock_mem0_client.add.call_args
assert call_args.kwargs["run_id"] == "base_thread"
@@ -291,7 +290,7 @@ class TestMem0ProviderMessagesAdding:
ChatMessage(role=Role.USER, text="Valid message"),
]
await provider.messages_adding("thread123", messages)
await provider.invoked(messages)
call_args = mock_mem0_client.add.call_args
# Should only include the valid message
@@ -305,26 +304,26 @@ class TestMem0ProviderMessagesAdding:
ChatMessage(role=Role.USER, text=" "),
]
await provider.messages_adding("thread123", messages)
await provider.invoked(messages)
mock_mem0_client.add.assert_not_called()
class TestMem0ProviderModelInvoking:
"""Test model_invoking method."""
"""Test invoking method."""
async def test_model_invoking_fails_without_filters(self, mock_mem0_client: AsyncMock) -> None:
"""Test that model_invoking fails when no filters are provided."""
"""Test that invoking fails when no filters are provided."""
provider = Mem0Provider(mem0_client=mock_mem0_client)
message = ChatMessage(role=Role.USER, text="What's the weather?")
with pytest.raises(ServiceInitializationError) as exc_info:
await provider.model_invoking(message)
await provider.invoking(message)
assert "At least one of the filters" in str(exc_info.value)
async def test_model_invoking_single_message(self, mock_mem0_client: AsyncMock) -> None:
"""Test model_invoking with a single message."""
"""Test invoking with a single message."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
message = ChatMessage(role=Role.USER, text="What's the weather?")
@@ -334,7 +333,7 @@ class TestMem0ProviderModelInvoking:
{"memory": "User lives in Seattle"},
]
context = await provider.model_invoking(message)
context = await provider.invoking(message)
mock_mem0_client.search.assert_called_once()
call_args = mock_mem0_client.search.call_args
@@ -347,39 +346,38 @@ class TestMem0ProviderModelInvoking:
"User likes outdoor activities\nUser lives in Seattle"
)
assert context.contents
assert isinstance(context.contents[0], TextContent)
assert context.contents[0].text == expected_instructions
assert context.messages
assert context.messages[0].text == expected_instructions
async def test_model_invoking_multiple_messages(
self, mock_mem0_client: AsyncMock, sample_messages: list[ChatMessage]
) -> None:
"""Test model_invoking with multiple messages."""
"""Test invoking with multiple messages."""
provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client)
mock_mem0_client.search.return_value = [{"memory": "Previous conversation context"}]
await provider.model_invoking(sample_messages)
await provider.invoking(sample_messages)
call_args = mock_mem0_client.search.call_args
expected_query = "Hello, how are you?\nI'm doing well, thank you!\nYou are a helpful assistant"
assert call_args.kwargs["query"] == expected_query
async def test_model_invoking_with_agent_id(self, mock_mem0_client: AsyncMock) -> None:
"""Test model_invoking with agent_id."""
"""Test invoking with agent_id."""
provider = Mem0Provider(agent_id="agent123", mem0_client=mock_mem0_client)
message = ChatMessage(role=Role.USER, text="Hello")
mock_mem0_client.search.return_value = []
await provider.model_invoking(message)
await provider.invoking(message)
call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["agent_id"] == "agent123"
assert call_args.kwargs["user_id"] is None
async def test_model_invoking_with_scope_to_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None:
"""Test model_invoking with scope_to_per_operation_thread_id enabled."""
"""Test invoking with scope_to_per_operation_thread_id enabled."""
provider = Mem0Provider(
user_id="user123",
thread_id="base_thread",
@@ -391,7 +389,7 @@ class TestMem0ProviderModelInvoking:
mock_mem0_client.search.return_value = []
await provider.model_invoking(message)
await provider.invoking(message)
call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["run_id"] == "operation_thread"
@@ -403,10 +401,10 @@ class TestMem0ProviderModelInvoking:
mock_mem0_client.search.return_value = []
context = await provider.model_invoking(message)
context = await provider.invoking(message)
assert isinstance(context, Context)
assert not context.contents
assert not context.messages
async def test_model_invoking_filters_empty_message_text(self, mock_mem0_client: AsyncMock) -> None:
"""Test that empty message text is filtered out from query."""
@@ -419,13 +417,13 @@ class TestMem0ProviderModelInvoking:
mock_mem0_client.search.return_value = []
await provider.model_invoking(messages)
await provider.invoking(messages)
call_args = mock_mem0_client.search.call_args
assert call_args.kwargs["query"] == "Valid message"
async def test_model_invoking_custom_context_prompt(self, mock_mem0_client: AsyncMock) -> None:
"""Test model_invoking with custom context prompt."""
"""Test invoking with custom context prompt."""
custom_prompt = "## Custom Context\nRemember these details:"
provider = Mem0Provider(
user_id="user123",
@@ -436,12 +434,11 @@ class TestMem0ProviderModelInvoking:
mock_mem0_client.search.return_value = [{"memory": "Test memory"}]
context = await provider.model_invoking(message)
context = await provider.invoking(message)
expected_instructions = "## Custom Context\nRemember these details:\nTest memory"
assert context.contents
assert isinstance(context.contents[0], TextContent)
assert context.contents[0].text == expected_instructions
assert context.messages
assert context.messages[0].text == expected_instructions
class TestMem0ProviderValidation:
@@ -22,7 +22,7 @@ class RedisStoreState(AFBaseModel):
class RedisChatMessageStore:
"""Redis-backed implementation of ChatMessageStore using Redis Lists.
"""Redis-backed implementation of ChatMessageStoreProtocol using Redis Lists.
This implementation provides persistent, thread-safe chat message storage using Redis Lists.
Messages are stored as JSON-serialized strings in chronological order, with each conversation
@@ -153,9 +153,9 @@ class RedisChatMessageStore:
await pipe.execute()
async def add_messages(self, messages: Sequence[ChatMessage]) -> None:
"""Add messages to the Redis store (ChatMessageStore protocol method).
"""Add messages to the Redis store (ChatMessageStoreProtocol protocol method).
This method implements the required ChatMessageStore protocol for adding messages.
This method implements the required ChatMessageStoreProtocol protocol for adding messages.
Messages are appended to the Redis list in chronological order, with automatic
trimming if message limits are configured.
@@ -190,9 +190,9 @@ class RedisChatMessageStore:
await self._redis_client.ltrim(self.redis_key, -self.max_messages, -1) # type: ignore[misc]
async def list_messages(self) -> list[ChatMessage]:
"""Get all messages from the store in chronological order (ChatMessageStore protocol method).
"""Get all messages from the store in chronological order (ChatMessageStoreProtocol protocol method).
This method implements the required ChatMessageStore protocol for retrieving messages.
This method implements the required ChatMessageStoreProtocol protocol for retrieving messages.
Returns all messages stored in Redis, ordered from oldest (index 0) to newest (index -1).
Returns:
@@ -220,10 +220,10 @@ class RedisChatMessageStore:
return messages
async def serialize_state(self, **kwargs: Any) -> Any:
"""Serialize the current store state for persistence (ChatMessageStore protocol method).
async def serialize(self, **kwargs: Any) -> Any:
"""Serialize the current store state for persistence (ChatMessageStoreProtocol protocol method).
This method implements the required ChatMessageStore protocol for state serialization.
This method implements the required ChatMessageStoreProtocol protocol for state serialization.
Captures the Redis connection configuration and thread information needed to
reconstruct the store and reconnect to the same conversation data.
@@ -243,10 +243,43 @@ class RedisChatMessageStore:
)
return state.model_dump(**kwargs)
async def deserialize_state(self, serialized_store_state: Any, **kwargs: Any) -> None:
"""Deserialize state data into this store instance (ChatMessageStore protocol method).
@classmethod
async def deserialize(cls, serialized_store_state: Any, **kwargs: Any) -> RedisChatMessageStore:
"""Deserialize state data into a new store instance (ChatMessageStoreProtocol protocol method).
This method implements the required ChatMessageStore protocol for state deserialization.
This method implements the required ChatMessageStoreProtocol protocol for state deserialization.
Creates a new RedisChatMessageStore instance from previously serialized data,
allowing the store to reconnect to the same conversation data in Redis.
Args:
serialized_store_state: Previously serialized state data from serialize_state().
Should be a dictionary with thread_id, redis_url, etc.
**kwargs: Additional arguments passed to Pydantic model validation.
Returns:
A new RedisChatMessageStore instance configured from the serialized state.
Raises:
ValueError: If required fields are missing or invalid in the serialized state.
"""
if not serialized_store_state:
raise ValueError("serialized_store_state is required for deserialization")
# Validate and parse the serialized state using Pydantic
state = RedisStoreState.model_validate(serialized_store_state, **kwargs)
# Create and return a new store instance with the deserialized configuration
return cls(
redis_url=state.redis_url,
thread_id=state.thread_id,
key_prefix=state.key_prefix,
max_messages=state.max_messages,
)
async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) -> None:
"""Deserialize state data into this store instance (ChatMessageStoreProtocol protocol method).
This method implements the required ChatMessageStoreProtocol protocol for state deserialization.
Restores the store configuration from previously serialized data, allowing the store
to reconnect to the same conversation data in Redis.
@@ -1,32 +1,34 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import json
import sys
from collections.abc import MutableSequence, Sequence
from functools import reduce
from operator import and_
from typing import Any, Literal, cast
from agent_framework import ChatMessage, Context, ContextProvider, Role, TextContent
import numpy as np
from agent_framework import ChatMessage, Context, ContextProvider, Role
from agent_framework.exceptions import (
AgentException,
ServiceInitializationError,
ServiceInvalidRequestError,
)
from redisvl.index import AsyncSearchIndex
from redisvl.query import FilterQuery, HybridQuery, TextQuery
from redisvl.query.filter import FilterExpression, Tag
from redisvl.utils.token_escaper import TokenEscaper
from redisvl.utils.vectorize import BaseVectorizer
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
else:
from typing_extensions import Self # pragma: no cover
import json
import numpy as np
from redisvl.index import AsyncSearchIndex
from redisvl.query import FilterQuery, HybridQuery, TextQuery
from redisvl.query.filter import FilterExpression, Tag
from redisvl.utils.token_escaper import TokenEscaper
from redisvl.utils.vectorize import BaseVectorizer
if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
class RedisProvider(ContextProvider):
@@ -36,41 +38,73 @@ class RedisProvider(ContextProvider):
Uses full-text or optional hybrid vector search to ground model responses.
"""
# Connection and indexing
redis_url: str = "redis://localhost:6379"
index_name: str = "context"
prefix: str = "context"
def __init__(
self,
redis_url: str = "redis://localhost:6379",
index_name: str = "context",
prefix: str = "context",
# Redis vectorizer configuration (optional, injected by client)
redis_vectorizer: BaseVectorizer | None = None,
vector_field_name: str | None = None,
vector_algorithm: Literal["flat", "hnsw"] | None = None,
vector_distance_metric: Literal["cosine", "ip", "l2"] | None = None,
# Partition fields (indexed for filtering)
application_id: str | None = None,
agent_id: str | None = None,
user_id: str | None = None,
thread_id: str | None = None,
scope_to_per_operation_thread_id: bool = False,
# Prompt and runtime
context_prompt: str = ContextProvider.DEFAULT_CONTEXT_PROMPT,
redis_index: Any = None,
overwrite_index: bool = False,
):
"""Create a Redis Context Provider.
# Redis vectorizer configuration (optional, injected by client)
redis_vectorizer: BaseVectorizer | None = None
vector_field_name: str | None = None
vector_algorithm: Literal["flat", "hnsw"] | None = None
vector_distance_metric: Literal["cosine", "ip", "l2"] | None = None
Args:
redis_url: The Redis server URL.
index_name: The name of the Redis index.
prefix: The prefix for all keys in the Redis database.
redis_vectorizer: The vectorizer to use for Redis.
vector_field_name: The name of the vector field in Redis.
vector_algorithm: The algorithm to use for vector search.
vector_distance_metric: The distance metric to use for vector search.
application_id: The application ID to scope the context.
agent_id: The agent ID to scope the context.
user_id: The user ID to scope the context.
thread_id: The thread ID to scope the context.
scope_to_per_operation_thread_id: Whether to scope to the per-operation thread ID.
context_prompt: The context prompt to use for the provider.
redis_index: The Redis index to use for the provider.
overwrite_index: Whether to overwrite the existing Redis index.
# Partition fields (indexed for filtering)
application_id: str | None = None
agent_id: str | None = None
user_id: str | None = None
thread_id: str | None = None
scope_to_per_operation_thread_id: bool = False
# Prompt and runtime
context_prompt: str = ContextProvider.DEFAULT_CONTEXT_PROMPT
redis_index: Any = None
overwrite_index: bool = False
_per_operation_thread_id: str | None = None
_token_escaper: TokenEscaper = TokenEscaper()
_conversation_id: str | None = None
_index_initialized: bool = False
_schema_dict: dict[str, Any] | None = None
def model_post_init(self, __context: Any) -> None:
"""Post-initialization hook to set up computed fields after Pydantic initialization.
This is called automatically by Pydantic after the model is initialized.
"""
# Create Redis index using the cached schema_dict property
self.redis_index = AsyncSearchIndex.from_dict(self.schema_dict, redis_url=self.redis_url, validate_on_load=True)
self.redis_url = redis_url
self.index_name = index_name
self.prefix = prefix
if redis_vectorizer is not None and not isinstance(redis_vectorizer, BaseVectorizer):
raise AgentException(
f"The redis vectorizer is not a valid type, got: {type(redis_vectorizer)}, expected: BaseVectorizer."
)
self.redis_vectorizer = redis_vectorizer
self.vector_field_name = vector_field_name
self.vector_algorithm: Literal["flat", "hnsw"] | None = vector_algorithm
self.vector_distance_metric: Literal["cosine", "ip", "l2"] | None = vector_distance_metric
self.application_id = application_id
self.agent_id = agent_id
self.user_id = user_id
self.thread_id = thread_id
self.scope_to_per_operation_thread_id = scope_to_per_operation_thread_id
self.context_prompt = context_prompt
self.overwrite_index = overwrite_index
self._per_operation_thread_id: str | None = None
self._token_escaper: TokenEscaper = TokenEscaper()
self._conversation_id: str | None = None
self._index_initialized: bool = False
self._schema_dict: dict[str, Any] | None = None
self.redis_index = redis_index or AsyncSearchIndex.from_dict(
self.schema_dict, redis_url=self.redis_url, validate_on_load=True
)
@property
def schema_dict(self) -> dict[str, Any]:
@@ -429,6 +463,7 @@ class RedisProvider(ContextProvider):
"""
return self._per_operation_thread_id if self.scope_to_per_operation_thread_id else self.thread_id
@override
async def thread_created(self, thread_id: str | None) -> None:
"""Called when a new thread is created.
@@ -442,20 +477,27 @@ class RedisProvider(ContextProvider):
# Track current conversation id (Agent passes conversation_id here)
self._conversation_id = thread_id or self._conversation_id
async def messages_adding(self, thread_id: str | None, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
"""Called when a new message is being added to the thread.
Validates scope, normalizes allowed roles, and persists messages to Redis via add().
Args:
thread_id: The ID of the thread or None.
new_messages: New messages to add.
"""
@override
async def invoked(
self,
request_messages: ChatMessage | Sequence[ChatMessage],
response_messages: ChatMessage | Sequence[ChatMessage] | None = None,
invoke_exception: Exception | None = None,
**kwargs: Any,
) -> None:
self._validate_filters()
self._validate_per_operation_thread_id(thread_id)
self._per_operation_thread_id = self._per_operation_thread_id or thread_id
messages_list = [new_messages] if isinstance(new_messages, ChatMessage) else list(new_messages)
request_messages_list = (
[request_messages] if isinstance(request_messages, ChatMessage) else list(request_messages)
)
response_messages_list = (
[response_messages]
if isinstance(response_messages, ChatMessage)
else list(response_messages)
if response_messages
else []
)
messages_list = [*request_messages_list, *response_messages_list]
messages: list[dict[str, Any]] = []
for message in messages_list:
@@ -475,7 +517,8 @@ class RedisProvider(ContextProvider):
if messages:
await self._add(data=messages)
async def model_invoking(self, messages: ChatMessage | MutableSequence[ChatMessage]) -> Context:
@override
async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], **kwargs: Any) -> Context:
"""Called before invoking the model to provide scoped context.
Concatenates recent messages into a query, fetches matching memories from Redis.
@@ -483,6 +526,7 @@ class RedisProvider(ContextProvider):
Args:
messages: List of new messages in the thread.
kwargs: not used at present at present.
Returns:
Context: Context object containing instructions with memories.
@@ -495,8 +539,12 @@ class RedisProvider(ContextProvider):
line_separated_memories = "\n".join(
str(memory.get("content", "")) for memory in memories if memory.get("content")
)
content = TextContent(f"{self.context_prompt}\n{line_separated_memories}") if line_separated_memories else None
return Context(contents=[content] if content else None)
return Context(
messages=[ChatMessage(role="user", text=f"{self.context_prompt}\n{line_separated_memories}")]
if line_separated_memories
else None
)
async def __aenter__(self) -> Self:
"""Async context manager entry.
@@ -239,7 +239,7 @@ class TestRedisChatMessageStore:
async def test_serialize_state(self, redis_store):
"""Test state serialization."""
state = await redis_store.serialize_state()
state = await redis_store.serialize()
expected_state = {
"thread_id": "test_thread_123",
@@ -259,7 +259,7 @@ class TestRedisChatMessageStore:
"max_messages": 50,
}
await redis_store.deserialize_state(serialized_state)
await redis_store.update_from_state(serialized_state)
assert redis_store.thread_id == "restored_thread_456"
assert redis_store.redis_url == "redis://localhost:6380"
@@ -270,7 +270,7 @@ class TestRedisChatMessageStore:
"""Test deserializing empty state doesn't change anything."""
original_thread_id = redis_store.thread_id
await redis_store.deserialize_state(None)
await redis_store.update_from_state(None)
assert redis_store.thread_id == original_thread_id
@@ -6,8 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
from agent_framework import ChatMessage, Role
from agent_framework.exceptions import ServiceInitializationError
from pydantic import ValidationError
from agent_framework.exceptions import AgentException, ServiceInitializationError
from redisvl.utils.vectorize import CustomTextVectorizer
from agent_framework_redis import RedisProvider
@@ -121,21 +120,18 @@ class TestRedisProviderMessages:
ChatMessage(role=Role.SYSTEM, text="You are a helpful assistant"),
]
@pytest.mark.asyncio
# Writes require at least one scoping filter to avoid unbounded operations
async def test_messages_adding_requires_filters(self, patch_index_from_dict): # noqa: ARG002
provider = RedisProvider()
with pytest.raises(ServiceInitializationError):
await provider.messages_adding("thread123", ChatMessage(role=Role.USER, text="Hello"))
await provider.invoked("thread123", ChatMessage(role=Role.USER, text="Hello"))
@pytest.mark.asyncio
# Captures the per-operation thread id when provided
async def test_thread_created_sets_per_operation_id(self, patch_index_from_dict): # noqa: ARG002
provider = RedisProvider(user_id="u1")
await provider.thread_created("t1")
assert provider._per_operation_thread_id == "t1"
@pytest.mark.asyncio
# Enforces single-thread usage when scope_to_per_operation_thread_id is True
async def test_thread_created_conflict_when_scoped(self, patch_index_from_dict): # noqa: ARG002
provider = RedisProvider(user_id="u1", scope_to_per_operation_thread_id=True)
@@ -144,7 +140,6 @@ class TestRedisProviderMessages:
await provider.thread_created("t2")
assert "only be used with one thread" in str(exc.value)
@pytest.mark.asyncio
# Aggregates all results from the async paginator into a flat list
async def test_search_all_paginates(self, mock_index: AsyncMock, patch_index_from_dict): # noqa: ARG002
async def gen(_q, page_size: int = 200): # noqa: ARG001, ANN001
@@ -158,14 +153,12 @@ class TestRedisProviderMessages:
class TestRedisProviderModelInvoking:
@pytest.mark.asyncio
# Reads require at least one scoping filter to avoid unbounded operations
async def test_model_invoking_requires_filters(self, patch_index_from_dict): # noqa: ARG002
provider = RedisProvider()
with pytest.raises(ServiceInitializationError):
await provider.model_invoking(ChatMessage(role=Role.USER, text="Hi"))
await provider.invoking(ChatMessage(role=Role.USER, text="Hi"))
@pytest.mark.asyncio
# Ensures text-only search path is used and context is composed from hits
async def test_textquery_path_and_context_contents(
self, mock_index: AsyncMock, patch_index_from_dict, patch_queries
@@ -175,7 +168,7 @@ class TestRedisProviderModelInvoking:
provider = RedisProvider(user_id="u1")
# Act
ctx = await provider.model_invoking([ChatMessage(role=Role.USER, text="q1")])
ctx = await provider.invoking([ChatMessage(role=Role.USER, text="q1")])
# Assert: TextQuery used (not HybridQuery), filter_expression included
assert patch_queries["TextQuery"].call_count == 1
@@ -187,27 +180,25 @@ class TestRedisProviderModelInvoking:
assert "filter_expression" in kwargs
# Context contains memories joined after the default prompt
assert ctx.contents is not None and len(ctx.contents) == 1
text = ctx.contents[0].text
assert ctx.messages is not None and len(ctx.messages) == 1
text = ctx.messages[0].text
assert text.endswith("A\nB")
@pytest.mark.asyncio
# When no results are returned, Context should have no contents
async def test_model_invoking_empty_results_returns_empty_context(
self, mock_index: AsyncMock, patch_index_from_dict, patch_queries
): # noqa: ARG002
mock_index.query = AsyncMock(return_value=[])
provider = RedisProvider(user_id="u1")
ctx = await provider.model_invoking([ChatMessage(role=Role.USER, text="any")])
assert ctx.contents is None
ctx = await provider.invoking([ChatMessage(role=Role.USER, text="any")])
assert ctx.messages == []
@pytest.mark.asyncio
# Ensures hybrid vector-text search is used when a vectorizer and vector field are configured
async def test_hybridquery_path_with_vectorizer(self, mock_index: AsyncMock, patch_index_from_dict, patch_queries): # noqa: ARG002
mock_index.query = AsyncMock(return_value=[{"content": "Hit"}])
provider = RedisProvider(user_id="u1", redis_vectorizer=CUSTOM_VECTORIZER, vector_field_name="vec")
ctx = await provider.model_invoking([ChatMessage(role=Role.USER, text="hello")])
ctx = await provider.invoking([ChatMessage(role=Role.USER, text="hello")])
# Assert: HybridQuery used with vector and vector field
assert patch_queries["HybridQuery"].call_count == 1
@@ -220,18 +211,16 @@ class TestRedisProviderModelInvoking:
assert "filter_expression" in k
# Context assembled from returned memories
assert ctx.contents and "Hit" in ctx.contents[0].text
assert ctx.messages and "Hit" in ctx.messages[0].text
class TestRedisProviderContextManager:
@pytest.mark.asyncio
# Verifies async context manager returns self for chaining
async def test_async_context_manager_returns_self(self, patch_index_from_dict): # noqa: ARG002
provider = RedisProvider(user_id="u1")
async with provider as ctx:
assert ctx is provider
@pytest.mark.asyncio
# Exit should be a no-op and not raise
async def test_aexit_noop(self, patch_index_from_dict): # noqa: ARG002
provider = RedisProvider(user_id="u1")
@@ -239,7 +228,6 @@ class TestRedisProviderContextManager:
class TestMessagesAddingBehavior:
@pytest.mark.asyncio
# Adds messages while injecting partition defaults and preserving allowed roles
async def test_messages_adding_adds_partition_defaults_and_roles(
self, mock_index: AsyncMock, patch_index_from_dict
@@ -257,7 +245,7 @@ class TestMessagesAddingBehavior:
ChatMessage(role=Role.SYSTEM, text="s"),
]
await provider.messages_adding("t1", msgs)
await provider.invoked(msgs)
# Ensure load invoked with shaped docs containing defaults
assert mock_index.load.await_count == 1
@@ -270,9 +258,7 @@ class TestMessagesAddingBehavior:
assert d["application_id"] == "app"
assert d["agent_id"] == "agent"
assert d["user_id"] == "u1"
assert d["thread_id"] == "t1" # scoped via per-operation thread id
@pytest.mark.asyncio
# Skips blank text and disallowed roles (e.g., TOOL) when adding messages
async def test_messages_adding_ignores_blank_and_disallowed_roles(
self, mock_index: AsyncMock, patch_index_from_dict
@@ -282,37 +268,34 @@ class TestMessagesAddingBehavior:
ChatMessage(role=Role.USER, text=" "),
ChatMessage(role=Role.TOOL, text="tool output"),
]
await provider.messages_adding("tid", msgs)
await provider.invoked(msgs)
# No valid messages -> no load
assert mock_index.load.await_count == 0
class TestIndexCreationPublicCalls:
@pytest.mark.asyncio
# Ensures index is created only once when drop=True on first public write call
async def test_messages_adding_triggers_index_create_once_when_drop_true(
self, mock_index: AsyncMock, patch_index_from_dict
): # noqa: ARG002
provider = RedisProvider(user_id="u1", drop_redis_index=True)
await provider.messages_adding("t1", ChatMessage(role=Role.USER, text="m1"))
await provider.messages_adding("t1", ChatMessage(role=Role.USER, text="m2"))
provider = RedisProvider(user_id="u1")
await provider.invoked(ChatMessage(role=Role.USER, text="m1"))
await provider.invoked(ChatMessage(role=Role.USER, text="m2"))
# create only on first call
assert mock_index.create.await_count == 1
@pytest.mark.asyncio
# Ensures index is created when drop=False and the index does not exist on first read
async def test_model_invoking_triggers_create_when_drop_false_and_not_exists(
self, mock_index: AsyncMock, patch_index_from_dict
): # noqa: ARG002
mock_index.exists = AsyncMock(return_value=False)
provider = RedisProvider(user_id="u1", drop_redis_index=False)
provider = RedisProvider(user_id="u1")
mock_index.query = AsyncMock(return_value=[{"content": "C"}])
await provider.model_invoking([ChatMessage(role=Role.USER, text="q")])
await provider.invoking([ChatMessage(role=Role.USER, text="q")])
assert mock_index.create.await_count == 1
class TestThreadCreatedAdditional:
@pytest.mark.asyncio
# Allows None or same thread id repeatedly; different id raises when scoped
async def test_thread_created_allows_none_and_same_id(self, patch_index_from_dict): # noqa: ARG002
provider = RedisProvider(user_id="u1", scope_to_per_operation_thread_id=True)
@@ -327,8 +310,7 @@ class TestThreadCreatedAdditional:
class TestVectorPopulation:
@pytest.mark.asyncio
# When vectorizer configured, messages_adding should embed content and populate the vector field
# When vectorizer configured, invoked should embed content and populate the vector field
async def test_messages_adding_populates_vector_field_when_vectorizer_present(
self, mock_index: AsyncMock, patch_index_from_dict
): # noqa: ARG002
@@ -339,7 +321,7 @@ class TestVectorPopulation:
vector_field_name="vec",
)
await provider.messages_adding("t1", ChatMessage(role=Role.USER, text="hello"))
await provider.invoked(ChatMessage(role=Role.USER, text="hello"))
assert mock_index.load.await_count == 1
(loaded_args, _kwargs) = mock_index.load.call_args
docs = loaded_args[0]
@@ -365,12 +347,11 @@ class TestRedisProviderSchemaVectors:
class DummyVectorizer:
pass
with pytest.raises(ValidationError):
with pytest.raises(AgentException):
RedisProvider(user_id="u1", redis_vectorizer=DummyVectorizer(), vector_field_name="vec")
class TestEnsureIndex:
@pytest.mark.asyncio
# Creates index once and marks _index_initialized to prevent duplicate calls
async def test_ensure_index_creates_once(self, mock_index: AsyncMock, patch_index_from_dict): # noqa: ARG002
# Mock index doesn't exist, so it will be created
@@ -386,7 +367,6 @@ class TestEnsureIndex:
await provider._ensure_index()
assert mock_index.create.await_count == 1
@pytest.mark.asyncio
# Creates index with overwrite=True when overwrite_index=True
async def test_ensure_index_with_overwrite_true(self, mock_index: AsyncMock, patch_index_from_dict): # noqa: ARG002
mock_index.exists = AsyncMock(return_value=True)
@@ -397,7 +377,6 @@ class TestEnsureIndex:
# Should call create with overwrite=True, drop=False
mock_index.create.assert_called_once_with(overwrite=True, drop=False)
@pytest.mark.asyncio
# Creates index with overwrite=False when index doesn't exist
async def test_ensure_index_create_if_missing(self, mock_index: AsyncMock, patch_index_from_dict): # noqa: ARG002
mock_index.exists = AsyncMock(return_value=False)
@@ -408,7 +387,6 @@ class TestEnsureIndex:
# Should call create with overwrite=False, drop=False
mock_index.create.assert_called_once_with(overwrite=False, drop=False)
@pytest.mark.asyncio
# Validates schema compatibility when index exists and overwrite=False
async def test_ensure_index_schema_validation_success(self, mock_index: AsyncMock, patch_index_from_dict): # noqa: ARG002
mock_index.exists = AsyncMock(return_value=True)
@@ -424,7 +402,6 @@ class TestEnsureIndex:
patch_index_from_dict.from_existing.assert_called_once_with("context", redis_url="redis://localhost:6379")
mock_index.create.assert_called_once_with(overwrite=False, drop=False)
@pytest.mark.asyncio
# Raises ServiceInitializationError when schemas don't match
async def test_ensure_index_schema_validation_failure(self, mock_index: AsyncMock, patch_index_from_dict): # noqa: ARG002
mock_index.exists = AsyncMock(return_value=True)