# Copyright (c) Microsoft. All rights reserved. """Unit tests for DurableAIAgent shim and DurableAgentProvider. Focuses on critical message normalization, delegation, and protocol compliance. Run with: pytest tests/test_shim.py -v """ from typing import Any from unittest.mock import Mock import pytest from agent_framework import Message, SupportsAgentRun from pydantic import BaseModel from agent_framework_durabletask import DurableAgentSession from agent_framework_durabletask._executors import DurableAgentExecutor from agent_framework_durabletask._models import RunRequest from agent_framework_durabletask._shim import DurableAgentProvider, DurableAIAgent class ResponseFormatModel(BaseModel): """Test Pydantic model for response format testing.""" result: str @pytest.fixture def mock_executor() -> Mock: """Create a mock executor for testing.""" mock = Mock(spec=DurableAgentExecutor) mock.run_durable_agent = Mock(return_value=None) mock.get_new_session = Mock(return_value=DurableAgentSession()) # Mock get_run_request to create actual RunRequest objects def create_run_request( message: str, options: dict[str, Any] | None = None, ) -> RunRequest: import uuid opts = dict(options) if options else {} response_format = opts.pop("response_format", None) enable_tool_calls = opts.pop("enable_tool_calls", True) wait_for_response = opts.pop("wait_for_response", True) return RunRequest( message=message, correlation_id=str(uuid.uuid4()), response_format=response_format, enable_tool_calls=enable_tool_calls, wait_for_response=wait_for_response, options=opts, ) mock.get_run_request = Mock(side_effect=create_run_request) return mock @pytest.fixture def test_agent(mock_executor: Mock) -> DurableAIAgent[Any]: """Create a test agent with mock executor.""" return DurableAIAgent(mock_executor, "test_agent") class TestDurableAIAgentMessageNormalization: """Test that DurableAIAgent properly normalizes various message input types.""" def test_run_accepts_string_message(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify run accepts and normalizes string messages.""" test_agent.run("Hello, world!") mock_executor.run_durable_agent.assert_called_once() # Verify agent_name and run_request were passed correctly as kwargs _, kwargs = mock_executor.run_durable_agent.call_args assert kwargs["agent_name"] == "test_agent" assert kwargs["run_request"].message == "Hello, world!" def test_run_accepts_chat_message(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify run accepts and normalizes Message objects.""" chat_msg = Message(role="user", text="Test message") test_agent.run(chat_msg) mock_executor.run_durable_agent.assert_called_once() _, kwargs = mock_executor.run_durable_agent.call_args assert kwargs["run_request"].message == "Test message" def test_run_accepts_list_of_strings(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify run accepts and joins list of strings.""" test_agent.run(["First message", "Second message"]) mock_executor.run_durable_agent.assert_called_once() _, kwargs = mock_executor.run_durable_agent.call_args assert kwargs["run_request"].message == "First message\nSecond message" def test_run_accepts_list_of_chat_messages(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify run accepts and joins list of Message objects.""" messages = [ Message(role="user", text="Message 1"), Message(role="assistant", text="Message 2"), ] test_agent.run(messages) mock_executor.run_durable_agent.assert_called_once() _, kwargs = mock_executor.run_durable_agent.call_args assert kwargs["run_request"].message == "Message 1\nMessage 2" def test_run_handles_none_message(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify run handles None message gracefully.""" test_agent.run(None) mock_executor.run_durable_agent.assert_called_once() _, kwargs = mock_executor.run_durable_agent.call_args assert kwargs["run_request"].message == "" def test_run_handles_empty_list(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify run handles empty list gracefully.""" test_agent.run([]) mock_executor.run_durable_agent.assert_called_once() _, kwargs = mock_executor.run_durable_agent.call_args assert kwargs["run_request"].message == "" class TestDurableAIAgentParameterFlow: """Test that parameters flow correctly through the shim to executor.""" def test_run_forwards_session_parameter(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify run forwards session parameter to executor.""" session = DurableAgentSession(service_session_id="test-session") test_agent.run("message", session=session) mock_executor.run_durable_agent.assert_called_once() _, kwargs = mock_executor.run_durable_agent.call_args assert kwargs["session"] == session def test_run_forwards_response_format(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify run forwards response_format parameter to executor.""" test_agent.run("message", options={"response_format": ResponseFormatModel}) mock_executor.run_durable_agent.assert_called_once() _, kwargs = mock_executor.run_durable_agent.call_args assert kwargs["run_request"].response_format == ResponseFormatModel class TestDurableAISupportsAgentRunCompliance: """Test that DurableAIAgent implements SupportsAgentRun correctly.""" def test_agent_implements_protocol(self, test_agent: DurableAIAgent[Any]) -> None: """Verify DurableAIAgent implements SupportsAgentRun.""" assert isinstance(test_agent, SupportsAgentRun) def test_agent_has_required_properties(self, test_agent: DurableAIAgent[Any]) -> None: """Verify DurableAIAgent has all required SupportsAgentRun properties.""" assert hasattr(test_agent, "id") assert hasattr(test_agent, "name") assert hasattr(test_agent, "display_name") assert hasattr(test_agent, "description") def test_agent_id_defaults_to_name(self, mock_executor: Mock) -> None: """Verify agent id defaults to name when not provided.""" agent: DurableAIAgent[Any] = DurableAIAgent(mock_executor, "my_agent") assert agent.id == "my_agent" assert agent.name == "my_agent" def test_agent_id_can_be_customized(self, mock_executor: Mock) -> None: """Verify agent id can be set independently from name.""" agent: DurableAIAgent[Any] = DurableAIAgent(mock_executor, "my_agent", agent_id="custom-id") assert agent.id == "custom-id" assert agent.name == "my_agent" class TestDurableAIAgentSessionManagement: """Test session creation and management.""" def test_create_session_delegates_to_executor(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify create_session delegates to executor.""" mock_session = DurableAgentSession() mock_executor.get_new_session.return_value = mock_session session = test_agent.create_session() mock_executor.get_new_session.assert_called_once_with("test_agent") assert session == mock_session def test_get_session_forwards_service_session_id( self, test_agent: DurableAIAgent[Any], mock_executor: Mock ) -> None: """Verify get_session forwards service_session_id and session_id to executor.""" mock_session = DurableAgentSession(service_session_id="svc-123") mock_executor.get_new_session.return_value = mock_session session = test_agent.get_session("svc-123", session_id="local-456") mock_executor.get_new_session.assert_called_once_with( "test_agent", service_session_id="svc-123", session_id="local-456" ) assert session.service_session_id == "svc-123" def test_get_session_without_session_id(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify get_session works with only service_session_id (session_id defaults to None).""" mock_session = DurableAgentSession(service_session_id="svc-789") mock_executor.get_new_session.return_value = mock_session session = test_agent.get_session("svc-789") mock_executor.get_new_session.assert_called_once_with( "test_agent", service_session_id="svc-789", session_id=None ) assert session.service_session_id == "svc-789" class TestDurableAgentProviderInterface: """Test that DurableAgentProvider defines the correct interface.""" def test_provider_cannot_be_instantiated(self) -> None: """Verify DurableAgentProvider is abstract and cannot be instantiated.""" with pytest.raises(TypeError): DurableAgentProvider() # type: ignore[abstract] def test_provider_defines_get_agent_method(self) -> None: """Verify DurableAgentProvider defines get_agent abstract method.""" assert hasattr(DurableAgentProvider, "get_agent") if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"])