From 754491cdd3d4029d886e43b1bff8526215280c6e Mon Sep 17 00:00:00 2001 From: Laveesh Rohra Date: Fri, 7 Nov 2025 08:29:43 -0800 Subject: [PATCH] Python: Add Unit tests for Azurefunctions package (#1976) * Add Unit tests for Azurefunctions * remove duplicate import --- .../agent_framework_azurefunctions/_models.py | 5 +- .../packages/azurefunctions/tests/test_app.py | 614 ++++++++++++ .../azurefunctions/tests/test_entities.py | 904 ++++++++++++++++++ .../azurefunctions/tests/test_models.py | 478 +++++++++ .../azurefunctions/tests/test_multi_agent.py | 150 +++ .../tests/test_orchestration.py | 425 ++++++++ .../azurefunctions/tests/test_state.py | 110 +++ 7 files changed, 2684 insertions(+), 2 deletions(-) create mode 100644 python/packages/azurefunctions/tests/test_app.py create mode 100644 python/packages/azurefunctions/tests/test_entities.py create mode 100644 python/packages/azurefunctions/tests/test_models.py create mode 100644 python/packages/azurefunctions/tests/test_multi_agent.py create mode 100644 python/packages/azurefunctions/tests/test_orchestration.py create mode 100644 python/packages/azurefunctions/tests/test_state.py diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py index 15f24666bb..32d0f101e3 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py @@ -197,9 +197,10 @@ class DurableAgentThread(AgentThread): **kwargs: Any, ) -> "DurableAgentThread": """Restores a durable thread, rehydrating the stored session identifier.""" - session_id_value = serialized_thread_state.get(cls._SERIALIZED_SESSION_ID_KEY) + state_payload = dict(serialized_thread_state) + session_id_value = state_payload.pop(cls._SERIALIZED_SESSION_ID_KEY, None) thread = await super().deserialize( - serialized_thread_state, + state_payload, message_store=message_store, **kwargs, ) diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py new file mode 100644 index 0000000000..d7ed8fef78 --- /dev/null +++ b/python/packages/azurefunctions/tests/test_app.py @@ -0,0 +1,614 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for AgentFunctionApp.""" + +from collections.abc import Awaitable, Callable +from typing import Any, TypeVar +from unittest.mock import ANY, AsyncMock, Mock, patch + +import azure.durable_functions as df +import azure.functions as func +import pytest +from agent_framework import AgentRunResponse, ChatMessage + +from agent_framework_azurefunctions import AgentFunctionApp +from agent_framework_azurefunctions._entities import AgentEntity, AgentState, create_agent_entity +from agent_framework_azurefunctions._errors import IncomingRequestError + +TFunc = TypeVar("TFunc", bound=Callable[..., Any]) + + +class TestAgentFunctionAppInit: + """Test suite for AgentFunctionApp initialization.""" + + def test_init_with_defaults(self) -> None: + """Test initialization with default parameters.""" + mock_agent = Mock() + mock_agent.name = "TestAgent" + + app = AgentFunctionApp(agents=[mock_agent]) + + assert len(app.agents) == 1 + assert "TestAgent" in app.agents + assert app.enable_health_check is True + + def test_init_with_custom_auth_level(self) -> None: + """Test initialization with custom auth level.""" + mock_agent = Mock() + mock_agent.name = "TestAgent" + + app = AgentFunctionApp(agents=[mock_agent], http_auth_level=func.AuthLevel.FUNCTION) + + # App should be created successfully + assert "TestAgent" in app.agents + + def test_init_with_health_check_disabled(self) -> None: + """Test initialization with health check disabled.""" + mock_agent = Mock() + mock_agent.name = "TestAgent" + + app = AgentFunctionApp(agents=[mock_agent], enable_health_check=False) + + assert app.enable_health_check is False + + def test_init_with_http_endpoints_disabled(self) -> None: + """Test initialization with HTTP endpoints disabled.""" + mock_agent = Mock() + mock_agent.name = "TestAgent" + + app = AgentFunctionApp(agents=[mock_agent], enable_http_endpoints=False) + + assert app.enable_http_endpoints is False + + def test_init_stores_agent_reference(self) -> None: + """Test that agent reference is stored correctly.""" + mock_agent = Mock() + mock_agent.name = "TestAgent" + + app = AgentFunctionApp(agents=[mock_agent]) + + assert app.agents["TestAgent"].name == "TestAgent" + + def test_add_agent_uses_specific_callback(self) -> None: + """Verify that a per-agent callback overrides the default.""" + + mock_agent = Mock() + mock_agent.name = "CallbackAgent" + specific_callback = Mock() + + with patch.object(AgentFunctionApp, "_setup_agent_functions") as setup_mock: + app = AgentFunctionApp(default_callback=Mock()) + app.add_agent(mock_agent, callback=specific_callback) + + setup_mock.assert_called_once() + _, _, passed_callback, enable_http_endpoint = setup_mock.call_args[0] + assert passed_callback is specific_callback + assert enable_http_endpoint is True + + def test_default_callback_applied_when_no_specific(self) -> None: + """Ensure the default callback is supplied when add_agent lacks override.""" + + mock_agent = Mock() + mock_agent.name = "DefaultAgent" + default_callback = Mock() + + with patch.object(AgentFunctionApp, "_setup_agent_functions") as setup_mock: + app = AgentFunctionApp(default_callback=default_callback) + app.add_agent(mock_agent) + + setup_mock.assert_called_once() + _, _, passed_callback, enable_http_endpoint = setup_mock.call_args[0] + assert passed_callback is default_callback + assert enable_http_endpoint is True + + def test_init_with_agents_uses_default_callback(self) -> None: + """Agents provided in __init__ should receive the default callback.""" + + mock_agent = Mock() + mock_agent.name = "InitAgent" + default_callback = Mock() + + with patch.object(AgentFunctionApp, "_setup_agent_functions") as setup_mock: + AgentFunctionApp(agents=[mock_agent], default_callback=default_callback) + + setup_mock.assert_called_once() + _, _, passed_callback, enable_http_endpoint = setup_mock.call_args[0] + assert passed_callback is default_callback + assert enable_http_endpoint is True + + +class TestAgentFunctionAppSetup: + """Test suite for AgentFunctionApp setup and configuration.""" + + def test_app_is_dfapp_instance(self) -> None: + """Test that AgentFunctionApp is a DFApp instance.""" + mock_agent = Mock() + mock_agent.name = "TestAgent" + + app = AgentFunctionApp(agents=[mock_agent]) + + assert isinstance(app, df.DFApp) + + def test_setup_creates_http_trigger(self) -> None: + """Test that setup creates an HTTP trigger.""" + mock_agent = Mock() + mock_agent.name = "TestAgent" + + def passthrough_decorator(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]: + def decorator(func: TFunc) -> TFunc: + return func + + return decorator + + with ( + patch.object(AgentFunctionApp, "route", new=passthrough_decorator), + patch.object(AgentFunctionApp, "durable_client_input", new=passthrough_decorator), + patch.object(AgentFunctionApp, "entity_trigger", new=passthrough_decorator), + ): + app = AgentFunctionApp(agents=[mock_agent]) + + # Verify agent is registered + assert "TestAgent" in app.agents + + def test_setup_skips_http_trigger_when_disabled(self) -> None: + """Test that HTTP trigger is not created when disabled.""" + mock_agent = Mock() + mock_agent.name = "TestAgent" + + captured_routes: list[str | None] = [] + + def capture_route(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]: + def decorator(func: TFunc) -> TFunc: + route_key = kwargs.get("route") if kwargs else None + captured_routes.append(route_key) + return func + + return decorator + + def passthrough_decorator(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]: + def decorator(func: TFunc) -> TFunc: + return func + + return decorator + + with ( + patch.object(AgentFunctionApp, "function_name", new=passthrough_decorator), + patch.object(AgentFunctionApp, "route", new=capture_route), + patch.object(AgentFunctionApp, "durable_client_input", new=passthrough_decorator), + patch.object(AgentFunctionApp, "entity_trigger", new=passthrough_decorator), + ): + app = AgentFunctionApp(agents=[mock_agent], enable_http_endpoints=False) + + # Verify agent is registered + assert "TestAgent" in app.agents + + # Verify that no HTTP run route was created + run_route = f"agents/{mock_agent.name}/run" + assert run_route not in captured_routes + + def test_agent_override_enables_http_route_when_app_disabled(self) -> None: + """Agent-level override should enable HTTP route even when app disables it.""" + + mock_agent = Mock() + mock_agent.name = "OverrideAgent" + + with ( + patch.object(AgentFunctionApp, "_setup_http_run_route") as http_route_mock, + patch.object(AgentFunctionApp, "_setup_agent_entity") as agent_entity_mock, + ): + app = AgentFunctionApp(enable_health_check=False, enable_http_endpoints=False) + app.add_agent(mock_agent, enable_http_endpoint=True) + + http_route_mock.assert_called_once_with("OverrideAgent") + agent_entity_mock.assert_called_once_with(mock_agent, "OverrideAgent", ANY) + assert app.agent_http_endpoint_flags["OverrideAgent"] is True + + def test_agent_override_disables_http_route_when_app_enabled(self) -> None: + """Agent-level override should disable HTTP route even when app enables it.""" + + mock_agent = Mock() + mock_agent.name = "DisabledOverride" + + with ( + patch.object(AgentFunctionApp, "_setup_http_run_route") as http_route_mock, + patch.object(AgentFunctionApp, "_setup_agent_entity") as agent_entity_mock, + ): + app = AgentFunctionApp(enable_health_check=False, enable_http_endpoints=True) + app.add_agent(mock_agent, enable_http_endpoint=False) + + http_route_mock.assert_not_called() + agent_entity_mock.assert_called_once_with(mock_agent, "DisabledOverride", ANY) + assert app.agent_http_endpoint_flags["DisabledOverride"] is False + + def test_multiple_apps_independent(self) -> None: + """Test that multiple AgentFunctionApp instances are independent.""" + agent1 = Mock() + agent1.name = "Agent1" + agent2 = Mock() + agent2.name = "Agent2" + + app1 = AgentFunctionApp(agents=[agent1]) + app2 = AgentFunctionApp(agents=[agent2]) + + assert app1.agents["Agent1"].name == "Agent1" + assert app2.agents["Agent2"].name == "Agent2" + assert "Agent1" in app1.agents + assert "Agent2" in app2.agents + + +class TestWaitForCompletionAndCorrelationId: + """Tests for wait_for_completion flag and correlation ID handling.""" + + def _create_app(self) -> AgentFunctionApp: + mock_agent = Mock() + mock_agent.__class__.__name__ = "MockAgent" + mock_agent.name = "MockAgent" + return AgentFunctionApp(agents=[mock_agent], enable_health_check=False) + + def _make_request( + self, + headers: dict[str, str] | None = None, + params: dict[str, str] | None = None, + ) -> Mock: + request = Mock() + request.headers = headers or {} + request.params = params or {} + return request + + def test_wait_for_completion_header_true(self) -> None: + """Test that the wait-for-completion header is honored.""" + app = self._create_app() + request = self._make_request(headers={"X-Wait-For-Completion": "true"}) + + assert app._should_wait_for_completion(request, {}) is True + + def test_wait_for_completion_body_variants(self) -> None: + """Test that multiple payload spellings are accepted.""" + app = self._create_app() + request = self._make_request() + + assert app._should_wait_for_completion(request, {"wait_for_completion": "true"}) is True + assert app._should_wait_for_completion(request, {"waitForCompletion": "1"}) is True + assert app._should_wait_for_completion(request, {"WaitForCompletion": "no"}) is False + + +class TestAgentEntityOperations: + """Test suite for entity operations.""" + + async def test_entity_run_agent_operation(self) -> None: + """Test that entity can run agent operation.""" + mock_agent = Mock() + mock_agent.run = AsyncMock( + return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Test response")]) + ) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + result = await entity.run_agent( + mock_context, + {"message": "Test message", "conversation_id": "test-conv-123", "correlation_id": "corr-app-entity-1"}, + ) + + assert result["status"] == "success" + assert result["response"] == "Test response" + assert result["message"] == "Test message" + assert result["conversation_id"] == "test-conv-123" + assert entity.state.message_count == 1 + + async def test_entity_stores_conversation_history(self) -> None: + """Test that the entity stores conversation history.""" + mock_agent = Mock() + mock_agent.run = AsyncMock( + return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Response 1")]) + ) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + # Send first message + await entity.run_agent( + mock_context, {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-app-entity-2"} + ) + + history = entity.state.conversation_history + assert len(history) == 2 # User + assistant + + user_msg = history[0] + user_role = getattr(user_msg.role, "value", user_msg.role) + assert user_role == "user" + assert user_msg.text == "Message 1" + + assistant_msg = history[1] + assistant_role = getattr(assistant_msg.role, "value", assistant_msg.role) + assert assistant_role == "assistant" + assert assistant_msg.text == "Response 1" + + async def test_entity_increments_message_count(self) -> None: + """Test that the entity increments the message count.""" + mock_agent = Mock() + mock_agent.run = AsyncMock( + return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Response")]) + ) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + assert entity.state.message_count == 0 + + await entity.run_agent( + mock_context, {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-app-entity-3a"} + ) + assert entity.state.message_count == 1 + + await entity.run_agent( + mock_context, {"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-app-entity-3b"} + ) + assert entity.state.message_count == 2 + + def test_entity_reset(self) -> None: + """Test that entity reset clears state.""" + mock_agent = Mock() + entity = AgentEntity(mock_agent) + + # Set some state + entity.state.message_count = 10 + entity.state.last_response = "Some response" + entity.state.conversation_history = [ + ChatMessage(role="user", text="test", additional_properties={"timestamp": "2024-01-01T00:00:00Z"}) + ] + + # Reset + mock_context = Mock() + entity.reset(mock_context) + + assert entity.state.message_count == 0 + assert entity.state.last_response is None + assert len(entity.state.conversation_history) == 0 + + +class TestAgentEntityFactory: + """Test suite for the entity factory function.""" + + def test_create_agent_entity_returns_function(self) -> None: + """Test that create_agent_entity returns a function.""" + mock_agent = Mock() + entity_function = create_agent_entity(mock_agent) + + assert callable(entity_function) + + def test_entity_function_handles_run_agent_operation(self) -> None: + """Test that the entity function handles the run_agent operation.""" + mock_agent = Mock() + mock_agent.run = AsyncMock( + return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Response")]) + ) + + entity_function = create_agent_entity(mock_agent) + + # Mock context + mock_context = Mock() + mock_context.operation_name = "run_agent" + mock_context.get_input.return_value = { + "message": "Test message", + "conversation_id": "conv-123", + "correlation_id": "corr-app-factory-1", + } + mock_context.get_state.return_value = None + + # Execute entity function + entity_function(mock_context) + + # Verify result was set + assert mock_context.set_result.called + assert mock_context.set_state.called + + def test_entity_function_handles_reset_operation(self) -> None: + """Test that the entity function handles the reset operation.""" + mock_agent = Mock() + entity_function = create_agent_entity(mock_agent) + + # Mock context + mock_context = Mock() + mock_context.operation_name = "reset" + mock_context.get_state.return_value = { + "message_count": 5, + "conversation_history": [{"role": "user", "content": "test"}], + "last_response": "Test", + } + + # Execute entity function + entity_function(mock_context) + + # Verify result was set + assert mock_context.set_result.called + result_call = mock_context.set_result.call_args[0][0] + assert result_call["status"] == "reset" + + def test_entity_function_handles_unknown_operation(self) -> None: + """Test that the entity function handles an unknown operation.""" + mock_agent = Mock() + entity_function = create_agent_entity(mock_agent) + + # Mock context with unknown operation + mock_context = Mock() + mock_context.operation_name = "unknown_operation" + mock_context.get_state.return_value = None + + # Execute entity function + entity_function(mock_context) + + # Verify error result was set + assert mock_context.set_result.called + result_call = mock_context.set_result.call_args[0][0] + assert "error" in result_call + assert "unknown_operation" in result_call["error"] + + def test_entity_function_restores_state(self) -> None: + """Test that the entity function restores state from the context.""" + mock_agent = Mock() + entity_function = create_agent_entity(mock_agent) + + # Mock context with existing state + existing_state = { + "message_count": 3, + "conversation_history": [{"role": "user", "content": "msg1"}, {"role": "assistant", "content": "resp1"}], + "last_response": "resp1", + } + + mock_context = Mock() + mock_context.operation_name = "reset" + mock_context.get_state.return_value = existing_state + + with patch.object(AgentState, "restore_state") as restore_state_mock: + entity_function(mock_context) + + restore_state_mock.assert_called_once_with(existing_state) + + +class TestErrorHandling: + """Test suite for error handling.""" + + async def test_entity_handles_agent_error(self) -> None: + """Test that the entity handles agent execution errors.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(side_effect=Exception("Agent error")) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + result = await entity.run_agent( + mock_context, {"message": "Test message", "conversation_id": "conv-1", "correlation_id": "corr-app-error-1"} + ) + + assert result["status"] == "error" + assert "error" in result + assert "Agent error" in result["error"] + assert result["error_type"] == "Exception" + + def test_entity_function_handles_exception(self) -> None: + """Test that the entity function handles exceptions gracefully.""" + mock_agent = Mock() + # Force an exception by making get_input fail + mock_agent.run = AsyncMock(side_effect=Exception("Test error")) + + entity_function = create_agent_entity(mock_agent) + + mock_context = Mock() + mock_context.operation_name = "run_agent" + mock_context.get_input.side_effect = Exception("Input error") + mock_context.get_state.return_value = None + + # Execute entity function - should not raise + entity_function(mock_context) + + # Verify error result was set + assert mock_context.set_result.called + result_call = mock_context.set_result.call_args[0][0] + assert "error" in result_call + + +class TestIncomingRequestParsing: + """Tests for parsing run requests with JSON and plain text bodies.""" + + def _create_app(self) -> AgentFunctionApp: + mock_agent = Mock() + mock_agent.name = "ParserAgent" + return AgentFunctionApp(agents=[mock_agent], enable_health_check=False) + + def test_parse_plain_text_body(self) -> None: + """Test parsing a plain-text request body.""" + app = self._create_app() + + request = Mock() + request.get_json.side_effect = ValueError("Invalid JSON") + request.get_body.return_value = b"Plain text message" + + req_body, message = app._parse_incoming_request(request) + + assert req_body == {} + assert message == "Plain text message" + + def test_parse_plain_text_requires_content(self) -> None: + """Test that plain-text requests require message content.""" + app = self._create_app() + + request = Mock() + request.get_json.side_effect = ValueError("Invalid JSON") + request.get_body.return_value = b" " + + with pytest.raises(IncomingRequestError) as exc_info: + app._parse_incoming_request(request) + + assert "Message is required" in str(exc_info.value) + + def test_extract_session_key_from_query_params(self) -> None: + """Test session key extraction from query parameters.""" + app = self._create_app() + + request = Mock() + request.params = {"sessionId": "query-session"} + req_body = {} + + session_key = app._resolve_session_key(request, req_body) + + assert session_key == "query-session" + + +class TestHttpRunRoute: + """Tests for the HTTP run route behavior.""" + + async def test_http_run_accepts_plain_text(self) -> None: + """Test that the HTTP handler accepts plain-text requests.""" + mock_agent = Mock() + mock_agent.name = "HttpAgent" + + captured_handlers: dict[str | None, Callable[..., Awaitable[func.HttpResponse]]] = {} + + def capture_decorator(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]: + def decorator(func: TFunc) -> TFunc: + return func + + return decorator + + def capture_route(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]: + def decorator(func: TFunc) -> TFunc: + route_key = kwargs.get("route") if kwargs else None + captured_handlers[route_key] = func + return func + + return decorator + + with ( + patch.object(AgentFunctionApp, "function_name", new=capture_decorator), + patch.object(AgentFunctionApp, "route", new=capture_route), + patch.object(AgentFunctionApp, "durable_client_input", new=capture_decorator), + patch.object(AgentFunctionApp, "entity_trigger", new=capture_decorator), + ): + AgentFunctionApp(agents=[mock_agent], enable_health_check=False) + + run_route = f"agents/{mock_agent.name}/run" + handler = captured_handlers[run_route] + + request = Mock() + request.headers = {} + request.params = {} + request.route_params = {} + request.get_json.side_effect = ValueError("Invalid JSON") + request.get_body.return_value = b"Plain text via HTTP" + + client = AsyncMock() + + response = await handler(request, client) + + assert response.status_code == 202 + + signal_args = client.signal_entity.call_args[0] + run_request = signal_args[2] + + assert run_request["message"] == "Plain text via HTTP" + assert run_request["role"] == "user" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/azurefunctions/tests/test_entities.py b/python/packages/azurefunctions/tests/test_entities.py new file mode 100644 index 0000000000..cc67842a2b --- /dev/null +++ b/python/packages/azurefunctions/tests/test_entities.py @@ -0,0 +1,904 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for AgentEntity and entity operations. + +Run with: pytest tests/test_entities.py -v +""" + +import asyncio +from collections.abc import AsyncIterator, Callable +from datetime import datetime +from typing import Any, TypeVar +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from agent_framework import AgentRunResponse, AgentRunResponseUpdate, ChatMessage +from pydantic import BaseModel + +from agent_framework_azurefunctions._entities import AgentEntity, create_agent_entity +from agent_framework_azurefunctions._models import ChatRole, RunRequest +from agent_framework_azurefunctions._state import AgentState + +TFunc = TypeVar("TFunc", bound=Callable[..., Any]) + + +def _role_value(chat_message: ChatMessage) -> str: + """Helper to extract the string role from a ChatMessage.""" + role = getattr(chat_message, "role", None) + role_value = getattr(role, "value", role) + if role_value is None: + return "" + return str(role_value) + + +def _agent_response(text: str | None) -> AgentRunResponse: + """Create an AgentRunResponse with a single assistant message.""" + message = ( + ChatMessage(role="assistant", text=text) if text is not None else ChatMessage(role="assistant", contents=[]) + ) + return AgentRunResponse(messages=[message]) + + +class RecordingCallback: + """Callback implementation capturing streaming and final responses for assertions.""" + + def __init__(self): + self.stream_mock = AsyncMock() + self.response_mock = AsyncMock() + + async def on_streaming_response_update( + self, + update: AgentRunResponseUpdate, + context: Any, + ) -> None: + await self.stream_mock(update, context) + + async def on_agent_response(self, response: AgentRunResponse, context: Any) -> None: + await self.response_mock(response, context) + + +class EntityStructuredResponse(BaseModel): + answer: float + + +class TestAgentEntityInit: + """Test suite for AgentEntity initialization.""" + + def test_init_creates_entity(self) -> None: + """Test that AgentEntity initializes correctly.""" + mock_agent = Mock() + + entity = AgentEntity(mock_agent) + + assert entity.agent == mock_agent + assert entity.state.conversation_history == [] + assert entity.state.last_response is None + assert entity.state.message_count == 0 + + def test_init_stores_agent_reference(self) -> None: + """Test that the agent reference is stored correctly.""" + mock_agent = Mock() + mock_agent.name = "TestAgent" + + entity = AgentEntity(mock_agent) + + assert entity.agent.name == "TestAgent" + + def test_init_with_different_agent_types(self) -> None: + """Test initialization with different agent types.""" + agent1 = Mock() + agent1.__class__.__name__ = "AzureOpenAIAgent" + + agent2 = Mock() + agent2.__class__.__name__ = "CustomAgent" + + entity1 = AgentEntity(agent1) + entity2 = AgentEntity(agent2) + + assert entity1.agent.__class__.__name__ == "AzureOpenAIAgent" + assert entity2.agent.__class__.__name__ == "CustomAgent" + + +class TestAgentEntityRunAgent: + """Test suite for the run_agent operation.""" + + async def test_run_agent_executes_agent(self) -> None: + """Test that run_agent executes the agent.""" + mock_agent = Mock() + mock_response = _agent_response("Test response") + mock_agent.run = AsyncMock(return_value=mock_response) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + result = await entity.run_agent( + mock_context, {"message": "Test message", "conversation_id": "conv-123", "correlation_id": "corr-entity-1"} + ) + + # Verify agent.run was called + mock_agent.run.assert_called_once() + _, kwargs = mock_agent.run.call_args + sent_messages = kwargs.get("messages") + assert isinstance(sent_messages, list) + assert len(sent_messages) == 1 + sent_message = sent_messages[0] + assert isinstance(sent_message, ChatMessage) + assert sent_message.text == "Test message" + assert _role_value(sent_message) == "user" + + # Verify result + assert result["status"] == "success" + assert result["response"] == "Test response" + assert result["message"] == "Test message" + assert result["conversation_id"] == "conv-123" + + async def test_run_agent_streaming_callbacks_invoked(self) -> None: + """Ensure streaming updates trigger callbacks and run() is not used.""" + + updates = [ + AgentRunResponseUpdate(text="Hello"), + AgentRunResponseUpdate(text=" world"), + ] + + async def update_generator() -> AsyncIterator[AgentRunResponseUpdate]: + for update in updates: + yield update + + mock_agent = Mock() + mock_agent.name = "StreamingAgent" + mock_agent.run_stream = Mock(return_value=update_generator()) + mock_agent.run = AsyncMock(side_effect=AssertionError("run() should not be called when streaming succeeds")) + + callback = RecordingCallback() + entity = AgentEntity(mock_agent, callback=callback) + mock_context = Mock() + + result = await entity.run_agent( + mock_context, + { + "message": "Tell me something", + "conversation_id": "session-1", + "correlation_id": "corr-stream-1", + }, + ) + + assert result["status"] == "success" + assert "Hello" in result.get("response", "") + assert callback.stream_mock.await_count == len(updates) + assert callback.response_mock.await_count == 1 + mock_agent.run.assert_not_called() + + # Validate callback arguments + stream_calls = callback.stream_mock.await_args_list + for expected_update, recorded_call in zip(updates, stream_calls, strict=True): + assert recorded_call.args[0] is expected_update + context = recorded_call.args[1] + assert context.agent_name == "StreamingAgent" + assert context.correlation_id == "corr-stream-1" + assert context.conversation_id == "session-1" + assert context.request_message == "Tell me something" + + final_call = callback.response_mock.await_args + assert final_call is not None + final_response, final_context = final_call.args + assert final_context.agent_name == "StreamingAgent" + assert final_context.correlation_id == "corr-stream-1" + assert final_context.conversation_id == "session-1" + assert final_context.request_message == "Tell me something" + assert getattr(final_response, "text", "").strip() + + async def test_run_agent_final_callback_without_streaming(self) -> None: + """Ensure the final callback fires even when streaming is unavailable.""" + + mock_agent = Mock() + mock_agent.name = "NonStreamingAgent" + mock_agent.run_stream = None + agent_response = _agent_response("Final response") + mock_agent.run = AsyncMock(return_value=agent_response) + + callback = RecordingCallback() + entity = AgentEntity(mock_agent, callback=callback) + mock_context = Mock() + + result = await entity.run_agent( + mock_context, + { + "message": "Hi", + "conversation_id": "session-2", + "correlation_id": "corr-final-1", + }, + ) + + assert result["status"] == "success" + assert result.get("response") == "Final response" + assert callback.stream_mock.await_count == 0 + assert callback.response_mock.await_count == 1 + + final_call = callback.response_mock.await_args + assert final_call is not None + assert final_call.args[0] is agent_response + final_context = final_call.args[1] + assert final_context.agent_name == "NonStreamingAgent" + assert final_context.correlation_id == "corr-final-1" + assert final_context.conversation_id == "session-2" + assert final_context.request_message == "Hi" + + async def test_run_agent_updates_conversation_history(self) -> None: + """Test that run_agent updates the conversation history.""" + mock_agent = Mock() + mock_response = _agent_response("Agent response") + mock_agent.run = AsyncMock(return_value=mock_response) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + await entity.run_agent( + mock_context, {"message": "User message", "conversation_id": "conv-1", "correlation_id": "corr-entity-2"} + ) + + # Should have 2 entries: user message + assistant response + history = entity.state.conversation_history + + assert len(history) == 2 + + user_msg = history[0] + assert _role_value(user_msg) == "user" + assert user_msg.text == "User message" + + assistant_msg = history[1] + assert _role_value(assistant_msg) == "assistant" + assert assistant_msg.text == "Agent response" + + async def test_run_agent_increments_message_count(self) -> None: + """Test that run_agent increments the message count.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(return_value=_agent_response("Response")) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + assert entity.state.message_count == 0 + + await entity.run_agent( + mock_context, {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-entity-3a"} + ) + assert entity.state.message_count == 1 + + await entity.run_agent( + mock_context, {"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-entity-3b"} + ) + assert entity.state.message_count == 2 + + await entity.run_agent( + mock_context, {"message": "Message 3", "conversation_id": "conv-1", "correlation_id": "corr-entity-3c"} + ) + assert entity.state.message_count == 3 + + async def test_run_agent_stores_last_response(self) -> None: + """Test that run_agent stores the last response.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(return_value=_agent_response("Response 1")) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + await entity.run_agent( + mock_context, {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-entity-4a"} + ) + assert entity.state.last_response == "Response 1" + + mock_agent.run = AsyncMock(return_value=_agent_response("Response 2")) + await entity.run_agent( + mock_context, {"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-entity-4b"} + ) + assert entity.state.last_response == "Response 2" + + async def test_run_agent_with_none_conversation_id(self) -> None: + """Test run_agent with a None conversation identifier.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(return_value=_agent_response("Response")) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + with pytest.raises(ValueError, match="conversation_id"): + await entity.run_agent( + mock_context, {"message": "Message", "conversation_id": None, "correlation_id": "corr-entity-5"} + ) + + async def test_run_agent_handles_response_without_text_attribute(self) -> None: + """Test that run_agent handles responses without a text attribute.""" + mock_agent = Mock() + + class NoTextResponse(AgentRunResponse): + @property + def text(self) -> str: # type: ignore[override] + raise AttributeError("text attribute missing") + + mock_response = NoTextResponse(messages=[ChatMessage(role="assistant", text="ignored")]) + mock_agent.run = AsyncMock(return_value=mock_response) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + result = await entity.run_agent( + mock_context, {"message": "Message", "conversation_id": "conv-1", "correlation_id": "corr-entity-6"} + ) + + # Should handle gracefully + assert result["status"] == "success" + assert result["response"] == "Error extracting response" + + async def test_run_agent_handles_none_response_text(self) -> None: + """Test that run_agent handles responses with None text.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(return_value=_agent_response(None)) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + result = await entity.run_agent( + mock_context, {"message": "Message", "conversation_id": "conv-1", "correlation_id": "corr-entity-7"} + ) + + assert result["status"] == "success" + assert result["response"] == "No response" + + async def test_run_agent_multiple_conversations(self) -> None: + """Test that run_agent maintains history across multiple messages.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(return_value=_agent_response("Response")) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + # Send multiple messages + await entity.run_agent( + mock_context, {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-entity-8a"} + ) + await entity.run_agent( + mock_context, {"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-entity-8b"} + ) + await entity.run_agent( + mock_context, {"message": "Message 3", "conversation_id": "conv-1", "correlation_id": "corr-entity-8c"} + ) + + history = entity.state.conversation_history + assert len(history) == 6 + assert entity.state.message_count == 3 + + +class TestAgentEntityReset: + """Test suite for the reset operation.""" + + def test_reset_clears_conversation_history(self) -> None: + """Test that reset clears the conversation history.""" + mock_agent = Mock() + entity = AgentEntity(mock_agent) + + # Add some history + entity.state.conversation_history = [ + ChatMessage(role="user", text="msg1"), + ChatMessage(role="assistant", text="resp1"), + ] + + mock_context = Mock() + entity.reset(mock_context) + + assert entity.state.conversation_history == [] + + def test_reset_clears_last_response(self) -> None: + """Test that reset clears the last response.""" + mock_agent = Mock() + entity = AgentEntity(mock_agent) + + entity.state.last_response = "Some response" + + mock_context = Mock() + entity.reset(mock_context) + + assert entity.state.last_response is None + + def test_reset_clears_message_count(self) -> None: + """Test that reset clears the message count.""" + mock_agent = Mock() + entity = AgentEntity(mock_agent) + + entity.state.message_count = 10 + + mock_context = Mock() + entity.reset(mock_context) + + assert entity.state.message_count == 0 + + async def test_reset_after_conversation(self) -> None: + """Test reset after a full conversation.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(return_value=_agent_response("Response")) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + # Have a conversation + await entity.run_agent( + mock_context, {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-entity-10a"} + ) + await entity.run_agent( + mock_context, {"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-entity-10b"} + ) + + # Verify state before reset + assert entity.state.message_count == 2 + assert len(entity.state.conversation_history) == 4 + + # Reset + entity.reset(mock_context) + + # Verify state after reset + assert entity.state.message_count == 0 + assert len(entity.state.conversation_history) == 0 + assert entity.state.last_response is None + + +class TestCreateAgentEntity: + """Test suite for the create_agent_entity factory function.""" + + def test_create_agent_entity_returns_callable(self) -> None: + """Test that create_agent_entity returns a callable.""" + mock_agent = Mock() + + entity_function = create_agent_entity(mock_agent) + + assert callable(entity_function) + + def test_entity_function_handles_run_agent(self) -> None: + """Test that the entity function handles the run_agent operation.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(return_value=_agent_response("Response")) + + entity_function = create_agent_entity(mock_agent) + + # Mock context + mock_context = Mock() + mock_context.operation_name = "run_agent" + mock_context.get_input.return_value = { + "message": "Test message", + "conversation_id": "conv-123", + "correlation_id": "corr-entity-factory", + } + mock_context.get_state.return_value = None + + # Execute + entity_function(mock_context) + + # Verify result and state were set + assert mock_context.set_result.called + assert mock_context.set_state.called + + def test_entity_function_handles_reset(self) -> None: + """Test that the entity function handles the reset operation.""" + mock_agent = Mock() + + entity_function = create_agent_entity(mock_agent) + + # Mock context with existing state + mock_context = Mock() + mock_context.operation_name = "reset" + mock_context.get_state.return_value = { + "message_count": 5, + "conversation_history": [ + ChatMessage( + role="user", text="test", additional_properties={"timestamp": "2024-01-01T00:00:00Z"} + ).to_dict() + ], + "last_response": "Test", + } + + # Execute + entity_function(mock_context) + + # Verify reset result + assert mock_context.set_result.called + result = mock_context.set_result.call_args[0][0] + assert result["status"] == "reset" + + # Verify state was cleared + assert mock_context.set_state.called + state = mock_context.set_state.call_args[0][0] + assert state["message_count"] == 0 + assert state["conversation_history"] == [] + assert state["last_response"] is None + + def test_entity_function_handles_unknown_operation(self) -> None: + """Test that the entity function handles unknown operations.""" + mock_agent = Mock() + + entity_function = create_agent_entity(mock_agent) + + mock_context = Mock() + mock_context.operation_name = "invalid_operation" + mock_context.get_state.return_value = None + + # Execute + entity_function(mock_context) + + # Verify error result + assert mock_context.set_result.called + result = mock_context.set_result.call_args[0][0] + assert "error" in result + assert "invalid_operation" in result["error"].lower() + + def test_entity_function_creates_new_entity_on_first_call(self) -> None: + """Test that the entity function creates a new entity when no state exists.""" + mock_agent = Mock() + mock_agent.__class__.__name__ = "Agent" + + entity_function = create_agent_entity(mock_agent) + mock_context = Mock() + mock_context.operation_name = "reset" + mock_context.get_state.return_value = None # No existing state + + # Execute + entity_function(mock_context) + + # Verify new entity state was created + assert mock_context.set_result.called + result = mock_context.set_result.call_args[0][0] + assert result["status"] == "reset" + assert mock_context.set_state.called + state = mock_context.set_state.call_args[0][0] + assert state["message_count"] == 0 + assert state["conversation_history"] == [] + + def test_entity_function_restores_existing_state(self) -> None: + """Test that the entity function restores existing state.""" + mock_agent = Mock() + + entity_function = create_agent_entity(mock_agent) + + existing_state = { + "message_count": 5, + "conversation_history": [ + ChatMessage( + role="user", text="msg1", additional_properties={"timestamp": "2024-01-01T00:00:00Z"} + ).to_dict(), + ChatMessage( + role="assistant", text="resp1", additional_properties={"timestamp": "2024-01-01T00:05:00Z"} + ).to_dict(), + ], + "last_response": "resp1", + } + + mock_context = Mock() + mock_context.operation_name = "reset" + mock_context.get_state.return_value = existing_state + + with patch.object(AgentState, "restore_state") as restore_state_mock: + entity_function(mock_context) + + restore_state_mock.assert_called_once_with(existing_state) + + +class TestErrorHandling: + """Test suite for error handling in entities.""" + + async def test_run_agent_handles_agent_exception(self) -> None: + """Test that run_agent handles agent exceptions.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(side_effect=Exception("Agent failed")) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + result = await entity.run_agent( + mock_context, {"message": "Message", "conversation_id": "conv-1", "correlation_id": "corr-entity-error-1"} + ) + + assert result["status"] == "error" + assert "error" in result + assert "Agent failed" in result["error"] + assert result["error_type"] == "Exception" + + async def test_run_agent_handles_value_error(self) -> None: + """Test that run_agent handles ValueError instances.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(side_effect=ValueError("Invalid input")) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + result = await entity.run_agent( + mock_context, {"message": "Message", "conversation_id": "conv-1", "correlation_id": "corr-entity-error-2"} + ) + + assert result["status"] == "error" + assert result["error_type"] == "ValueError" + assert "Invalid input" in result["error"] + + async def test_run_agent_handles_timeout_error(self) -> None: + """Test that run_agent handles TimeoutError instances.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(side_effect=TimeoutError("Request timeout")) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + result = await entity.run_agent( + mock_context, {"message": "Message", "conversation_id": "conv-1", "correlation_id": "corr-entity-error-3"} + ) + + assert result["status"] == "error" + assert result["error_type"] == "TimeoutError" + + def test_entity_function_handles_exception_in_operation(self) -> None: + """Test that the entity function handles exceptions gracefully.""" + mock_agent = Mock() + + entity_function = create_agent_entity(mock_agent) + + mock_context = Mock() + mock_context.operation_name = "run_agent" + mock_context.get_input.side_effect = Exception("Input error") + mock_context.get_state.return_value = None + + # Execute - should not raise + entity_function(mock_context) + + # Verify error was set + assert mock_context.set_result.called + result = mock_context.set_result.call_args[0][0] + assert "error" in result + + async def test_run_agent_preserves_message_on_error(self) -> None: + """Test that run_agent preserves message information on error.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(side_effect=Exception("Error")) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + result = await entity.run_agent( + mock_context, + {"message": "Test message", "conversation_id": "conv-123", "correlation_id": "corr-entity-error-4"}, + ) + + # Even on error, message info should be preserved + assert result["message"] == "Test message" + assert result["conversation_id"] == "conv-123" + assert result["status"] == "error" + + +class TestConversationHistory: + """Test suite for conversation history tracking.""" + + async def test_conversation_history_has_timestamps(self) -> None: + """Test that conversation history entries include timestamps.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(return_value=_agent_response("Response")) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + await entity.run_agent( + mock_context, {"message": "Message", "conversation_id": "conv-1", "correlation_id": "corr-entity-history-1"} + ) + + # Check both user and assistant messages have timestamps + for entry in entity.state.conversation_history: + timestamp = entry.additional_properties.get("timestamp") + assert timestamp is not None + # Verify timestamp is in ISO format + datetime.fromisoformat(timestamp) + + async def test_conversation_history_ordering(self) -> None: + """Test that conversation history maintains the correct order.""" + mock_agent = Mock() + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + # Send multiple messages with different responses + mock_agent.run = AsyncMock(return_value=_agent_response("Response 1")) + await entity.run_agent( + mock_context, + {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-entity-history-2a"}, + ) + + mock_agent.run = AsyncMock(return_value=_agent_response("Response 2")) + await entity.run_agent( + mock_context, + {"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-entity-history-2b"}, + ) + + mock_agent.run = AsyncMock(return_value=_agent_response("Response 3")) + await entity.run_agent( + mock_context, + {"message": "Message 3", "conversation_id": "conv-1", "correlation_id": "corr-entity-history-2c"}, + ) + + # Verify order + history = entity.state.conversation_history + assert history[0].text == "Message 1" + assert history[1].text == "Response 1" + assert history[2].text == "Message 2" + assert history[3].text == "Response 2" + assert history[4].text == "Message 3" + assert history[5].text == "Response 3" + + async def test_conversation_history_role_alternation(self) -> None: + """Test that conversation history alternates between user and assistant roles.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(return_value=_agent_response("Response")) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + await entity.run_agent( + mock_context, + {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-entity-history-3a"}, + ) + await entity.run_agent( + mock_context, + {"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-entity-history-3b"}, + ) + + # Check role alternation + history = entity.state.conversation_history + assert _role_value(history[0]) == "user" + assert _role_value(history[1]) == "assistant" + assert _role_value(history[2]) == "user" + assert _role_value(history[3]) == "assistant" + + +class TestRunRequestSupport: + """Test suite for RunRequest support in entities.""" + + async def test_run_agent_with_run_request_object(self) -> None: + """Test run_agent with a RunRequest object.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(return_value=_agent_response("Response")) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + request = RunRequest( + message="Test message", + conversation_id="conv-123", + role=ChatRole.USER, + enable_tool_calls=True, + correlation_id="corr-runreq-1", + ) + + result = await entity.run_agent(mock_context, request) + + assert result["status"] == "success" + assert result["response"] == "Response" + assert result["message"] == "Test message" + assert result["conversation_id"] == "conv-123" + + async def test_run_agent_with_dict_request(self) -> None: + """Test run_agent with a dictionary request.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(return_value=_agent_response("Response")) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + request_dict = { + "message": "Test message", + "conversation_id": "conv-456", + "role": "system", + "enable_tool_calls": False, + "correlation_id": "corr-runreq-2", + } + + result = await entity.run_agent(mock_context, request_dict) + + assert result["status"] == "success" + assert result["message"] == "Test message" + assert result["conversation_id"] == "conv-456" + + async def test_run_agent_with_string_raises_without_correlation(self) -> None: + """Test that run_agent rejects legacy string input without correlation ID.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(return_value=_agent_response("Response")) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + with pytest.raises(ValueError): + await entity.run_agent(mock_context, "Simple message") + + async def test_run_agent_stores_role_in_history(self) -> None: + """Test that run_agent stores the role in conversation history.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(return_value=_agent_response("Response")) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + # Send as system role + request = RunRequest( + message="System message", + conversation_id="conv-runreq-3", + role=ChatRole.SYSTEM, + correlation_id="corr-runreq-3", + ) + + await entity.run_agent(mock_context, request) + + # Check that system role was stored + history = entity.state.conversation_history + assert _role_value(history[0]) == "system" + assert history[0].text == "System message" + + async def test_run_agent_with_response_format(self) -> None: + """Test run_agent with a JSON response format.""" + mock_agent = Mock() + # Return JSON response + mock_agent.run = AsyncMock(return_value=_agent_response('{"answer": 42}')) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + request = RunRequest( + message="What is the answer?", + conversation_id="conv-runreq-4", + response_format=EntityStructuredResponse, + correlation_id="corr-runreq-4", + ) + + result = await entity.run_agent(mock_context, request) + + assert result["status"] == "success" + # Should have structured_response + if "structured_response" in result: + assert result["structured_response"]["answer"] == 42 + + async def test_run_agent_disable_tool_calls(self) -> None: + """Test run_agent with tool calls disabled.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(return_value=_agent_response("Response")) + + entity = AgentEntity(mock_agent) + mock_context = Mock() + + request = RunRequest( + message="Test", conversation_id="conv-runreq-5", enable_tool_calls=False, correlation_id="corr-runreq-5" + ) + + result = await entity.run_agent(mock_context, request) + + assert result["status"] == "success" + # Agent should have been called (tool disabling is framework-dependent) + mock_agent.run.assert_called_once() + + async def test_entity_function_with_run_request_dict(self) -> None: + """Test that the entity function handles the RunRequest dict format.""" + mock_agent = Mock() + mock_agent.run = AsyncMock(return_value=_agent_response("Response")) + + entity_function = create_agent_entity(mock_agent) + + mock_context = Mock() + mock_context.operation_name = "run_agent" + mock_context.get_input.return_value = { + "message": "Test message", + "conversation_id": "conv-789", + "role": "user", + "enable_tool_calls": True, + "correlation_id": "corr-runreq-6", + } + mock_context.get_state.return_value = None + + await asyncio.to_thread(entity_function, mock_context) + + # Verify result was set + assert mock_context.set_result.called + result = mock_context.set_result.call_args[0][0] + assert result["status"] == "success" + assert result["message"] == "Test message" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/azurefunctions/tests/test_models.py b/python/packages/azurefunctions/tests/test_models.py new file mode 100644 index 0000000000..916bdb358a --- /dev/null +++ b/python/packages/azurefunctions/tests/test_models.py @@ -0,0 +1,478 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for data models (AgentSessionId, RunRequest, AgentResponse, ChatRole).""" + +import azure.durable_functions as df +import pytest +from pydantic import BaseModel + +from agent_framework_azurefunctions._models import AgentResponse, AgentSessionId, ChatRole, RunRequest + + +class ModuleStructuredResponse(BaseModel): + value: int + + +class TestChatRole: + """Test suite for ChatRole enum.""" + + def test_chat_role_values(self) -> None: + """Test that ChatRole has correct values.""" + assert ChatRole.USER == "user" + assert ChatRole.SYSTEM == "system" + assert ChatRole.ASSISTANT == "assistant" + + def test_chat_role_is_string(self) -> None: + """Test that ChatRole values are strings.""" + assert isinstance(ChatRole.USER.value, str) + assert isinstance(ChatRole.SYSTEM.value, str) + assert isinstance(ChatRole.ASSISTANT.value, str) + + +class TestAgentSessionId: + """Test suite for AgentSessionId.""" + + def test_init_creates_session_id(self) -> None: + """Test that AgentSessionId initializes correctly.""" + session_id = AgentSessionId(name="AgentEntity", key="test-key-123") + + assert session_id.name == "AgentEntity" + assert session_id.key == "test-key-123" + + def test_with_random_key_generates_guid(self) -> None: + """Test that with_random_key generates a GUID.""" + session_id = AgentSessionId.with_random_key(name="AgentEntity") + + assert session_id.name == "AgentEntity" + assert len(session_id.key) == 32 # UUID hex is 32 chars + # Verify it's a valid hex string + int(session_id.key, 16) + + def test_with_random_key_unique_keys(self) -> None: + """Test that with_random_key generates unique keys.""" + session_id1 = AgentSessionId.with_random_key(name="AgentEntity") + session_id2 = AgentSessionId.with_random_key(name="AgentEntity") + + assert session_id1.key != session_id2.key + + def test_to_entity_id_conversion(self) -> None: + """Test conversion to EntityId.""" + session_id = AgentSessionId(name="AgentEntity", key="test-key") + entity_id = session_id.to_entity_id() + + assert isinstance(entity_id, df.EntityId) + assert entity_id.name == "dafx-AgentEntity" + assert entity_id.key == "test-key" + + def test_from_entity_id_conversion(self) -> None: + """Test creation from EntityId.""" + entity_id = df.EntityId(name="dafx-AgentEntity", key="test-key") + session_id = AgentSessionId.from_entity_id(entity_id) + + assert isinstance(session_id, AgentSessionId) + assert session_id.name == "AgentEntity" + assert session_id.key == "test-key" + + def test_round_trip_entity_id_conversion(self) -> None: + """Test round-trip conversion to and from EntityId.""" + original = AgentSessionId(name="AgentEntity", key="test-key") + entity_id = original.to_entity_id() + restored = AgentSessionId.from_entity_id(entity_id) + + assert restored.name == original.name + assert restored.key == original.key + + def test_str_representation(self) -> None: + """Test string representation.""" + session_id = AgentSessionId(name="AgentEntity", key="test-key-123") + str_repr = str(session_id) + + assert str_repr == "@AgentEntity@test-key-123" + + def test_repr_representation(self) -> None: + """Test repr representation.""" + session_id = AgentSessionId(name="AgentEntity", key="test-key") + repr_str = repr(session_id) + + assert "AgentSessionId" in repr_str + assert "AgentEntity" in repr_str + assert "test-key" in repr_str + + def test_parse_valid_session_id(self) -> None: + """Test parsing valid session ID string.""" + session_id = AgentSessionId.parse("@AgentEntity@test-key-123") + + assert session_id.name == "AgentEntity" + assert session_id.key == "test-key-123" + + def test_parse_invalid_format_no_prefix(self) -> None: + """Test parsing invalid format without @ prefix.""" + with pytest.raises(ValueError) as exc_info: + AgentSessionId.parse("AgentEntity@test-key") + + assert "Invalid agent session ID format" in str(exc_info.value) + + def test_parse_invalid_format_single_part(self) -> None: + """Test parsing invalid format with single part.""" + with pytest.raises(ValueError) as exc_info: + AgentSessionId.parse("@AgentEntity") + + assert "Invalid agent session ID format" in str(exc_info.value) + + def test_parse_with_multiple_at_signs_in_key(self) -> None: + """Test parsing with @ signs in the key.""" + session_id = AgentSessionId.parse("@AgentEntity@key-with@symbols") + + assert session_id.name == "AgentEntity" + assert session_id.key == "key-with@symbols" + + def test_parse_round_trip(self) -> None: + """Test round-trip parse and string conversion.""" + original = AgentSessionId(name="AgentEntity", key="test-key") + str_repr = str(original) + parsed = AgentSessionId.parse(str_repr) + + assert parsed.name == original.name + assert parsed.key == original.key + + def test_to_entity_name_adds_prefix(self) -> None: + """Test that to_entity_name adds the dafx- prefix.""" + entity_name = AgentSessionId.to_entity_name("TestAgent") + assert entity_name == "dafx-TestAgent" + + def test_from_entity_id_strips_prefix(self) -> None: + """Test that from_entity_id strips the dafx- prefix.""" + entity_id = df.EntityId(name="dafx-TestAgent", key="key123") + session_id = AgentSessionId.from_entity_id(entity_id) + + assert session_id.name == "TestAgent" + assert session_id.key == "key123" + + def test_from_entity_id_raises_without_prefix(self) -> None: + """Test that from_entity_id raises ValueError when entity name lacks the prefix.""" + entity_id = df.EntityId(name="TestAgent", key="key123") + + with pytest.raises(ValueError) as exc_info: + AgentSessionId.from_entity_id(entity_id) + + assert "not a valid agent session ID" in str(exc_info.value) + assert "dafx-" in str(exc_info.value) + + +class TestRunRequest: + """Test suite for RunRequest.""" + + def test_init_with_defaults(self) -> None: + """Test RunRequest initialization with defaults.""" + request = RunRequest(message="Hello", conversation_id="conv-default") + + assert request.message == "Hello" + assert request.role == ChatRole.USER + assert request.response_format is None + assert request.enable_tool_calls is True + assert request.conversation_id == "conv-default" + + def test_init_with_all_fields(self) -> None: + """Test RunRequest initialization with all fields.""" + schema = ModuleStructuredResponse + request = RunRequest( + message="Hello", + conversation_id="conv-123", + role=ChatRole.SYSTEM, + response_format=schema, + enable_tool_calls=False, + ) + + assert request.message == "Hello" + assert request.role == ChatRole.SYSTEM + assert request.response_format is schema + assert request.enable_tool_calls is False + assert request.conversation_id == "conv-123" + + def test_to_dict_with_defaults(self) -> None: + """Test to_dict with default values.""" + request = RunRequest(message="Test message", conversation_id="conv-to-dict") + data = request.to_dict() + + assert data["message"] == "Test message" + assert data["enable_tool_calls"] is True + assert data["role"] == "user" + assert "response_format" not in data or data["response_format"] is None + assert data["conversation_id"] == "conv-to-dict" + + def test_to_dict_with_all_fields(self) -> None: + """Test to_dict with all fields.""" + schema = ModuleStructuredResponse + request = RunRequest( + message="Hello", + conversation_id="conv-456", + role=ChatRole.ASSISTANT, + response_format=schema, + enable_tool_calls=False, + ) + data = request.to_dict() + + assert data["message"] == "Hello" + assert data["role"] == "assistant" + assert data["response_format"]["__response_schema_type__"] == "pydantic_model" + assert data["response_format"]["module"] == schema.__module__ + assert data["response_format"]["qualname"] == schema.__qualname__ + assert data["enable_tool_calls"] is False + assert data["conversation_id"] == "conv-456" + + def test_from_dict_with_defaults(self) -> None: + """Test from_dict with minimal data.""" + data = {"message": "Hello", "conversation_id": "conv-from-dict"} + request = RunRequest.from_dict(data) + + assert request.message == "Hello" + assert request.role == ChatRole.USER + assert request.enable_tool_calls is True + assert request.conversation_id == "conv-from-dict" + + def test_from_dict_with_all_fields(self) -> None: + """Test from_dict with all fields.""" + data = { + "message": "Test", + "role": "system", + "response_format": { + "__response_schema_type__": "pydantic_model", + "module": ModuleStructuredResponse.__module__, + "qualname": ModuleStructuredResponse.__qualname__, + }, + "enable_tool_calls": False, + "conversation_id": "conv-789", + } + request = RunRequest.from_dict(data) + + assert request.message == "Test" + assert request.role == ChatRole.SYSTEM + assert request.response_format is ModuleStructuredResponse + assert request.enable_tool_calls is False + assert request.conversation_id == "conv-789" + + def test_from_dict_invalid_role_defaults_to_user(self) -> None: + """Test from_dict with invalid role defaults to USER.""" + data = {"message": "Test", "role": "invalid_role", "conversation_id": "conv-invalid-role"} + request = RunRequest.from_dict(data) + + assert request.role == ChatRole.USER + + def test_from_dict_empty_message(self) -> None: + """Test from_dict with empty message.""" + data = {"conversation_id": "conv-empty"} + request = RunRequest.from_dict(data) + + assert request.message == "" + assert request.role == ChatRole.USER + assert request.conversation_id == "conv-empty" + + def test_round_trip_dict_conversion(self) -> None: + """Test round-trip to_dict and from_dict.""" + original = RunRequest( + message="Test message", + conversation_id="conv-123", + role=ChatRole.SYSTEM, + response_format=ModuleStructuredResponse, + enable_tool_calls=False, + ) + + data = original.to_dict() + restored = RunRequest.from_dict(data) + + assert restored.message == original.message + assert restored.role == original.role + assert restored.response_format is ModuleStructuredResponse + assert restored.enable_tool_calls == original.enable_tool_calls + assert restored.conversation_id == original.conversation_id + + def test_round_trip_with_pydantic_response_format(self) -> None: + """Ensure Pydantic response formats serialize and deserialize properly.""" + original = RunRequest( + message="Structured", + conversation_id="conv-pydantic", + response_format=ModuleStructuredResponse, + ) + + data = original.to_dict() + + assert data["response_format"]["__response_schema_type__"] == "pydantic_model" + assert data["response_format"]["module"] == ModuleStructuredResponse.__module__ + assert data["response_format"]["qualname"] == ModuleStructuredResponse.__qualname__ + + restored = RunRequest.from_dict(data) + assert restored.response_format is ModuleStructuredResponse + + def test_init_with_correlation_id(self) -> None: + """Test RunRequest initialization with correlation_id.""" + request = RunRequest(message="Test message", conversation_id="conv-corr-init", correlation_id="corr-123") + + assert request.message == "Test message" + assert request.correlation_id == "corr-123" + + def test_to_dict_with_correlation_id(self) -> None: + """Test to_dict includes correlation_id.""" + request = RunRequest(message="Test", conversation_id="conv-corr-to-dict", correlation_id="corr-456") + data = request.to_dict() + + assert data["message"] == "Test" + assert data["correlation_id"] == "corr-456" + + def test_from_dict_with_correlation_id(self) -> None: + """Test from_dict with correlation_id.""" + data = {"message": "Test", "correlation_id": "corr-789", "conversation_id": "conv-corr-from-dict"} + request = RunRequest.from_dict(data) + + assert request.message == "Test" + assert request.correlation_id == "corr-789" + assert request.conversation_id == "conv-corr-from-dict" + + def test_round_trip_with_correlation_id(self) -> None: + """Test round-trip to_dict and from_dict with correlation_id.""" + original = RunRequest( + message="Test message", + conversation_id="conv-123", + role=ChatRole.SYSTEM, + correlation_id="corr-123", + ) + + data = original.to_dict() + restored = RunRequest.from_dict(data) + + assert restored.message == original.message + assert restored.role == original.role + assert restored.correlation_id == original.correlation_id + assert restored.conversation_id == original.conversation_id + + +class TestAgentResponse: + """Test suite for AgentResponse.""" + + def test_init_with_required_fields(self) -> None: + """Test AgentResponse initialization with required fields.""" + response = AgentResponse( + response="Test response", message="Test message", conversation_id="conv-123", status="success" + ) + + assert response.response == "Test response" + assert response.message == "Test message" + assert response.conversation_id == "conv-123" + assert response.status == "success" + assert response.message_count == 0 + assert response.error is None + assert response.error_type is None + assert response.structured_response is None + + def test_init_with_all_fields(self) -> None: + """Test AgentResponse initialization with all fields.""" + structured = {"answer": "42"} + response = AgentResponse( + response=None, + message="What is the answer?", + conversation_id="conv-456", + status="success", + message_count=5, + error=None, + error_type=None, + structured_response=structured, + ) + + assert response.response is None + assert response.structured_response == structured + assert response.message_count == 5 + + def test_to_dict_with_text_response(self) -> None: + """Test to_dict with text response.""" + response = AgentResponse( + response="Text response", message="Message", conversation_id="conv-1", status="success", message_count=3 + ) + data = response.to_dict() + + assert data["response"] == "Text response" + assert data["message"] == "Message" + assert data["conversation_id"] == "conv-1" + assert data["status"] == "success" + assert data["message_count"] == 3 + assert "structured_response" not in data + assert "error" not in data + assert "error_type" not in data + + def test_to_dict_with_structured_response(self) -> None: + """Test to_dict with structured response.""" + structured = {"answer": 42, "confidence": 0.95} + response = AgentResponse( + response=None, + message="Question", + conversation_id="conv-2", + status="success", + structured_response=structured, + ) + data = response.to_dict() + + assert data["structured_response"] == structured + assert "response" not in data + + def test_to_dict_with_error(self) -> None: + """Test to_dict with error.""" + response = AgentResponse( + response=None, + message="Failed message", + conversation_id="conv-3", + status="error", + error="Something went wrong", + error_type="ValueError", + ) + data = response.to_dict() + + assert data["status"] == "error" + assert data["error"] == "Something went wrong" + assert data["error_type"] == "ValueError" + + def test_to_dict_prefers_structured_over_text(self) -> None: + """Test to_dict prefers structured_response over response.""" + structured = {"result": "structured"} + response = AgentResponse( + response="Text response", + message="Message", + conversation_id="conv-4", + status="success", + structured_response=structured, + ) + data = response.to_dict() + + assert "structured_response" in data + assert data["structured_response"] == structured + # Text response should not be included when structured is present + assert "response" not in data + + +class TestModelIntegration: + """Test suite for integration between models.""" + + def test_run_request_with_session_id(self) -> None: + """Test using RunRequest with AgentSessionId.""" + session_id = AgentSessionId.with_random_key("AgentEntity") + request = RunRequest(message="Test message", conversation_id=str(session_id)) + + assert request.conversation_id is not None + assert request.conversation_id == str(session_id) + assert request.conversation_id.startswith("@AgentEntity@") + + def test_response_from_run_request(self) -> None: + """Test creating AgentResponse from RunRequest.""" + request = RunRequest(message="What is 2+2?", conversation_id="conv-123", role=ChatRole.USER) + + response = AgentResponse( + response="4", + message=request.message, + conversation_id=request.conversation_id, + status="success", + message_count=1, + ) + + assert response.message == request.message + assert response.conversation_id == request.conversation_id + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/azurefunctions/tests/test_multi_agent.py b/python/packages/azurefunctions/tests/test_multi_agent.py new file mode 100644 index 0000000000..0c0be7f35d --- /dev/null +++ b/python/packages/azurefunctions/tests/test_multi_agent.py @@ -0,0 +1,150 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for multi-agent support in AgentFunctionApp.""" + +from unittest.mock import Mock + +import pytest + +from agent_framework_azurefunctions import AgentFunctionApp + + +class TestMultiAgentInit: + """Test suite for multi-agent initialization.""" + + def test_init_with_agents_list(self) -> None: + """Test initialization with list of agents.""" + agent1 = Mock() + agent1.name = "Agent1" + agent2 = Mock() + agent2.name = "Agent2" + + app = AgentFunctionApp(agents=[agent1, agent2]) + + assert len(app.agents) == 2 + assert "Agent1" in app.agents + assert "Agent2" in app.agents + assert app.agents["Agent1"] == agent1 + assert app.agents["Agent2"] == agent2 + + def test_init_with_empty_agents_list(self) -> None: + """Test initialization with empty list of agents.""" + app = AgentFunctionApp(agents=[]) + + assert len(app.agents) == 0 + + def test_init_with_no_agents(self) -> None: + """Test initialization without any agents.""" + app = AgentFunctionApp() + + assert len(app.agents) == 0 + + def test_init_with_duplicate_agent_names(self) -> None: + """Test initialization with agents having the same name raises error.""" + agent1 = Mock() + agent1.name = "TestAgent" + agent2 = Mock() + agent2.name = "TestAgent" + + with pytest.raises(ValueError, match="already registered"): + AgentFunctionApp(agents=[agent1, agent2]) + + def test_init_with_agent_without_name(self) -> None: + """Test initialization with agent missing name attribute raises error.""" + agent1 = Mock() + agent1.name = "Agent1" + agent2 = Mock(spec=[]) # Mock without name attribute + + with pytest.raises(ValueError, match="does not have a 'name' attribute"): + AgentFunctionApp(agents=[agent1, agent2]) + + +class TestAddAgentMethod: + """Test suite for add_agent() method.""" + + def test_add_agent_to_empty_app(self) -> None: + """Test adding agent to app initialized without agents.""" + app = AgentFunctionApp() + + agent = Mock() + agent.name = "NewAgent" + + app.add_agent(agent) + + assert len(app.agents) == 1 + assert "NewAgent" in app.agents + assert app.agents["NewAgent"] == agent + + def test_add_multiple_agents(self) -> None: + """Test adding multiple agents sequentially.""" + app = AgentFunctionApp() + + agent1 = Mock() + agent1.name = "Agent1" + agent2 = Mock() + agent2.name = "Agent2" + + app.add_agent(agent1) + app.add_agent(agent2) + + assert len(app.agents) == 2 + assert "Agent1" in app.agents + assert "Agent2" in app.agents + + def test_add_agent_with_duplicate_name_raises_error(self) -> None: + """Test that adding agent with duplicate name raises ValueError.""" + agent1 = Mock() + agent1.name = "MyAgent" + agent2 = Mock() + agent2.name = "MyAgent" + + app = AgentFunctionApp(agents=[agent1]) + + # Try to add another agent with the same name + with pytest.raises(ValueError, match="already registered"): + app.add_agent(agent2) + + def test_add_agent_to_app_with_existing_agents(self) -> None: + """Test adding agent to app that already has agents.""" + agent1 = Mock() + agent1.name = "Agent1" + agent2 = Mock() + agent2.name = "Agent2" + + app = AgentFunctionApp(agents=[agent1]) + app.add_agent(agent2) + + assert len(app.agents) == 2 + assert "Agent1" in app.agents + assert "Agent2" in app.agents + + def test_add_agent_without_name_raises_error(self) -> None: + """Test that adding agent without name attribute raises error.""" + app = AgentFunctionApp() + + agent = Mock(spec=[]) # Mock without name attribute + + with pytest.raises(ValueError, match="does not have a 'name' attribute"): + app.add_agent(agent) + + +class TestHealthCheckWithMultipleAgents: + """Test suite for health check with multiple agents.""" + + def test_health_check_returns_all_agents(self) -> None: + """Test that health check returns information about all agents.""" + agent1 = Mock() + agent1.name = "Agent1" + agent2 = Mock() + agent2.name = "Agent2" + + app = AgentFunctionApp(agents=[agent1, agent2]) + + # Note: We can't easily test the actual health check endpoint without running the app + # But we can verify the agents dictionary is properly populated + assert len(app.agents) == 2 + assert app.enable_health_check is True + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/azurefunctions/tests/test_orchestration.py b/python/packages/azurefunctions/tests/test_orchestration.py new file mode 100644 index 0000000000..cebea00220 --- /dev/null +++ b/python/packages/azurefunctions/tests/test_orchestration.py @@ -0,0 +1,425 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for orchestration support (DurableAIAgent).""" + +from typing import Any +from unittest.mock import Mock + +import pytest +from agent_framework import AgentThread + +from agent_framework_azurefunctions import DurableAIAgent, get_agent +from agent_framework_azurefunctions._models import AgentSessionId, DurableAgentThread + + +class TestDurableAIAgent: + """Test suite for DurableAIAgent wrapper.""" + + def test_init(self) -> None: + """Test DurableAIAgent initialization.""" + mock_context = Mock() + mock_context.instance_id = "test-instance-123" + + agent = DurableAIAgent(mock_context, "TestAgent") + + assert agent.context == mock_context + assert agent.agent_name == "TestAgent" + + def test_implements_agent_protocol(self) -> None: + """Test that DurableAIAgent implements AgentProtocol.""" + from agent_framework import AgentProtocol + + mock_context = Mock() + agent = DurableAIAgent(mock_context, "TestAgent") + + # Check that agent satisfies AgentProtocol + assert isinstance(agent, AgentProtocol) + + def test_has_agent_protocol_properties(self) -> None: + """Test that DurableAIAgent has AgentProtocol properties.""" + mock_context = Mock() + agent = DurableAIAgent(mock_context, "TestAgent") + + # AgentProtocol properties + assert hasattr(agent, "id") + assert hasattr(agent, "name") + assert hasattr(agent, "description") + assert hasattr(agent, "display_name") + + # Verify values + assert agent.name == "TestAgent" + assert agent.description == "Durable agent proxy for TestAgent" + assert agent.display_name == "TestAgent" + assert agent.id is not None # Auto-generated UUID + + def test_get_new_thread(self) -> None: + """Test creating a new agent thread.""" + mock_context = Mock() + mock_context.instance_id = "test-instance-456" + mock_context.new_uuid = Mock(return_value="test-guid-456") + + agent = DurableAIAgent(mock_context, "WriterAgent") + thread = agent.get_new_thread() + + assert isinstance(thread, DurableAgentThread) + assert thread.session_id is not None + session_id = thread.session_id + assert isinstance(session_id, AgentSessionId) + assert session_id.name == "WriterAgent" + assert session_id.key == "test-guid-456" + mock_context.new_uuid.assert_called_once() + + def test_get_new_thread_deterministic(self) -> None: + """Test that get_new_thread creates deterministic session IDs.""" + + mock_context = Mock() + mock_context.instance_id = "test-instance-789" + mock_context.new_uuid = Mock(side_effect=["session-guid-1", "session-guid-2"]) + + agent = DurableAIAgent(mock_context, "EditorAgent") + + # Create multiple threads - they should have unique session IDs + thread1 = agent.get_new_thread() + thread2 = agent.get_new_thread() + + assert isinstance(thread1, DurableAgentThread) + assert isinstance(thread2, DurableAgentThread) + + session_id1 = thread1.session_id + session_id2 = thread2.session_id + assert session_id1 is not None and session_id2 is not None + assert isinstance(session_id1, AgentSessionId) + assert isinstance(session_id2, AgentSessionId) + assert session_id1.name == "EditorAgent" + assert session_id2.name == "EditorAgent" + assert session_id1.key == "session-guid-1" + assert session_id2.key == "session-guid-2" + assert mock_context.new_uuid.call_count == 2 + + def test_run_creates_entity_call(self) -> None: + """Test that run() creates proper entity call and returns a Task.""" + mock_context = Mock() + mock_context.instance_id = "test-instance-001" + mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"]) + + # Mock call_entity to return a Task-like object + mock_task = Mock() + mock_task._is_scheduled = False # Task attribute that orchestration checks + + mock_context.call_entity = Mock(return_value=mock_task) + + agent = DurableAIAgent(mock_context, "TestAgent") + + # Create thread + thread = agent.get_new_thread() + + # Call run() - it should return the Task directly + task = agent.run(messages="Test message", thread=thread, enable_tool_calls=True) + + # Verify run() returns the Task from call_entity + assert task == mock_task + + # Verify call_entity was called with correct parameters + assert mock_context.call_entity.called + call_args = mock_context.call_entity.call_args + entity_id, operation, request = call_args[0] + + assert operation == "run_agent" + assert request["message"] == "Test message" + assert request["enable_tool_calls"] is True + assert "correlation_id" in request + assert request["correlation_id"] == "correlation-guid" + assert "conversation_id" in request + assert request["conversation_id"] == "thread-guid" + + def test_run_without_thread(self) -> None: + """Test that run() works without explicit thread (creates unique session key).""" + mock_context = Mock() + mock_context.instance_id = "test-instance-002" + # Two calls to new_uuid: one for session_key, one for correlation_id + mock_context.new_uuid = Mock(side_effect=["auto-generated-guid", "correlation-guid"]) + + mock_task = Mock() + mock_task._is_scheduled = False + mock_context.call_entity = Mock(return_value=mock_task) + + agent = DurableAIAgent(mock_context, "TestAgent") + + # Call without thread + task = agent.run(messages="Test message") + + assert task == mock_task + + # Verify the entity ID uses the auto-generated GUID with dafx- prefix + call_args = mock_context.call_entity.call_args + entity_id = call_args[0][0] + assert entity_id.name == "dafx-TestAgent" + assert entity_id.key == "auto-generated-guid" + # Should be called twice: once for session_key, once for correlation_id + assert mock_context.new_uuid.call_count == 2 + + def test_run_with_response_format(self) -> None: + """Test that run() passes response format correctly.""" + mock_context = Mock() + mock_context.instance_id = "test-instance-003" + + mock_task = Mock() + mock_task._is_scheduled = False + mock_context.call_entity = Mock(return_value=mock_task) + + agent = DurableAIAgent(mock_context, "TestAgent") + + from pydantic import BaseModel + + class SampleSchema(BaseModel): + key: str + + # Create thread and call + thread = agent.get_new_thread() + + task = agent.run(messages="Test message", thread=thread, response_format=SampleSchema) + + assert task == mock_task + + # Verify schema was passed in the call_entity arguments + call_args = mock_context.call_entity.call_args + input_data = call_args[0][2] # Third argument is input_data + assert "response_format" in input_data + assert input_data["response_format"]["__response_schema_type__"] == "pydantic_model" + assert input_data["response_format"]["module"] == SampleSchema.__module__ + assert input_data["response_format"]["qualname"] == SampleSchema.__qualname__ + + def test_messages_to_string(self) -> None: + """Test converting ChatMessage list to string.""" + from agent_framework import ChatMessage + + mock_context = Mock() + agent = DurableAIAgent(mock_context, "TestAgent") + + messages = [ + ChatMessage(role="user", text="Hello"), + ChatMessage(role="assistant", text="Hi there"), + ChatMessage(role="user", text="How are you?"), + ] + + result = agent._messages_to_string(messages) + + assert result == "Hello\nHi there\nHow are you?" + + def test_run_with_chat_message(self) -> None: + """Test that run() handles ChatMessage input.""" + from agent_framework import ChatMessage + + mock_context = Mock() + mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"]) + mock_task = Mock() + mock_context.call_entity = Mock(return_value=mock_task) + + agent = DurableAIAgent(mock_context, "TestAgent") + thread = agent.get_new_thread() + + # Call with ChatMessage + msg = ChatMessage(role="user", text="Hello") + task = agent.run(messages=msg, thread=thread) + + assert task == mock_task + + # Verify message was converted to string + call_args = mock_context.call_entity.call_args + request = call_args[0][2] + assert request["message"] == "Hello" + + def test_run_stream_raises_not_implemented(self) -> None: + """Test that run_stream() method raises NotImplementedError.""" + mock_context = Mock() + agent = DurableAIAgent(mock_context, "TestAgent") + + with pytest.raises(NotImplementedError) as exc_info: + agent.run_stream("Test message") + + error_msg = str(exc_info.value) + assert "Streaming is not supported" in error_msg + + def test_entity_id_format(self) -> None: + """Test that EntityId is created with correct format (name, key).""" + from azure.durable_functions import EntityId + + mock_context = Mock() + mock_context.new_uuid = Mock(return_value="test-guid-789") + mock_context.call_entity = Mock(return_value=Mock()) + + agent = DurableAIAgent(mock_context, "WriterAgent") + thread = agent.get_new_thread() + + # Call run() to trigger entity ID creation + agent.run("Test", thread=thread) + + # Verify call_entity was called with correct EntityId + call_args = mock_context.call_entity.call_args + entity_id = call_args[0][0] + + # EntityId should be EntityId(name="dafx-WriterAgent", key="test-guid-789") + # Which formats as "@dafx-writeragent@test-guid-789" + assert isinstance(entity_id, EntityId) + assert entity_id.name == "dafx-WriterAgent" + assert entity_id.key == "test-guid-789" + assert str(entity_id) == "@dafx-writeragent@test-guid-789" + + +class TestGetAgentHelper: + """Test suite for the get_agent helper function.""" + + def test_get_agent_function(self) -> None: + """Test get_agent function creates DurableAIAgent.""" + mock_context = Mock() + mock_context.instance_id = "test-instance-100" + + agent = get_agent(mock_context, "MyAgent") + + assert isinstance(agent, DurableAIAgent) + assert agent.agent_name == "MyAgent" + assert agent.context == mock_context + + +class TestOrchestrationIntegration: + """Integration tests for orchestration scenarios.""" + + def test_sequential_agent_calls_simulation(self) -> None: + """Simulate sequential agent calls in an orchestration.""" + mock_context = Mock() + mock_context.instance_id = "test-orchestration-001" + # new_uuid will be called 3 times: + # 1. thread creation + # 2. correlation_id for first call + # 3. correlation_id for second call + mock_context.new_uuid = Mock(side_effect=["deterministic-guid-001", "corr-1", "corr-2"]) + + # Track entity calls + entity_calls: list[dict[str, Any]] = [] + + def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> Mock: + entity_calls.append({"entity_id": str(entity_id), "operation": operation, "input": input_data}) + + # Return a mock Task + mock_task = Mock() + mock_task._is_scheduled = False + return mock_task + + mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect) + + # Create agent + agent = get_agent(mock_context, "WriterAgent") + + # Create thread + thread = agent.get_new_thread() + + # First call - returns Task + task1 = agent.run("Write something", thread=thread) + assert hasattr(task1, "_is_scheduled") + + # Second call - returns Task + task2 = agent.run("Improve: something", thread=thread) + assert hasattr(task2, "_is_scheduled") + + # Verify both calls used the same entity (same session key) + assert len(entity_calls) == 2 + assert entity_calls[0]["entity_id"] == entity_calls[1]["entity_id"] + # EntityId format is @dafx-writeragent@deterministic-guid-001 + assert entity_calls[0]["entity_id"] == "@dafx-writeragent@deterministic-guid-001" + # new_uuid called 3 times: thread + 2 correlation IDs + assert mock_context.new_uuid.call_count == 3 + + def test_multiple_agents_in_orchestration(self) -> None: + """Test using multiple different agents in one orchestration.""" + mock_context = Mock() + mock_context.instance_id = "test-orchestration-002" + # Mock new_uuid to return different GUIDs for each call + # Order: writer thread, editor thread, writer correlation, editor correlation + mock_context.new_uuid = Mock(side_effect=["writer-guid-001", "editor-guid-002", "writer-corr", "editor-corr"]) + + entity_calls: list[str] = [] + + def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> Mock: + entity_calls.append(str(entity_id)) + mock_task = Mock() + mock_task._is_scheduled = False + return mock_task + + mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect) + + # Create multiple agents + writer = get_agent(mock_context, "WriterAgent") + editor = get_agent(mock_context, "EditorAgent") + + writer_thread = writer.get_new_thread() + editor_thread = editor.get_new_thread() + + # Call both agents - returns Tasks + writer_task = writer.run("Write", thread=writer_thread) + editor_task = editor.run("Edit", thread=editor_thread) + + assert hasattr(writer_task, "_is_scheduled") + assert hasattr(editor_task, "_is_scheduled") + + # Verify different entity IDs were used + assert len(entity_calls) == 2 + # EntityId format is @dafx-agentname@guid (lowercased agent name with dafx- prefix) + assert entity_calls[0] == "@dafx-writeragent@writer-guid-001" + assert entity_calls[1] == "@dafx-editoragent@editor-guid-002" + + +class TestAgentThreadSerialization: + """Test that AgentThread can be serialized for orchestration state.""" + + async def test_agent_thread_serialize(self) -> None: + """Test that AgentThread can be serialized.""" + thread = AgentThread() + + # Serialize + serialized = await thread.serialize() + + assert isinstance(serialized, dict) + assert "service_thread_id" in serialized + + async def test_agent_thread_deserialize(self) -> None: + """Test that AgentThread can be deserialized.""" + thread = AgentThread() + serialized = await thread.serialize() + + # Deserialize + restored = await AgentThread.deserialize(serialized) + + assert isinstance(restored, AgentThread) + assert restored.service_thread_id == thread.service_thread_id + + async def test_durable_agent_thread_serialization(self) -> None: + """Test that DurableAgentThread persists session metadata during serialization.""" + mock_context = Mock() + mock_context.instance_id = "test-instance-999" + mock_context.new_uuid = Mock(return_value="test-guid-999") + + agent = DurableAIAgent(mock_context, "TestAgent") + thread = agent.get_new_thread() + + assert isinstance(thread, DurableAgentThread) + # Verify custom attribute and property exist + assert thread.session_id is not None + session_id = thread.session_id + assert isinstance(session_id, AgentSessionId) + assert session_id.name == "TestAgent" + assert session_id.key == "test-guid-999" + + # Standard serialization should still work + serialized = await thread.serialize() + assert isinstance(serialized, dict) + assert serialized.get("durable_session_id") == str(session_id) + + # After deserialization, we'd need to restore the custom attribute + # This would be handled by the orchestration framework + restored = await DurableAgentThread.deserialize(serialized) + assert isinstance(restored, DurableAgentThread) + assert restored.session_id == session_id + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/azurefunctions/tests/test_state.py b/python/packages/azurefunctions/tests/test_state.py new file mode 100644 index 0000000000..52aa7458f0 --- /dev/null +++ b/python/packages/azurefunctions/tests/test_state.py @@ -0,0 +1,110 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for AgentState correlation ID tracking.""" + +from unittest.mock import Mock + +import pytest +from agent_framework import AgentRunResponse + +from agent_framework_azurefunctions._state import AgentState + + +class TestAgentStateCorrelationId: + """Test suite for AgentState correlation ID tracking.""" + + def _create_mock_response(self, text: str = "Response") -> Mock: + """Create a mock AgentRunResponse with the provided text.""" + mock_response = Mock(spec=AgentRunResponse) + mock_response.to_dict.return_value = {"text": text, "messages": []} + return mock_response + + def test_add_assistant_message_with_correlation_id(self) -> None: + state = AgentState() + state.add_user_message("Hello", correlation_id="corr-123-request") + state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123") + message_metadata = state.conversation_history[-1].additional_properties or {} + assert message_metadata.get("correlation_id") == "corr-123" + + response_data = state.try_get_agent_response("corr-123") + assert response_data is not None + assert response_data["content"] == "Response" + assert response_data["agent_response"] == {"text": "Response", "messages": []} + + def test_try_get_agent_response_returns_response(self) -> None: + state = AgentState() + state.add_user_message("Hello", correlation_id="corr-200-request") + state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-456") + + response_data = state.try_get_agent_response("corr-456") + + assert response_data is not None + assert response_data["content"] == "Response" + + def test_try_get_agent_response_returns_none_for_missing_id(self) -> None: + state = AgentState() + state.add_user_message("Hello", correlation_id="corr-300-request") + state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123") + + assert state.try_get_agent_response("non-existent") is None + + def test_multiple_responses_tracked_separately(self) -> None: + state = AgentState() + + for index in range(3): + state.add_user_message(f"Message {index}", correlation_id=f"corr-{index}-request") + state.add_assistant_message( + f"Response {index}", + self._create_mock_response(text=f"Response {index}"), + correlation_id=f"corr-{index}", + ) + + for index in range(3): + payload = state.try_get_agent_response(f"corr-{index}") + assert payload is not None + assert payload["content"] == f"Response {index}" + + def test_add_assistant_message_without_correlation_id(self) -> None: + state = AgentState() + state.add_user_message("Hello", correlation_id="corr-400-request") + state.add_assistant_message("Response", self._create_mock_response()) + + assert state.try_get_agent_response("missing") is None + assert state.last_response == "Response" + + def test_to_dict_does_not_duplicate_agent_responses(self) -> None: + state = AgentState() + state.add_user_message("Hello", correlation_id="corr-500-request") + state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123") + + state_snapshot = state.to_dict() + + assert "agent_responses" not in state_snapshot + metadata = state_snapshot["conversation_history"][-1]["additional_properties"] + assert metadata["correlation_id"] == "corr-123" + + def test_restore_state_preserves_agent_response_lookup(self) -> None: + state = AgentState() + state.add_user_message("Hello", correlation_id="corr-600-request") + state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123") + + restored_state = AgentState() + restored_state.restore_state(state.to_dict()) + + payload = restored_state.try_get_agent_response("corr-123") + assert payload is not None + assert payload["content"] == "Response" + + def test_reset_clears_conversation_history(self) -> None: + state = AgentState() + state.add_user_message("Hello", correlation_id="corr-700-request") + state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123") + + state.reset() + + assert len(state.conversation_history) == 0 + assert state.try_get_agent_response("corr-123") is None + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"])