Files
agent-framework/python/packages/durabletask/tests/test_durable_entities.py
Eduard van Valkenburg b1b528e4a8 Python: [BREAKING] Remove deprecated kwargs compatibility paths (#4858)
* [BREAKING] Remove deprecated kwargs compatibility paths

Remove the deprecated kwargs compatibility shims across core agents, clients, tools, middleware, and telemetry.

Keep workflow kwargs behavior intact in this branch and follow up separately in #4850.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix PR CI fallout for kwargs removal

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Address PR review feedback

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* updates

* Fix Azure AI CI fallout

Remove the stale _get_current_conversation_id override from the Azure AI client after the OpenAI base helper was deleted.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fixed new classes

* Fix Assistants deprecated import gating

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix integration replay regressions

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Switch multi-agent hosting samples to Azure chat completions

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Simplify Azure multi-agent sample config

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-27 21:00:12 +00:00

767 lines
28 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
"""Unit tests for AgentEntity.
Run with: pytest tests/test_entities.py -v
"""
from collections.abc import AsyncIterator
from datetime import datetime
from typing import Any, TypeVar
from unittest.mock import AsyncMock, Mock
import pytest
from agent_framework import AgentResponse, AgentResponseUpdate, Content, Message, ResponseStream
from pydantic import BaseModel
from agent_framework_durabletask import (
AgentEntity,
AgentEntityStateProviderMixin,
DurableAgentState,
DurableAgentStateData,
DurableAgentStateMessage,
DurableAgentStateRequest,
DurableAgentStateResponse,
DurableAgentStateTextContent,
DurableAgentStateTextReasoningContent,
RunRequest,
)
from agent_framework_durabletask._entities import DurableTaskEntityStateProvider
StateT = TypeVar("StateT")
class MockEntityContext:
"""Minimal durabletask EntityContext shim for tests."""
def __init__(self, initial_state: Any = None) -> None:
self._state = initial_state
def get_state(
self,
intended_type: type[StateT] | None = None,
default: StateT | None = None,
) -> Any:
del intended_type
if self._state is None:
return default
return self._state
def set_state(self, new_state: Any) -> None:
self._state = new_state
class _InMemoryStateProvider(AgentEntityStateProviderMixin):
"""Test-only state provider for AgentEntity."""
def __init__(self, *, thread_id: str, initial_state: dict[str, Any] | None = None) -> None:
self._thread_id = thread_id
self._state_dict: dict[str, Any] = initial_state or {}
def _get_state_dict(self) -> dict[str, Any]:
return self._state_dict
def _set_state_dict(self, state: dict[str, Any]) -> None:
self._state_dict = state
def _get_thread_id_from_entity(self) -> str:
return self._thread_id
def _make_entity(agent: Any, callback: Any = None, *, thread_id: str = "test-thread") -> AgentEntity:
return AgentEntity(agent, callback=callback, state_provider=_InMemoryStateProvider(thread_id=thread_id))
def _role_value(chat_message: DurableAgentStateMessage) -> str:
"""Helper to extract the string role from a Message."""
role = getattr(chat_message, "role", None)
role_value = getattr(role, "value", role)
if role_value is None:
return ""
return str(role_value)
def _agent_response(text: str | None) -> AgentResponse:
"""Create an AgentResponse with a single assistant message."""
message = Message(role="assistant", text=text) if text is not None else Message(role="assistant", text="")
return AgentResponse(messages=[message], created_at="2024-01-01T00:00:00Z")
def _create_mock_run(response: AgentResponse | None = None, side_effect: Exception | None = None):
"""Create a mock run function that handles stream parameter correctly.
The durabletask entity code tries run(stream=True) first, then falls back to run(stream=False).
This helper creates a mock that raises TypeError for streaming (to trigger fallback) and
returns the response or raises the side_effect for non-streaming.
"""
async def mock_run(*args, stream=False, **kwargs):
if stream:
# Simulate "streaming not supported" to trigger fallback
raise TypeError("streaming not supported")
if side_effect:
raise side_effect
return response
return mock_run
class RecordingCallback:
"""Callback implementation capturing streaming and final responses for assertions."""
def __init__(self):
self.stream_mock = AsyncMock()
self.response_mock = AsyncMock()
async def on_streaming_response_update(
self,
update: AgentResponseUpdate,
context: Any,
) -> None:
await self.stream_mock(update, context)
async def on_agent_response(self, response: AgentResponse, context: Any) -> None:
await self.response_mock(response, context)
class EntityStructuredResponse(BaseModel):
answer: float
class TestAgentEntityInit:
"""Test suite for AgentEntity initialization."""
def test_init_creates_entity(self) -> None:
"""Test that AgentEntity initializes correctly."""
mock_agent = Mock()
entity = _make_entity(mock_agent)
assert entity.agent == mock_agent
assert len(entity.state.data.conversation_history) == 0
assert entity.state.data.extension_data is None
assert entity.state.schema_version == DurableAgentState.SCHEMA_VERSION
def test_init_stores_agent_reference(self) -> None:
"""Test that the agent reference is stored correctly."""
mock_agent = Mock()
mock_agent.name = "TestAgent"
entity = _make_entity(mock_agent)
assert entity.agent.name == "TestAgent"
def test_init_with_different_agent_types(self) -> None:
"""Test initialization with different agent types."""
agent1 = Mock()
agent1.__class__.__name__ = "AzureOpenAIAgent"
agent2 = Mock()
agent2.__class__.__name__ = "CustomAgent"
entity1 = _make_entity(agent1)
entity2 = _make_entity(agent2)
assert entity1.agent.__class__.__name__ == "AzureOpenAIAgent"
assert entity2.agent.__class__.__name__ == "CustomAgent"
class TestDurableTaskEntityStateProvider:
"""Tests for DurableTaskEntityStateProvider wrapper behavior and persistence wiring."""
def _make_durabletask_entity_provider(
self,
agent: Any,
*,
initial_state: dict[str, Any] | None = None,
) -> tuple[DurableTaskEntityStateProvider, MockEntityContext]:
"""Create a DurableTaskEntityStateProvider wired to an in-memory durabletask context."""
entity = DurableTaskEntityStateProvider()
ctx = MockEntityContext(initial_state)
# DurableEntity provides this hook; required for get_state/set_state to work in unit tests.
entity._initialize_entity_context(ctx) # type: ignore[attr-defined]
return entity, ctx
def test_reset_persists_cleared_state(self) -> None:
mock_agent = Mock()
existing_state = {
"schemaVersion": "1.0.0",
"data": {
"conversationHistory": [
{
"$type": "request",
"correlationId": "corr-existing-1",
"createdAt": "2024-01-01T00:00:00Z",
"messages": [{"role": "user", "contents": [{"$type": "text", "text": "msg1"}]}],
}
]
},
}
entity, ctx = self._make_durabletask_entity_provider(mock_agent, initial_state=existing_state)
entity.reset()
persisted = ctx.get_state(dict, default={})
assert isinstance(persisted, dict)
assert persisted["data"]["conversationHistory"] == []
class TestAgentEntityRunAgent:
"""Test suite for the run_agent operation."""
async def test_run_executes_agent(self) -> None:
"""Test that run executes the agent."""
mock_agent = Mock()
mock_response = _agent_response("Test response")
# Mock run() to return response for non-streaming, raise for streaming (to test fallback)
async def mock_run(*args, stream=False, **kwargs):
if stream:
raise TypeError("streaming not supported")
return mock_response
mock_agent.run = mock_run
entity = _make_entity(mock_agent)
result = await entity.run({
"message": "Test message",
"correlationId": "corr-entity-1",
})
# Verify result
assert isinstance(result, AgentResponse)
assert result.text == "Test response"
async def test_run_agent_streaming_callbacks_invoked(self) -> None:
"""Ensure streaming updates trigger callbacks when using run(stream=True)."""
updates = [
AgentResponseUpdate(contents=[Content.from_text(text="Hello")]),
AgentResponseUpdate(contents=[Content.from_text(text=" world")]),
]
async def update_generator() -> AsyncIterator[AgentResponseUpdate]:
for update in updates:
yield update
mock_agent = Mock()
mock_agent.name = "StreamingAgent"
# Mock run() to return ResponseStream when stream=True
def mock_run(*args, stream=False, **kwargs):
if stream:
return ResponseStream(
update_generator(),
finalizer=AgentResponse.from_updates,
)
raise AssertionError("run(stream=False) should not be called when streaming succeeds")
mock_agent.run = mock_run
callback = RecordingCallback()
entity = _make_entity(mock_agent, callback=callback, thread_id="session-1")
result = await entity.run(
{
"message": "Tell me something",
"correlationId": "corr-stream-1",
},
)
assert isinstance(result, AgentResponse)
assert "Hello" in result.text
assert callback.stream_mock.await_count == len(updates)
assert callback.response_mock.await_count == 1
# Validate callback arguments
stream_calls = callback.stream_mock.await_args_list
for expected_update, recorded_call in zip(updates, stream_calls, strict=True):
assert recorded_call.args[0] is expected_update
context = recorded_call.args[1]
assert context.agent_name == "StreamingAgent"
assert context.correlation_id == "corr-stream-1"
assert context.thread_id == "session-1"
assert context.request_message == "Tell me something"
final_call = callback.response_mock.await_args
assert final_call is not None
final_response, final_context = final_call.args
assert final_context.agent_name == "StreamingAgent"
assert final_context.correlation_id == "corr-stream-1"
assert final_context.thread_id == "session-1"
assert final_context.request_message == "Tell me something"
assert getattr(final_response, "text", "").strip()
async def test_run_agent_final_callback_without_streaming(self) -> None:
"""Ensure the final callback fires even when streaming is unavailable."""
mock_agent = Mock()
mock_agent.name = "NonStreamingAgent"
agent_response = _agent_response("Final response")
mock_agent.run = _create_mock_run(response=agent_response)
callback = RecordingCallback()
entity = _make_entity(mock_agent, callback=callback, thread_id="session-2")
result = await entity.run(
{
"message": "Hi",
"correlationId": "corr-final-1",
},
)
assert isinstance(result, AgentResponse)
assert result.text == "Final response"
assert callback.stream_mock.await_count == 0
assert callback.response_mock.await_count == 1
final_call = callback.response_mock.await_args
assert final_call is not None
assert final_call.args[0] is agent_response
final_context = final_call.args[1]
assert final_context.agent_name == "NonStreamingAgent"
assert final_context.correlation_id == "corr-final-1"
assert final_context.thread_id == "session-2"
assert final_context.request_message == "Hi"
async def test_run_agent_updates_conversation_history(self) -> None:
"""Test that run_agent updates the conversation history."""
mock_agent = Mock()
mock_response = _agent_response("Agent response")
mock_agent.run = _create_mock_run(response=mock_response)
entity = _make_entity(mock_agent)
await entity.run({"message": "User message", "correlationId": "corr-entity-2"})
# Should have 2 entries: user message + assistant response
user_history = entity.state.data.conversation_history[0].messages
assistant_history = entity.state.data.conversation_history[1].messages
assert len(user_history) == 1
user_msg = user_history[0]
assert _role_value(user_msg) == "user"
assert user_msg.text == "User message"
assistant_msg = assistant_history[0]
assert _role_value(assistant_msg) == "assistant"
assert assistant_msg.text == "Agent response"
async def test_run_agent_increments_message_count(self) -> None:
"""Test that run_agent increments the message count."""
mock_agent = Mock()
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
entity = _make_entity(mock_agent)
assert len(entity.state.data.conversation_history) == 0
await entity.run({"message": "Message 1", "correlationId": "corr-entity-3a"})
assert len(entity.state.data.conversation_history) == 2
await entity.run({"message": "Message 2", "correlationId": "corr-entity-3b"})
assert len(entity.state.data.conversation_history) == 4
await entity.run({"message": "Message 3", "correlationId": "corr-entity-3c"})
assert len(entity.state.data.conversation_history) == 6
async def test_run_requires_entity_thread_id(self) -> None:
"""Test that AgentEntity.run rejects missing entity thread identifiers."""
mock_agent = Mock()
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
entity = _make_entity(mock_agent, thread_id="")
with pytest.raises(ValueError, match="thread_id"):
await entity.run({"message": "Message", "correlationId": "corr-entity-5"})
async def test_run_agent_multiple_conversations(self) -> None:
"""Test that run_agent maintains history across multiple messages."""
mock_agent = Mock()
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
entity = _make_entity(mock_agent)
# Send multiple messages
await entity.run({"message": "Message 1", "correlationId": "corr-entity-8a"})
await entity.run({"message": "Message 2", "correlationId": "corr-entity-8b"})
await entity.run({"message": "Message 3", "correlationId": "corr-entity-8c"})
history = entity.state.data.conversation_history
assert len(history) == 6
assert entity.state.message_count == 6
async def test_run_filters_reasoning_content_from_replayed_history(self) -> None:
"""Replayed durable history should not include reasoning-only content items."""
captured_messages: list[Message] = []
async def mock_run(*args, stream=False, **kwargs):
if stream:
raise TypeError("streaming not supported")
captured_messages.extend(kwargs["messages"])
return _agent_response("Response")
mock_agent = Mock()
mock_agent.run = mock_run
entity = _make_entity(mock_agent)
entity.state.data = DurableAgentStateData(
conversation_history=[
DurableAgentStateRequest(
correlation_id="corr-entity-prev-request",
created_at=datetime.now(),
messages=[
DurableAgentStateMessage(
role="user",
contents=[DurableAgentStateTextContent(text="Hi")],
)
],
),
DurableAgentStateResponse(
correlation_id="corr-entity-prev-response",
created_at=datetime.now(),
messages=[
DurableAgentStateMessage(
role="assistant",
contents=[
DurableAgentStateTextReasoningContent(text="Let me think."),
DurableAgentStateTextContent(text="Hello there."),
],
)
],
),
]
)
await entity.run({"message": "What next?", "correlationId": "corr-entity-replay"})
assert captured_messages
assert all(content.type != "reasoning" for message in captured_messages for content in message.contents)
assert [message.text for message in captured_messages] == ["Hi", "Hello there.", "What next?"]
class TestAgentEntityReset:
"""Test suite for the reset operation."""
def test_reset_clears_conversation_history(self) -> None:
"""Test that reset clears the conversation history."""
mock_agent = Mock()
entity = _make_entity(mock_agent)
# Add some history with proper DurableAgentStateEntry objects
entity.state.data.conversation_history = [
DurableAgentStateRequest(
correlation_id="test-1",
created_at=datetime.now(),
messages=[
DurableAgentStateMessage(
role="user",
contents=[DurableAgentStateTextContent(text="msg1")],
)
],
),
]
entity.reset()
assert entity.state.data.conversation_history == []
def test_reset_with_extension_data(self) -> None:
"""Test that reset works when entity has extension data."""
mock_agent = Mock()
entity = _make_entity(mock_agent)
# Set up some initial state with conversation history
entity.state.data = DurableAgentStateData(conversation_history=[], extension_data={"some_key": "some_value"})
entity.reset()
assert len(entity.state.data.conversation_history) == 0
def test_reset_clears_message_count(self) -> None:
"""Test that reset clears the message count."""
mock_agent = Mock()
entity = _make_entity(mock_agent)
entity.reset()
assert len(entity.state.data.conversation_history) == 0
async def test_reset_after_conversation(self) -> None:
"""Test reset after a full conversation."""
mock_agent = Mock()
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
entity = _make_entity(mock_agent)
# Have a conversation
await entity.run({"message": "Message 1", "correlationId": "corr-entity-10a"})
await entity.run({"message": "Message 2", "correlationId": "corr-entity-10b"})
# Verify state before reset
assert entity.state.message_count == 4
assert len(entity.state.data.conversation_history) == 4
# Reset
entity.reset()
# Verify state after reset
assert entity.state.message_count == 0
assert len(entity.state.data.conversation_history) == 0
class TestErrorHandling:
"""Test suite for error handling in entities."""
async def test_run_agent_handles_agent_exception(self) -> None:
"""Test that run_agent handles agent exceptions."""
mock_agent = Mock()
mock_agent.run = _create_mock_run(side_effect=Exception("Agent failed"))
entity = _make_entity(mock_agent)
result = await entity.run({"message": "Message", "correlationId": "corr-entity-error-1"})
assert isinstance(result, AgentResponse)
assert len(result.messages) == 1
content = result.messages[0].contents[0]
assert isinstance(content, Content)
assert "Agent failed" in (content.message or "")
assert content.error_code == "Exception"
async def test_run_agent_handles_value_error(self) -> None:
"""Test that run_agent handles ValueError instances."""
mock_agent = Mock()
mock_agent.run = _create_mock_run(side_effect=ValueError("Invalid input"))
entity = _make_entity(mock_agent)
result = await entity.run({"message": "Message", "correlationId": "corr-entity-error-2"})
assert isinstance(result, AgentResponse)
assert len(result.messages) == 1
content = result.messages[0].contents[0]
assert isinstance(content, Content)
assert content.error_code == "ValueError"
assert "Invalid input" in str(content.message)
async def test_run_agent_handles_timeout_error(self) -> None:
"""Test that run_agent handles TimeoutError instances."""
mock_agent = Mock()
mock_agent.run = _create_mock_run(side_effect=TimeoutError("Request timeout"))
entity = _make_entity(mock_agent)
result = await entity.run({"message": "Message", "correlationId": "corr-entity-error-3"})
assert isinstance(result, AgentResponse)
assert len(result.messages) == 1
content = result.messages[0].contents[0]
assert isinstance(content, Content)
assert content.error_code == "TimeoutError"
async def test_run_agent_preserves_message_on_error(self) -> None:
"""Test that run_agent preserves message information on error."""
mock_agent = Mock()
mock_agent.run = _create_mock_run(side_effect=Exception("Error"))
entity = _make_entity(mock_agent)
result = await entity.run(
{"message": "Test message", "correlationId": "corr-entity-error-4"},
)
# Even on error, message info should be preserved
assert isinstance(result, AgentResponse)
assert len(result.messages) == 1
content = result.messages[0].contents[0]
assert isinstance(content, Content)
class TestConversationHistory:
"""Test suite for conversation history tracking."""
async def test_conversation_history_has_timestamps(self) -> None:
"""Test that conversation history entries include timestamps."""
mock_agent = Mock()
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
entity = _make_entity(mock_agent)
await entity.run({"message": "Message", "correlationId": "corr-entity-history-1"})
# Check both user and assistant messages have timestamps
for entry in entity.state.data.conversation_history:
timestamp = entry.created_at
assert timestamp is not None
# Verify timestamp is in ISO format
datetime.fromisoformat(str(timestamp))
async def test_conversation_history_ordering(self) -> None:
"""Test that conversation history maintains the correct order."""
mock_agent = Mock()
entity = _make_entity(mock_agent)
# Send multiple messages with different responses
mock_agent.run = _create_mock_run(response=_agent_response("Response 1"))
await entity.run(
{"message": "Message 1", "correlationId": "corr-entity-history-2a"},
)
mock_agent.run = _create_mock_run(response=_agent_response("Response 2"))
await entity.run(
{"message": "Message 2", "correlationId": "corr-entity-history-2b"},
)
mock_agent.run = _create_mock_run(response=_agent_response("Response 3"))
await entity.run(
{"message": "Message 3", "correlationId": "corr-entity-history-2c"},
)
# Verify order
history = entity.state.data.conversation_history
# Each conversation turn creates 2 entries: request and response
assert history[0].messages[0].text == "Message 1" # Request 1
assert history[1].messages[0].text == "Response 1" # Response 1
assert history[2].messages[0].text == "Message 2" # Request 2
assert history[3].messages[0].text == "Response 2" # Response 2
assert history[4].messages[0].text == "Message 3" # Request 3
assert history[5].messages[0].text == "Response 3" # Response 3
async def test_conversation_history_role_alternation(self) -> None:
"""Test that conversation history alternates between user and assistant roles."""
mock_agent = Mock()
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
entity = _make_entity(mock_agent)
await entity.run(
{"message": "Message 1", "correlationId": "corr-entity-history-3a"},
)
await entity.run(
{"message": "Message 2", "correlationId": "corr-entity-history-3b"},
)
# Check role alternation
history = entity.state.data.conversation_history
# Each conversation turn creates 2 entries: request and response
assert history[0].messages[0].role == "user" # Request 1
assert history[1].messages[0].role == "assistant" # Response 1
assert history[2].messages[0].role == "user" # Request 2
assert history[3].messages[0].role == "assistant" # Response 2
class TestRunRequestSupport:
"""Test suite for RunRequest support in entities."""
async def test_run_agent_with_run_request_object(self) -> None:
"""Test run_agent with a RunRequest object."""
mock_agent = Mock()
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
entity = _make_entity(mock_agent)
request = RunRequest(
message="Test message",
role="user",
enable_tool_calls=True,
correlation_id="corr-runreq-1",
)
result = await entity.run(request)
assert isinstance(result, AgentResponse)
assert result.text == "Response"
async def test_run_agent_with_dict_request(self) -> None:
"""Test run_agent with a dictionary request."""
mock_agent = Mock()
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
entity = _make_entity(mock_agent)
request_dict = {
"message": "Test message",
"role": "system",
"enable_tool_calls": False,
"correlationId": "corr-runreq-2",
}
result = await entity.run(request_dict)
assert isinstance(result, AgentResponse)
assert result.text == "Response"
async def test_run_agent_with_string_raises_without_correlation(self) -> None:
"""Test that run_agent rejects legacy string input without correlation ID."""
mock_agent = Mock()
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
entity = _make_entity(mock_agent)
with pytest.raises(ValueError):
await entity.run("Simple message")
async def test_run_agent_stores_role_in_history(self) -> None:
"""Test that run_agent stores the role in conversation history."""
mock_agent = Mock()
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
entity = _make_entity(mock_agent)
# Send as system role
request = RunRequest(
message="System message",
role="system",
correlation_id="corr-runreq-3",
)
await entity.run(request)
# Check that system role was stored
history = entity.state.data.conversation_history
assert history[0].messages[0].role == "system"
assert history[0].messages[0].text == "System message"
async def test_run_agent_with_response_format(self) -> None:
"""Test run_agent with a JSON response format."""
mock_agent = Mock()
# Return JSON response
mock_agent.run = _create_mock_run(response=_agent_response('{"answer": 42}'))
entity = _make_entity(mock_agent)
request = RunRequest(
message="What is the answer?",
response_format=EntityStructuredResponse,
correlation_id="corr-runreq-4",
)
result = await entity.run(request)
assert isinstance(result, AgentResponse)
assert result.text == '{"answer": 42}'
assert result.value is None
async def test_run_agent_disable_tool_calls(self) -> None:
"""Test run_agent with tool calls disabled."""
mock_agent = Mock()
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
entity = _make_entity(mock_agent)
request = RunRequest(message="Test", enable_tool_calls=False, correlation_id="corr-runreq-5")
result = await entity.run(request)
assert isinstance(result, AgentResponse)
# Agent should have been called (tool disabling is framework-dependent)
assert result.text == "Response"
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])