# Copyright (c) Microsoft. All rights reserved. """Unit tests for orchestration support (DurableAIAgent).""" from typing import Any from unittest.mock import Mock import pytest from agent_framework import AgentRunResponse, AgentThread, ChatMessage from azure.durable_functions.models.Task import TaskBase, TaskState from agent_framework_azurefunctions import AgentFunctionApp, DurableAIAgent from agent_framework_azurefunctions._models import AgentSessionId, DurableAgentThread from agent_framework_azurefunctions._orchestration import AgentTask def _app_with_registered_agents(*agent_names: str) -> AgentFunctionApp: app = AgentFunctionApp(enable_health_check=False, enable_http_endpoints=False) for name in agent_names: agent = Mock() agent.name = name app.add_agent(agent) return app class _FakeTask(TaskBase): """Concrete TaskBase for testing AgentTask wiring.""" def __init__(self, task_id: int = 1): super().__init__(task_id, []) self._set_is_scheduled(False) self.action_repr = [] self.state = TaskState.RUNNING def _create_entity_task(task_id: int = 1) -> TaskBase: """Create a minimal TaskBase instance for AgentTask tests.""" return _FakeTask(task_id) class TestAgentResponseHelpers: """Tests for helper utilities that prepare AgentRunResponse values.""" @staticmethod def _create_agent_task() -> AgentTask: entity_task = _create_entity_task() return AgentTask(entity_task, None, "correlation-id") def test_load_agent_response_from_instance(self) -> None: task = self._create_agent_task() response = AgentRunResponse(messages=[ChatMessage(role="assistant", text='{"foo": "bar"}')]) loaded = task._load_agent_response(response) assert loaded is response assert loaded.value is None def test_load_agent_response_from_serialized(self) -> None: task = self._create_agent_task() serialized = AgentRunResponse(messages=[ChatMessage(role="assistant", text="structured")]).to_dict() serialized["value"] = {"answer": 42} loaded = task._load_agent_response(serialized) assert loaded is not None assert loaded.value == {"answer": 42} loaded_dict = loaded.to_dict() assert loaded_dict["type"] == "agent_run_response" def test_load_agent_response_rejects_none(self) -> None: task = self._create_agent_task() with pytest.raises(ValueError): task._load_agent_response(None) def test_load_agent_response_rejects_unsupported_type(self) -> None: task = self._create_agent_task() with pytest.raises(TypeError, match="Unsupported type"): task._load_agent_response(["invalid", "list"]) # type: ignore[arg-type] def test_try_set_value_success(self) -> None: """Test try_set_value correctly processes successful task completion.""" entity_task = _create_entity_task() task = AgentTask(entity_task, None, "correlation-id") # Simulate successful entity task completion entity_task.state = TaskState.SUCCEEDED entity_task.result = AgentRunResponse(messages=[ChatMessage(role="assistant", text="Test response")]).to_dict() # Clear pending_tasks to simulate that parent has processed the child task.pending_tasks.clear() # Call try_set_value task.try_set_value(entity_task) # Verify task completed successfully with AgentRunResponse assert task.state == TaskState.SUCCEEDED assert isinstance(task.result, AgentRunResponse) assert task.result.text == "Test response" def test_try_set_value_failure(self) -> None: """Test try_set_value correctly handles failed task completion.""" entity_task = _create_entity_task() task = AgentTask(entity_task, None, "correlation-id") # Simulate failed entity task entity_task.state = TaskState.FAILED entity_task.result = Exception("Entity call failed") # Call try_set_value task.try_set_value(entity_task) # Verify task failed with the error assert task.state == TaskState.FAILED assert isinstance(task.result, Exception) assert str(task.result) == "Entity call failed" def test_try_set_value_with_response_format(self) -> None: """Test try_set_value parses structured output when response_format is provided.""" from pydantic import BaseModel class TestSchema(BaseModel): answer: str entity_task = _create_entity_task() task = AgentTask(entity_task, TestSchema, "correlation-id") # Simulate successful entity task with JSON response entity_task.state = TaskState.SUCCEEDED entity_task.result = AgentRunResponse( messages=[ChatMessage(role="assistant", text='{"answer": "42"}')] ).to_dict() # Clear pending_tasks to simulate that parent has processed the child task.pending_tasks.clear() # Call try_set_value task.try_set_value(entity_task) # Verify task completed and value was parsed assert task.state == TaskState.SUCCEEDED assert isinstance(task.result, AgentRunResponse) assert isinstance(task.result.value, TestSchema) assert task.result.value.answer == "42" def test_ensure_response_format_parses_value(self) -> None: """Test _ensure_response_format correctly parses response value.""" from pydantic import BaseModel class SampleSchema(BaseModel): name: str task = self._create_agent_task() response = AgentRunResponse(messages=[ChatMessage(role="assistant", text='{"name": "test"}')]) # Value should be None initially assert response.value is None # Parse the value task._ensure_response_format(SampleSchema, "test-correlation", response) # Value should now be parsed assert isinstance(response.value, SampleSchema) assert response.value.name == "test" def test_ensure_response_format_skips_if_already_parsed(self) -> None: """Test _ensure_response_format does not re-parse if value already matches format.""" from pydantic import BaseModel class SampleSchema(BaseModel): name: str task = self._create_agent_task() existing_value = SampleSchema(name="existing") response = AgentRunResponse( messages=[ChatMessage(role="assistant", text='{"name": "new"}')], value=existing_value, ) # Call _ensure_response_format task._ensure_response_format(SampleSchema, "test-correlation", response) # Value should remain unchanged (not re-parsed) assert response.value is existing_value assert response.value.name == "existing" class TestDurableAIAgent: """Test suite for DurableAIAgent wrapper.""" def test_init(self) -> None: """Test DurableAIAgent initialization.""" mock_context = Mock() mock_context.instance_id = "test-instance-123" agent = DurableAIAgent(mock_context, "TestAgent") assert agent.context == mock_context assert agent.agent_name == "TestAgent" def test_implements_agent_protocol(self) -> None: """Test that DurableAIAgent implements AgentProtocol.""" from agent_framework import AgentProtocol mock_context = Mock() agent = DurableAIAgent(mock_context, "TestAgent") # Check that agent satisfies AgentProtocol assert isinstance(agent, AgentProtocol) def test_has_agent_protocol_properties(self) -> None: """Test that DurableAIAgent has AgentProtocol properties.""" mock_context = Mock() agent = DurableAIAgent(mock_context, "TestAgent") # AgentProtocol properties assert hasattr(agent, "id") assert hasattr(agent, "name") assert hasattr(agent, "description") assert hasattr(agent, "display_name") # Verify values assert agent.name == "TestAgent" assert agent.description == "Durable agent proxy for TestAgent" assert agent.display_name == "TestAgent" assert agent.id is not None # Auto-generated UUID def test_get_new_thread(self) -> None: """Test creating a new agent thread.""" mock_context = Mock() mock_context.instance_id = "test-instance-456" mock_context.new_uuid = Mock(return_value="test-guid-456") agent = DurableAIAgent(mock_context, "WriterAgent") thread = agent.get_new_thread() assert isinstance(thread, DurableAgentThread) assert thread.session_id is not None session_id = thread.session_id assert isinstance(session_id, AgentSessionId) assert session_id.name == "WriterAgent" assert session_id.key == "test-guid-456" mock_context.new_uuid.assert_called_once() def test_get_new_thread_deterministic(self) -> None: """Test that get_new_thread creates deterministic session IDs.""" mock_context = Mock() mock_context.instance_id = "test-instance-789" mock_context.new_uuid = Mock(side_effect=["session-guid-1", "session-guid-2"]) agent = DurableAIAgent(mock_context, "EditorAgent") # Create multiple threads - they should have unique session IDs thread1 = agent.get_new_thread() thread2 = agent.get_new_thread() assert isinstance(thread1, DurableAgentThread) assert isinstance(thread2, DurableAgentThread) session_id1 = thread1.session_id session_id2 = thread2.session_id assert session_id1 is not None and session_id2 is not None assert isinstance(session_id1, AgentSessionId) assert isinstance(session_id2, AgentSessionId) assert session_id1.name == "EditorAgent" assert session_id2.name == "EditorAgent" assert session_id1.key == "session-guid-1" assert session_id2.key == "session-guid-2" assert mock_context.new_uuid.call_count == 2 def test_run_creates_entity_call(self) -> None: """Test that run() creates proper entity call and returns a Task.""" mock_context = Mock() mock_context.instance_id = "test-instance-001" mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"]) entity_task = _create_entity_task() mock_context.call_entity = Mock(return_value=entity_task) agent = DurableAIAgent(mock_context, "TestAgent") # Create thread thread = agent.get_new_thread() # Call run() - returns AgentTask directly task = agent.run(messages="Test message", thread=thread, enable_tool_calls=True) assert isinstance(task, AgentTask) assert task.children[0] == entity_task # Verify call_entity was called with correct parameters assert mock_context.call_entity.called call_args = mock_context.call_entity.call_args entity_id, operation, request = call_args[0] assert operation == "run" assert request["message"] == "Test message" assert request["enable_tool_calls"] is True assert "correlationId" in request assert request["correlationId"] == "correlation-guid" assert "thread_id" in request assert request["thread_id"] == "thread-guid" # Verify orchestration ID is set from context.instance_id assert "orchestrationId" in request assert request["orchestrationId"] == "test-instance-001" def test_run_sets_orchestration_id(self) -> None: """Test that run() sets the orchestration_id from context.instance_id.""" mock_context = Mock() mock_context.instance_id = "my-orchestration-123" mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"]) entity_task = _create_entity_task() mock_context.call_entity = Mock(return_value=entity_task) agent = DurableAIAgent(mock_context, "TestAgent") thread = agent.get_new_thread() agent.run(messages="Test", thread=thread) call_args = mock_context.call_entity.call_args request = call_args[0][2] assert request["orchestrationId"] == "my-orchestration-123" def test_run_without_thread(self) -> None: """Test that run() works without explicit thread (creates unique session key).""" mock_context = Mock() mock_context.instance_id = "test-instance-002" mock_context.new_uuid = Mock(side_effect=["auto-generated-guid", "correlation-guid"]) entity_task = _create_entity_task() mock_context.call_entity = Mock(return_value=entity_task) agent = DurableAIAgent(mock_context, "TestAgent") # Call without thread task = agent.run(messages="Test message") assert isinstance(task, AgentTask) assert task.children[0] == entity_task # Verify the entity ID uses the auto-generated GUID with dafx- prefix call_args = mock_context.call_entity.call_args entity_id = call_args[0][0] assert entity_id.name == "dafx-TestAgent" assert entity_id.key == "auto-generated-guid" # Should be called twice: once for session_key, once for correlationId assert mock_context.new_uuid.call_count == 2 def test_run_with_response_format(self) -> None: """Test that run() passes response format correctly.""" mock_context = Mock() mock_context.instance_id = "test-instance-003" entity_task = _create_entity_task() mock_context.call_entity = Mock(return_value=entity_task) agent = DurableAIAgent(mock_context, "TestAgent") from pydantic import BaseModel class SampleSchema(BaseModel): key: str # Create thread and call thread = agent.get_new_thread() task = agent.run(messages="Test message", thread=thread, response_format=SampleSchema) assert isinstance(task, AgentTask) assert task.children[0] == entity_task # Verify schema was passed in the call_entity arguments call_args = mock_context.call_entity.call_args input_data = call_args[0][2] # Third argument is input_data assert "response_format" in input_data assert input_data["response_format"]["__response_schema_type__"] == "pydantic_model" assert input_data["response_format"]["module"] == SampleSchema.__module__ assert input_data["response_format"]["qualname"] == SampleSchema.__qualname__ def test_messages_to_string(self) -> None: """Test converting ChatMessage list to string.""" from agent_framework import ChatMessage mock_context = Mock() agent = DurableAIAgent(mock_context, "TestAgent") messages = [ ChatMessage(role="user", text="Hello"), ChatMessage(role="assistant", text="Hi there"), ChatMessage(role="user", text="How are you?"), ] result = agent._messages_to_string(messages) assert result == "Hello\nHi there\nHow are you?" def test_run_with_chat_message(self) -> None: """Test that run() handles ChatMessage input.""" from agent_framework import ChatMessage mock_context = Mock() mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"]) entity_task = _create_entity_task() mock_context.call_entity = Mock(return_value=entity_task) agent = DurableAIAgent(mock_context, "TestAgent") thread = agent.get_new_thread() # Call with ChatMessage msg = ChatMessage(role="user", text="Hello") task = agent.run(messages=msg, thread=thread) assert isinstance(task, AgentTask) assert task.children[0] == entity_task # Verify message was converted to string call_args = mock_context.call_entity.call_args request = call_args[0][2] assert request["message"] == "Hello" def test_run_stream_raises_not_implemented(self) -> None: """Test that run_stream() method raises NotImplementedError.""" mock_context = Mock() agent = DurableAIAgent(mock_context, "TestAgent") with pytest.raises(NotImplementedError) as exc_info: agent.run_stream("Test message") error_msg = str(exc_info.value) assert "Streaming is not supported" in error_msg def test_entity_id_format(self) -> None: """Test that EntityId is created with correct format (name, key).""" from azure.durable_functions import EntityId mock_context = Mock() mock_context.new_uuid = Mock(return_value="test-guid-789") mock_context.call_entity = Mock(return_value=_create_entity_task()) agent = DurableAIAgent(mock_context, "WriterAgent") thread = agent.get_new_thread() # Call run() to trigger entity ID creation agent.run("Test", thread=thread) # Verify call_entity was called with correct EntityId call_args = mock_context.call_entity.call_args entity_id = call_args[0][0] # EntityId should be EntityId(name="dafx-WriterAgent", key="test-guid-789") # Which formats as "@dafx-writeragent@test-guid-789" assert isinstance(entity_id, EntityId) assert entity_id.name == "dafx-WriterAgent" assert entity_id.key == "test-guid-789" assert str(entity_id) == "@dafx-writeragent@test-guid-789" class TestAgentFunctionAppGetAgent: """Test suite for AgentFunctionApp.get_agent.""" def test_get_agent_method(self) -> None: """Test get_agent method creates DurableAIAgent for registered agent.""" app = _app_with_registered_agents("MyAgent") mock_context = Mock() mock_context.instance_id = "test-instance-100" agent = app.get_agent(mock_context, "MyAgent") assert isinstance(agent, DurableAIAgent) assert agent.agent_name == "MyAgent" assert agent.context == mock_context def test_get_agent_raises_for_unregistered_agent(self) -> None: """Test get_agent raises ValueError when agent is not registered.""" app = _app_with_registered_agents("KnownAgent") with pytest.raises(ValueError, match=r"Agent 'MissingAgent' is not registered with this app\."): app.get_agent(Mock(), "MissingAgent") class TestOrchestrationIntegration: """Integration tests for orchestration scenarios.""" def test_sequential_agent_calls_simulation(self) -> None: """Simulate sequential agent calls in an orchestration.""" mock_context = Mock() mock_context.instance_id = "test-orchestration-001" # new_uuid will be called 3 times: # 1. thread creation # 2. correlationId for first call # 3. correlationId for second call mock_context.new_uuid = Mock(side_effect=["deterministic-guid-001", "corr-1", "corr-2"]) # Track entity calls entity_calls: list[dict[str, Any]] = [] def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> TaskBase: entity_calls.append({"entity_id": str(entity_id), "operation": operation, "input": input_data}) return _create_entity_task() mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect) app = _app_with_registered_agents("WriterAgent") agent = app.get_agent(mock_context, "WriterAgent") # Create thread thread = agent.get_new_thread() # First call - returns AgentTask task1 = agent.run("Write something", thread=thread) assert isinstance(task1, AgentTask) # Second call - returns AgentTask task2 = agent.run("Improve: something", thread=thread) assert isinstance(task2, AgentTask) # Verify both calls used the same entity (same session key) assert len(entity_calls) == 2 assert entity_calls[0]["entity_id"] == entity_calls[1]["entity_id"] # EntityId format is @dafx-writeragent@deterministic-guid-001 assert entity_calls[0]["entity_id"] == "@dafx-writeragent@deterministic-guid-001" # new_uuid called 3 times: thread + 2 correlation IDs assert mock_context.new_uuid.call_count == 3 def test_multiple_agents_in_orchestration(self) -> None: """Test using multiple different agents in one orchestration.""" mock_context = Mock() mock_context.instance_id = "test-orchestration-002" # Mock new_uuid to return different GUIDs for each call # Order: writer thread, editor thread, writer correlation, editor correlation mock_context.new_uuid = Mock(side_effect=["writer-guid-001", "editor-guid-002", "writer-corr", "editor-corr"]) entity_calls: list[str] = [] def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> TaskBase: entity_calls.append(str(entity_id)) return _create_entity_task() mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect) app = _app_with_registered_agents("WriterAgent", "EditorAgent") writer = app.get_agent(mock_context, "WriterAgent") editor = app.get_agent(mock_context, "EditorAgent") writer_thread = writer.get_new_thread() editor_thread = editor.get_new_thread() # Call both agents - returns AgentTasks writer_task = writer.run("Write", thread=writer_thread) editor_task = editor.run("Edit", thread=editor_thread) assert isinstance(writer_task, AgentTask) assert isinstance(editor_task, AgentTask) # Verify different entity IDs were used assert len(entity_calls) == 2 # EntityId format is @dafx-agentname@guid (lowercased agent name with dafx- prefix) assert entity_calls[0] == "@dafx-writeragent@writer-guid-001" assert entity_calls[1] == "@dafx-editoragent@editor-guid-002" class TestAgentThreadSerialization: """Test that AgentThread can be serialized for orchestration state.""" async def test_agent_thread_serialize(self) -> None: """Test that AgentThread can be serialized.""" thread = AgentThread() # Serialize serialized = await thread.serialize() assert isinstance(serialized, dict) assert "service_thread_id" in serialized async def test_agent_thread_deserialize(self) -> None: """Test that AgentThread can be deserialized.""" thread = AgentThread() serialized = await thread.serialize() # Deserialize restored = await AgentThread.deserialize(serialized) assert isinstance(restored, AgentThread) assert restored.service_thread_id == thread.service_thread_id async def test_durable_agent_thread_serialization(self) -> None: """Test that DurableAgentThread persists session metadata during serialization.""" mock_context = Mock() mock_context.instance_id = "test-instance-999" mock_context.new_uuid = Mock(return_value="test-guid-999") agent = DurableAIAgent(mock_context, "TestAgent") thread = agent.get_new_thread() assert isinstance(thread, DurableAgentThread) # Verify custom attribute and property exist assert thread.session_id is not None session_id = thread.session_id assert isinstance(session_id, AgentSessionId) assert session_id.name == "TestAgent" assert session_id.key == "test-guid-999" # Standard serialization should still work serialized = await thread.serialize() assert isinstance(serialized, dict) assert serialized.get("durable_session_id") == str(session_id) # After deserialization, we'd need to restore the custom attribute # This would be handled by the orchestration framework restored = await DurableAgentThread.deserialize(serialized) assert isinstance(restored, DurableAgentThread) assert restored.session_id == session_id if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"])