Python: Add Unit tests for Azurefunctions package (#1976)

* Add Unit tests for Azurefunctions

* remove duplicate import
This commit is contained in:
Laveesh Rohra
2025-11-07 08:29:43 -08:00
committed by GitHub
Unverified
parent 90742ba48e
commit 754491cdd3
7 changed files with 2684 additions and 2 deletions
@@ -197,9 +197,10 @@ class DurableAgentThread(AgentThread):
**kwargs: Any,
) -> "DurableAgentThread":
"""Restores a durable thread, rehydrating the stored session identifier."""
session_id_value = serialized_thread_state.get(cls._SERIALIZED_SESSION_ID_KEY)
state_payload = dict(serialized_thread_state)
session_id_value = state_payload.pop(cls._SERIALIZED_SESSION_ID_KEY, None)
thread = await super().deserialize(
serialized_thread_state,
state_payload,
message_store=message_store,
**kwargs,
)
@@ -0,0 +1,614 @@
# Copyright (c) Microsoft. All rights reserved.
"""Unit tests for AgentFunctionApp."""
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar
from unittest.mock import ANY, AsyncMock, Mock, patch
import azure.durable_functions as df
import azure.functions as func
import pytest
from agent_framework import AgentRunResponse, ChatMessage
from agent_framework_azurefunctions import AgentFunctionApp
from agent_framework_azurefunctions._entities import AgentEntity, AgentState, create_agent_entity
from agent_framework_azurefunctions._errors import IncomingRequestError
TFunc = TypeVar("TFunc", bound=Callable[..., Any])
class TestAgentFunctionAppInit:
"""Test suite for AgentFunctionApp initialization."""
def test_init_with_defaults(self) -> None:
"""Test initialization with default parameters."""
mock_agent = Mock()
mock_agent.name = "TestAgent"
app = AgentFunctionApp(agents=[mock_agent])
assert len(app.agents) == 1
assert "TestAgent" in app.agents
assert app.enable_health_check is True
def test_init_with_custom_auth_level(self) -> None:
"""Test initialization with custom auth level."""
mock_agent = Mock()
mock_agent.name = "TestAgent"
app = AgentFunctionApp(agents=[mock_agent], http_auth_level=func.AuthLevel.FUNCTION)
# App should be created successfully
assert "TestAgent" in app.agents
def test_init_with_health_check_disabled(self) -> None:
"""Test initialization with health check disabled."""
mock_agent = Mock()
mock_agent.name = "TestAgent"
app = AgentFunctionApp(agents=[mock_agent], enable_health_check=False)
assert app.enable_health_check is False
def test_init_with_http_endpoints_disabled(self) -> None:
"""Test initialization with HTTP endpoints disabled."""
mock_agent = Mock()
mock_agent.name = "TestAgent"
app = AgentFunctionApp(agents=[mock_agent], enable_http_endpoints=False)
assert app.enable_http_endpoints is False
def test_init_stores_agent_reference(self) -> None:
"""Test that agent reference is stored correctly."""
mock_agent = Mock()
mock_agent.name = "TestAgent"
app = AgentFunctionApp(agents=[mock_agent])
assert app.agents["TestAgent"].name == "TestAgent"
def test_add_agent_uses_specific_callback(self) -> None:
"""Verify that a per-agent callback overrides the default."""
mock_agent = Mock()
mock_agent.name = "CallbackAgent"
specific_callback = Mock()
with patch.object(AgentFunctionApp, "_setup_agent_functions") as setup_mock:
app = AgentFunctionApp(default_callback=Mock())
app.add_agent(mock_agent, callback=specific_callback)
setup_mock.assert_called_once()
_, _, passed_callback, enable_http_endpoint = setup_mock.call_args[0]
assert passed_callback is specific_callback
assert enable_http_endpoint is True
def test_default_callback_applied_when_no_specific(self) -> None:
"""Ensure the default callback is supplied when add_agent lacks override."""
mock_agent = Mock()
mock_agent.name = "DefaultAgent"
default_callback = Mock()
with patch.object(AgentFunctionApp, "_setup_agent_functions") as setup_mock:
app = AgentFunctionApp(default_callback=default_callback)
app.add_agent(mock_agent)
setup_mock.assert_called_once()
_, _, passed_callback, enable_http_endpoint = setup_mock.call_args[0]
assert passed_callback is default_callback
assert enable_http_endpoint is True
def test_init_with_agents_uses_default_callback(self) -> None:
"""Agents provided in __init__ should receive the default callback."""
mock_agent = Mock()
mock_agent.name = "InitAgent"
default_callback = Mock()
with patch.object(AgentFunctionApp, "_setup_agent_functions") as setup_mock:
AgentFunctionApp(agents=[mock_agent], default_callback=default_callback)
setup_mock.assert_called_once()
_, _, passed_callback, enable_http_endpoint = setup_mock.call_args[0]
assert passed_callback is default_callback
assert enable_http_endpoint is True
class TestAgentFunctionAppSetup:
"""Test suite for AgentFunctionApp setup and configuration."""
def test_app_is_dfapp_instance(self) -> None:
"""Test that AgentFunctionApp is a DFApp instance."""
mock_agent = Mock()
mock_agent.name = "TestAgent"
app = AgentFunctionApp(agents=[mock_agent])
assert isinstance(app, df.DFApp)
def test_setup_creates_http_trigger(self) -> None:
"""Test that setup creates an HTTP trigger."""
mock_agent = Mock()
mock_agent.name = "TestAgent"
def passthrough_decorator(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]:
def decorator(func: TFunc) -> TFunc:
return func
return decorator
with (
patch.object(AgentFunctionApp, "route", new=passthrough_decorator),
patch.object(AgentFunctionApp, "durable_client_input", new=passthrough_decorator),
patch.object(AgentFunctionApp, "entity_trigger", new=passthrough_decorator),
):
app = AgentFunctionApp(agents=[mock_agent])
# Verify agent is registered
assert "TestAgent" in app.agents
def test_setup_skips_http_trigger_when_disabled(self) -> None:
"""Test that HTTP trigger is not created when disabled."""
mock_agent = Mock()
mock_agent.name = "TestAgent"
captured_routes: list[str | None] = []
def capture_route(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]:
def decorator(func: TFunc) -> TFunc:
route_key = kwargs.get("route") if kwargs else None
captured_routes.append(route_key)
return func
return decorator
def passthrough_decorator(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]:
def decorator(func: TFunc) -> TFunc:
return func
return decorator
with (
patch.object(AgentFunctionApp, "function_name", new=passthrough_decorator),
patch.object(AgentFunctionApp, "route", new=capture_route),
patch.object(AgentFunctionApp, "durable_client_input", new=passthrough_decorator),
patch.object(AgentFunctionApp, "entity_trigger", new=passthrough_decorator),
):
app = AgentFunctionApp(agents=[mock_agent], enable_http_endpoints=False)
# Verify agent is registered
assert "TestAgent" in app.agents
# Verify that no HTTP run route was created
run_route = f"agents/{mock_agent.name}/run"
assert run_route not in captured_routes
def test_agent_override_enables_http_route_when_app_disabled(self) -> None:
"""Agent-level override should enable HTTP route even when app disables it."""
mock_agent = Mock()
mock_agent.name = "OverrideAgent"
with (
patch.object(AgentFunctionApp, "_setup_http_run_route") as http_route_mock,
patch.object(AgentFunctionApp, "_setup_agent_entity") as agent_entity_mock,
):
app = AgentFunctionApp(enable_health_check=False, enable_http_endpoints=False)
app.add_agent(mock_agent, enable_http_endpoint=True)
http_route_mock.assert_called_once_with("OverrideAgent")
agent_entity_mock.assert_called_once_with(mock_agent, "OverrideAgent", ANY)
assert app.agent_http_endpoint_flags["OverrideAgent"] is True
def test_agent_override_disables_http_route_when_app_enabled(self) -> None:
"""Agent-level override should disable HTTP route even when app enables it."""
mock_agent = Mock()
mock_agent.name = "DisabledOverride"
with (
patch.object(AgentFunctionApp, "_setup_http_run_route") as http_route_mock,
patch.object(AgentFunctionApp, "_setup_agent_entity") as agent_entity_mock,
):
app = AgentFunctionApp(enable_health_check=False, enable_http_endpoints=True)
app.add_agent(mock_agent, enable_http_endpoint=False)
http_route_mock.assert_not_called()
agent_entity_mock.assert_called_once_with(mock_agent, "DisabledOverride", ANY)
assert app.agent_http_endpoint_flags["DisabledOverride"] is False
def test_multiple_apps_independent(self) -> None:
"""Test that multiple AgentFunctionApp instances are independent."""
agent1 = Mock()
agent1.name = "Agent1"
agent2 = Mock()
agent2.name = "Agent2"
app1 = AgentFunctionApp(agents=[agent1])
app2 = AgentFunctionApp(agents=[agent2])
assert app1.agents["Agent1"].name == "Agent1"
assert app2.agents["Agent2"].name == "Agent2"
assert "Agent1" in app1.agents
assert "Agent2" in app2.agents
class TestWaitForCompletionAndCorrelationId:
"""Tests for wait_for_completion flag and correlation ID handling."""
def _create_app(self) -> AgentFunctionApp:
mock_agent = Mock()
mock_agent.__class__.__name__ = "MockAgent"
mock_agent.name = "MockAgent"
return AgentFunctionApp(agents=[mock_agent], enable_health_check=False)
def _make_request(
self,
headers: dict[str, str] | None = None,
params: dict[str, str] | None = None,
) -> Mock:
request = Mock()
request.headers = headers or {}
request.params = params or {}
return request
def test_wait_for_completion_header_true(self) -> None:
"""Test that the wait-for-completion header is honored."""
app = self._create_app()
request = self._make_request(headers={"X-Wait-For-Completion": "true"})
assert app._should_wait_for_completion(request, {}) is True
def test_wait_for_completion_body_variants(self) -> None:
"""Test that multiple payload spellings are accepted."""
app = self._create_app()
request = self._make_request()
assert app._should_wait_for_completion(request, {"wait_for_completion": "true"}) is True
assert app._should_wait_for_completion(request, {"waitForCompletion": "1"}) is True
assert app._should_wait_for_completion(request, {"WaitForCompletion": "no"}) is False
class TestAgentEntityOperations:
"""Test suite for entity operations."""
async def test_entity_run_agent_operation(self) -> None:
"""Test that entity can run agent operation."""
mock_agent = Mock()
mock_agent.run = AsyncMock(
return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Test response")])
)
entity = AgentEntity(mock_agent)
mock_context = Mock()
result = await entity.run_agent(
mock_context,
{"message": "Test message", "conversation_id": "test-conv-123", "correlation_id": "corr-app-entity-1"},
)
assert result["status"] == "success"
assert result["response"] == "Test response"
assert result["message"] == "Test message"
assert result["conversation_id"] == "test-conv-123"
assert entity.state.message_count == 1
async def test_entity_stores_conversation_history(self) -> None:
"""Test that the entity stores conversation history."""
mock_agent = Mock()
mock_agent.run = AsyncMock(
return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Response 1")])
)
entity = AgentEntity(mock_agent)
mock_context = Mock()
# Send first message
await entity.run_agent(
mock_context, {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-app-entity-2"}
)
history = entity.state.conversation_history
assert len(history) == 2 # User + assistant
user_msg = history[0]
user_role = getattr(user_msg.role, "value", user_msg.role)
assert user_role == "user"
assert user_msg.text == "Message 1"
assistant_msg = history[1]
assistant_role = getattr(assistant_msg.role, "value", assistant_msg.role)
assert assistant_role == "assistant"
assert assistant_msg.text == "Response 1"
async def test_entity_increments_message_count(self) -> None:
"""Test that the entity increments the message count."""
mock_agent = Mock()
mock_agent.run = AsyncMock(
return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Response")])
)
entity = AgentEntity(mock_agent)
mock_context = Mock()
assert entity.state.message_count == 0
await entity.run_agent(
mock_context, {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-app-entity-3a"}
)
assert entity.state.message_count == 1
await entity.run_agent(
mock_context, {"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-app-entity-3b"}
)
assert entity.state.message_count == 2
def test_entity_reset(self) -> None:
"""Test that entity reset clears state."""
mock_agent = Mock()
entity = AgentEntity(mock_agent)
# Set some state
entity.state.message_count = 10
entity.state.last_response = "Some response"
entity.state.conversation_history = [
ChatMessage(role="user", text="test", additional_properties={"timestamp": "2024-01-01T00:00:00Z"})
]
# Reset
mock_context = Mock()
entity.reset(mock_context)
assert entity.state.message_count == 0
assert entity.state.last_response is None
assert len(entity.state.conversation_history) == 0
class TestAgentEntityFactory:
"""Test suite for the entity factory function."""
def test_create_agent_entity_returns_function(self) -> None:
"""Test that create_agent_entity returns a function."""
mock_agent = Mock()
entity_function = create_agent_entity(mock_agent)
assert callable(entity_function)
def test_entity_function_handles_run_agent_operation(self) -> None:
"""Test that the entity function handles the run_agent operation."""
mock_agent = Mock()
mock_agent.run = AsyncMock(
return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Response")])
)
entity_function = create_agent_entity(mock_agent)
# Mock context
mock_context = Mock()
mock_context.operation_name = "run_agent"
mock_context.get_input.return_value = {
"message": "Test message",
"conversation_id": "conv-123",
"correlation_id": "corr-app-factory-1",
}
mock_context.get_state.return_value = None
# Execute entity function
entity_function(mock_context)
# Verify result was set
assert mock_context.set_result.called
assert mock_context.set_state.called
def test_entity_function_handles_reset_operation(self) -> None:
"""Test that the entity function handles the reset operation."""
mock_agent = Mock()
entity_function = create_agent_entity(mock_agent)
# Mock context
mock_context = Mock()
mock_context.operation_name = "reset"
mock_context.get_state.return_value = {
"message_count": 5,
"conversation_history": [{"role": "user", "content": "test"}],
"last_response": "Test",
}
# Execute entity function
entity_function(mock_context)
# Verify result was set
assert mock_context.set_result.called
result_call = mock_context.set_result.call_args[0][0]
assert result_call["status"] == "reset"
def test_entity_function_handles_unknown_operation(self) -> None:
"""Test that the entity function handles an unknown operation."""
mock_agent = Mock()
entity_function = create_agent_entity(mock_agent)
# Mock context with unknown operation
mock_context = Mock()
mock_context.operation_name = "unknown_operation"
mock_context.get_state.return_value = None
# Execute entity function
entity_function(mock_context)
# Verify error result was set
assert mock_context.set_result.called
result_call = mock_context.set_result.call_args[0][0]
assert "error" in result_call
assert "unknown_operation" in result_call["error"]
def test_entity_function_restores_state(self) -> None:
"""Test that the entity function restores state from the context."""
mock_agent = Mock()
entity_function = create_agent_entity(mock_agent)
# Mock context with existing state
existing_state = {
"message_count": 3,
"conversation_history": [{"role": "user", "content": "msg1"}, {"role": "assistant", "content": "resp1"}],
"last_response": "resp1",
}
mock_context = Mock()
mock_context.operation_name = "reset"
mock_context.get_state.return_value = existing_state
with patch.object(AgentState, "restore_state") as restore_state_mock:
entity_function(mock_context)
restore_state_mock.assert_called_once_with(existing_state)
class TestErrorHandling:
"""Test suite for error handling."""
async def test_entity_handles_agent_error(self) -> None:
"""Test that the entity handles agent execution errors."""
mock_agent = Mock()
mock_agent.run = AsyncMock(side_effect=Exception("Agent error"))
entity = AgentEntity(mock_agent)
mock_context = Mock()
result = await entity.run_agent(
mock_context, {"message": "Test message", "conversation_id": "conv-1", "correlation_id": "corr-app-error-1"}
)
assert result["status"] == "error"
assert "error" in result
assert "Agent error" in result["error"]
assert result["error_type"] == "Exception"
def test_entity_function_handles_exception(self) -> None:
"""Test that the entity function handles exceptions gracefully."""
mock_agent = Mock()
# Force an exception by making get_input fail
mock_agent.run = AsyncMock(side_effect=Exception("Test error"))
entity_function = create_agent_entity(mock_agent)
mock_context = Mock()
mock_context.operation_name = "run_agent"
mock_context.get_input.side_effect = Exception("Input error")
mock_context.get_state.return_value = None
# Execute entity function - should not raise
entity_function(mock_context)
# Verify error result was set
assert mock_context.set_result.called
result_call = mock_context.set_result.call_args[0][0]
assert "error" in result_call
class TestIncomingRequestParsing:
"""Tests for parsing run requests with JSON and plain text bodies."""
def _create_app(self) -> AgentFunctionApp:
mock_agent = Mock()
mock_agent.name = "ParserAgent"
return AgentFunctionApp(agents=[mock_agent], enable_health_check=False)
def test_parse_plain_text_body(self) -> None:
"""Test parsing a plain-text request body."""
app = self._create_app()
request = Mock()
request.get_json.side_effect = ValueError("Invalid JSON")
request.get_body.return_value = b"Plain text message"
req_body, message = app._parse_incoming_request(request)
assert req_body == {}
assert message == "Plain text message"
def test_parse_plain_text_requires_content(self) -> None:
"""Test that plain-text requests require message content."""
app = self._create_app()
request = Mock()
request.get_json.side_effect = ValueError("Invalid JSON")
request.get_body.return_value = b" "
with pytest.raises(IncomingRequestError) as exc_info:
app._parse_incoming_request(request)
assert "Message is required" in str(exc_info.value)
def test_extract_session_key_from_query_params(self) -> None:
"""Test session key extraction from query parameters."""
app = self._create_app()
request = Mock()
request.params = {"sessionId": "query-session"}
req_body = {}
session_key = app._resolve_session_key(request, req_body)
assert session_key == "query-session"
class TestHttpRunRoute:
"""Tests for the HTTP run route behavior."""
async def test_http_run_accepts_plain_text(self) -> None:
"""Test that the HTTP handler accepts plain-text requests."""
mock_agent = Mock()
mock_agent.name = "HttpAgent"
captured_handlers: dict[str | None, Callable[..., Awaitable[func.HttpResponse]]] = {}
def capture_decorator(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]:
def decorator(func: TFunc) -> TFunc:
return func
return decorator
def capture_route(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]:
def decorator(func: TFunc) -> TFunc:
route_key = kwargs.get("route") if kwargs else None
captured_handlers[route_key] = func
return func
return decorator
with (
patch.object(AgentFunctionApp, "function_name", new=capture_decorator),
patch.object(AgentFunctionApp, "route", new=capture_route),
patch.object(AgentFunctionApp, "durable_client_input", new=capture_decorator),
patch.object(AgentFunctionApp, "entity_trigger", new=capture_decorator),
):
AgentFunctionApp(agents=[mock_agent], enable_health_check=False)
run_route = f"agents/{mock_agent.name}/run"
handler = captured_handlers[run_route]
request = Mock()
request.headers = {}
request.params = {}
request.route_params = {}
request.get_json.side_effect = ValueError("Invalid JSON")
request.get_body.return_value = b"Plain text via HTTP"
client = AsyncMock()
response = await handler(request, client)
assert response.status_code == 202
signal_args = client.signal_entity.call_args[0]
run_request = signal_args[2]
assert run_request["message"] == "Plain text via HTTP"
assert run_request["role"] == "user"
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
@@ -0,0 +1,904 @@
# Copyright (c) Microsoft. All rights reserved.
"""Unit tests for AgentEntity and entity operations.
Run with: pytest tests/test_entities.py -v
"""
import asyncio
from collections.abc import AsyncIterator, Callable
from datetime import datetime
from typing import Any, TypeVar
from unittest.mock import AsyncMock, Mock, patch
import pytest
from agent_framework import AgentRunResponse, AgentRunResponseUpdate, ChatMessage
from pydantic import BaseModel
from agent_framework_azurefunctions._entities import AgentEntity, create_agent_entity
from agent_framework_azurefunctions._models import ChatRole, RunRequest
from agent_framework_azurefunctions._state import AgentState
TFunc = TypeVar("TFunc", bound=Callable[..., Any])
def _role_value(chat_message: ChatMessage) -> str:
"""Helper to extract the string role from a ChatMessage."""
role = getattr(chat_message, "role", None)
role_value = getattr(role, "value", role)
if role_value is None:
return ""
return str(role_value)
def _agent_response(text: str | None) -> AgentRunResponse:
"""Create an AgentRunResponse with a single assistant message."""
message = (
ChatMessage(role="assistant", text=text) if text is not None else ChatMessage(role="assistant", contents=[])
)
return AgentRunResponse(messages=[message])
class RecordingCallback:
"""Callback implementation capturing streaming and final responses for assertions."""
def __init__(self):
self.stream_mock = AsyncMock()
self.response_mock = AsyncMock()
async def on_streaming_response_update(
self,
update: AgentRunResponseUpdate,
context: Any,
) -> None:
await self.stream_mock(update, context)
async def on_agent_response(self, response: AgentRunResponse, context: Any) -> None:
await self.response_mock(response, context)
class EntityStructuredResponse(BaseModel):
answer: float
class TestAgentEntityInit:
"""Test suite for AgentEntity initialization."""
def test_init_creates_entity(self) -> None:
"""Test that AgentEntity initializes correctly."""
mock_agent = Mock()
entity = AgentEntity(mock_agent)
assert entity.agent == mock_agent
assert entity.state.conversation_history == []
assert entity.state.last_response is None
assert entity.state.message_count == 0
def test_init_stores_agent_reference(self) -> None:
"""Test that the agent reference is stored correctly."""
mock_agent = Mock()
mock_agent.name = "TestAgent"
entity = AgentEntity(mock_agent)
assert entity.agent.name == "TestAgent"
def test_init_with_different_agent_types(self) -> None:
"""Test initialization with different agent types."""
agent1 = Mock()
agent1.__class__.__name__ = "AzureOpenAIAgent"
agent2 = Mock()
agent2.__class__.__name__ = "CustomAgent"
entity1 = AgentEntity(agent1)
entity2 = AgentEntity(agent2)
assert entity1.agent.__class__.__name__ == "AzureOpenAIAgent"
assert entity2.agent.__class__.__name__ == "CustomAgent"
class TestAgentEntityRunAgent:
"""Test suite for the run_agent operation."""
async def test_run_agent_executes_agent(self) -> None:
"""Test that run_agent executes the agent."""
mock_agent = Mock()
mock_response = _agent_response("Test response")
mock_agent.run = AsyncMock(return_value=mock_response)
entity = AgentEntity(mock_agent)
mock_context = Mock()
result = await entity.run_agent(
mock_context, {"message": "Test message", "conversation_id": "conv-123", "correlation_id": "corr-entity-1"}
)
# Verify agent.run was called
mock_agent.run.assert_called_once()
_, kwargs = mock_agent.run.call_args
sent_messages = kwargs.get("messages")
assert isinstance(sent_messages, list)
assert len(sent_messages) == 1
sent_message = sent_messages[0]
assert isinstance(sent_message, ChatMessage)
assert sent_message.text == "Test message"
assert _role_value(sent_message) == "user"
# Verify result
assert result["status"] == "success"
assert result["response"] == "Test response"
assert result["message"] == "Test message"
assert result["conversation_id"] == "conv-123"
async def test_run_agent_streaming_callbacks_invoked(self) -> None:
"""Ensure streaming updates trigger callbacks and run() is not used."""
updates = [
AgentRunResponseUpdate(text="Hello"),
AgentRunResponseUpdate(text=" world"),
]
async def update_generator() -> AsyncIterator[AgentRunResponseUpdate]:
for update in updates:
yield update
mock_agent = Mock()
mock_agent.name = "StreamingAgent"
mock_agent.run_stream = Mock(return_value=update_generator())
mock_agent.run = AsyncMock(side_effect=AssertionError("run() should not be called when streaming succeeds"))
callback = RecordingCallback()
entity = AgentEntity(mock_agent, callback=callback)
mock_context = Mock()
result = await entity.run_agent(
mock_context,
{
"message": "Tell me something",
"conversation_id": "session-1",
"correlation_id": "corr-stream-1",
},
)
assert result["status"] == "success"
assert "Hello" in result.get("response", "")
assert callback.stream_mock.await_count == len(updates)
assert callback.response_mock.await_count == 1
mock_agent.run.assert_not_called()
# Validate callback arguments
stream_calls = callback.stream_mock.await_args_list
for expected_update, recorded_call in zip(updates, stream_calls, strict=True):
assert recorded_call.args[0] is expected_update
context = recorded_call.args[1]
assert context.agent_name == "StreamingAgent"
assert context.correlation_id == "corr-stream-1"
assert context.conversation_id == "session-1"
assert context.request_message == "Tell me something"
final_call = callback.response_mock.await_args
assert final_call is not None
final_response, final_context = final_call.args
assert final_context.agent_name == "StreamingAgent"
assert final_context.correlation_id == "corr-stream-1"
assert final_context.conversation_id == "session-1"
assert final_context.request_message == "Tell me something"
assert getattr(final_response, "text", "").strip()
async def test_run_agent_final_callback_without_streaming(self) -> None:
"""Ensure the final callback fires even when streaming is unavailable."""
mock_agent = Mock()
mock_agent.name = "NonStreamingAgent"
mock_agent.run_stream = None
agent_response = _agent_response("Final response")
mock_agent.run = AsyncMock(return_value=agent_response)
callback = RecordingCallback()
entity = AgentEntity(mock_agent, callback=callback)
mock_context = Mock()
result = await entity.run_agent(
mock_context,
{
"message": "Hi",
"conversation_id": "session-2",
"correlation_id": "corr-final-1",
},
)
assert result["status"] == "success"
assert result.get("response") == "Final response"
assert callback.stream_mock.await_count == 0
assert callback.response_mock.await_count == 1
final_call = callback.response_mock.await_args
assert final_call is not None
assert final_call.args[0] is agent_response
final_context = final_call.args[1]
assert final_context.agent_name == "NonStreamingAgent"
assert final_context.correlation_id == "corr-final-1"
assert final_context.conversation_id == "session-2"
assert final_context.request_message == "Hi"
async def test_run_agent_updates_conversation_history(self) -> None:
"""Test that run_agent updates the conversation history."""
mock_agent = Mock()
mock_response = _agent_response("Agent response")
mock_agent.run = AsyncMock(return_value=mock_response)
entity = AgentEntity(mock_agent)
mock_context = Mock()
await entity.run_agent(
mock_context, {"message": "User message", "conversation_id": "conv-1", "correlation_id": "corr-entity-2"}
)
# Should have 2 entries: user message + assistant response
history = entity.state.conversation_history
assert len(history) == 2
user_msg = history[0]
assert _role_value(user_msg) == "user"
assert user_msg.text == "User message"
assistant_msg = history[1]
assert _role_value(assistant_msg) == "assistant"
assert assistant_msg.text == "Agent response"
async def test_run_agent_increments_message_count(self) -> None:
"""Test that run_agent increments the message count."""
mock_agent = Mock()
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
entity = AgentEntity(mock_agent)
mock_context = Mock()
assert entity.state.message_count == 0
await entity.run_agent(
mock_context, {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-entity-3a"}
)
assert entity.state.message_count == 1
await entity.run_agent(
mock_context, {"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-entity-3b"}
)
assert entity.state.message_count == 2
await entity.run_agent(
mock_context, {"message": "Message 3", "conversation_id": "conv-1", "correlation_id": "corr-entity-3c"}
)
assert entity.state.message_count == 3
async def test_run_agent_stores_last_response(self) -> None:
"""Test that run_agent stores the last response."""
mock_agent = Mock()
mock_agent.run = AsyncMock(return_value=_agent_response("Response 1"))
entity = AgentEntity(mock_agent)
mock_context = Mock()
await entity.run_agent(
mock_context, {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-entity-4a"}
)
assert entity.state.last_response == "Response 1"
mock_agent.run = AsyncMock(return_value=_agent_response("Response 2"))
await entity.run_agent(
mock_context, {"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-entity-4b"}
)
assert entity.state.last_response == "Response 2"
async def test_run_agent_with_none_conversation_id(self) -> None:
"""Test run_agent with a None conversation identifier."""
mock_agent = Mock()
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
entity = AgentEntity(mock_agent)
mock_context = Mock()
with pytest.raises(ValueError, match="conversation_id"):
await entity.run_agent(
mock_context, {"message": "Message", "conversation_id": None, "correlation_id": "corr-entity-5"}
)
async def test_run_agent_handles_response_without_text_attribute(self) -> None:
"""Test that run_agent handles responses without a text attribute."""
mock_agent = Mock()
class NoTextResponse(AgentRunResponse):
@property
def text(self) -> str: # type: ignore[override]
raise AttributeError("text attribute missing")
mock_response = NoTextResponse(messages=[ChatMessage(role="assistant", text="ignored")])
mock_agent.run = AsyncMock(return_value=mock_response)
entity = AgentEntity(mock_agent)
mock_context = Mock()
result = await entity.run_agent(
mock_context, {"message": "Message", "conversation_id": "conv-1", "correlation_id": "corr-entity-6"}
)
# Should handle gracefully
assert result["status"] == "success"
assert result["response"] == "Error extracting response"
async def test_run_agent_handles_none_response_text(self) -> None:
"""Test that run_agent handles responses with None text."""
mock_agent = Mock()
mock_agent.run = AsyncMock(return_value=_agent_response(None))
entity = AgentEntity(mock_agent)
mock_context = Mock()
result = await entity.run_agent(
mock_context, {"message": "Message", "conversation_id": "conv-1", "correlation_id": "corr-entity-7"}
)
assert result["status"] == "success"
assert result["response"] == "No response"
async def test_run_agent_multiple_conversations(self) -> None:
"""Test that run_agent maintains history across multiple messages."""
mock_agent = Mock()
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
entity = AgentEntity(mock_agent)
mock_context = Mock()
# Send multiple messages
await entity.run_agent(
mock_context, {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-entity-8a"}
)
await entity.run_agent(
mock_context, {"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-entity-8b"}
)
await entity.run_agent(
mock_context, {"message": "Message 3", "conversation_id": "conv-1", "correlation_id": "corr-entity-8c"}
)
history = entity.state.conversation_history
assert len(history) == 6
assert entity.state.message_count == 3
class TestAgentEntityReset:
"""Test suite for the reset operation."""
def test_reset_clears_conversation_history(self) -> None:
"""Test that reset clears the conversation history."""
mock_agent = Mock()
entity = AgentEntity(mock_agent)
# Add some history
entity.state.conversation_history = [
ChatMessage(role="user", text="msg1"),
ChatMessage(role="assistant", text="resp1"),
]
mock_context = Mock()
entity.reset(mock_context)
assert entity.state.conversation_history == []
def test_reset_clears_last_response(self) -> None:
"""Test that reset clears the last response."""
mock_agent = Mock()
entity = AgentEntity(mock_agent)
entity.state.last_response = "Some response"
mock_context = Mock()
entity.reset(mock_context)
assert entity.state.last_response is None
def test_reset_clears_message_count(self) -> None:
"""Test that reset clears the message count."""
mock_agent = Mock()
entity = AgentEntity(mock_agent)
entity.state.message_count = 10
mock_context = Mock()
entity.reset(mock_context)
assert entity.state.message_count == 0
async def test_reset_after_conversation(self) -> None:
"""Test reset after a full conversation."""
mock_agent = Mock()
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
entity = AgentEntity(mock_agent)
mock_context = Mock()
# Have a conversation
await entity.run_agent(
mock_context, {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-entity-10a"}
)
await entity.run_agent(
mock_context, {"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-entity-10b"}
)
# Verify state before reset
assert entity.state.message_count == 2
assert len(entity.state.conversation_history) == 4
# Reset
entity.reset(mock_context)
# Verify state after reset
assert entity.state.message_count == 0
assert len(entity.state.conversation_history) == 0
assert entity.state.last_response is None
class TestCreateAgentEntity:
"""Test suite for the create_agent_entity factory function."""
def test_create_agent_entity_returns_callable(self) -> None:
"""Test that create_agent_entity returns a callable."""
mock_agent = Mock()
entity_function = create_agent_entity(mock_agent)
assert callable(entity_function)
def test_entity_function_handles_run_agent(self) -> None:
"""Test that the entity function handles the run_agent operation."""
mock_agent = Mock()
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
entity_function = create_agent_entity(mock_agent)
# Mock context
mock_context = Mock()
mock_context.operation_name = "run_agent"
mock_context.get_input.return_value = {
"message": "Test message",
"conversation_id": "conv-123",
"correlation_id": "corr-entity-factory",
}
mock_context.get_state.return_value = None
# Execute
entity_function(mock_context)
# Verify result and state were set
assert mock_context.set_result.called
assert mock_context.set_state.called
def test_entity_function_handles_reset(self) -> None:
"""Test that the entity function handles the reset operation."""
mock_agent = Mock()
entity_function = create_agent_entity(mock_agent)
# Mock context with existing state
mock_context = Mock()
mock_context.operation_name = "reset"
mock_context.get_state.return_value = {
"message_count": 5,
"conversation_history": [
ChatMessage(
role="user", text="test", additional_properties={"timestamp": "2024-01-01T00:00:00Z"}
).to_dict()
],
"last_response": "Test",
}
# Execute
entity_function(mock_context)
# Verify reset result
assert mock_context.set_result.called
result = mock_context.set_result.call_args[0][0]
assert result["status"] == "reset"
# Verify state was cleared
assert mock_context.set_state.called
state = mock_context.set_state.call_args[0][0]
assert state["message_count"] == 0
assert state["conversation_history"] == []
assert state["last_response"] is None
def test_entity_function_handles_unknown_operation(self) -> None:
"""Test that the entity function handles unknown operations."""
mock_agent = Mock()
entity_function = create_agent_entity(mock_agent)
mock_context = Mock()
mock_context.operation_name = "invalid_operation"
mock_context.get_state.return_value = None
# Execute
entity_function(mock_context)
# Verify error result
assert mock_context.set_result.called
result = mock_context.set_result.call_args[0][0]
assert "error" in result
assert "invalid_operation" in result["error"].lower()
def test_entity_function_creates_new_entity_on_first_call(self) -> None:
"""Test that the entity function creates a new entity when no state exists."""
mock_agent = Mock()
mock_agent.__class__.__name__ = "Agent"
entity_function = create_agent_entity(mock_agent)
mock_context = Mock()
mock_context.operation_name = "reset"
mock_context.get_state.return_value = None # No existing state
# Execute
entity_function(mock_context)
# Verify new entity state was created
assert mock_context.set_result.called
result = mock_context.set_result.call_args[0][0]
assert result["status"] == "reset"
assert mock_context.set_state.called
state = mock_context.set_state.call_args[0][0]
assert state["message_count"] == 0
assert state["conversation_history"] == []
def test_entity_function_restores_existing_state(self) -> None:
"""Test that the entity function restores existing state."""
mock_agent = Mock()
entity_function = create_agent_entity(mock_agent)
existing_state = {
"message_count": 5,
"conversation_history": [
ChatMessage(
role="user", text="msg1", additional_properties={"timestamp": "2024-01-01T00:00:00Z"}
).to_dict(),
ChatMessage(
role="assistant", text="resp1", additional_properties={"timestamp": "2024-01-01T00:05:00Z"}
).to_dict(),
],
"last_response": "resp1",
}
mock_context = Mock()
mock_context.operation_name = "reset"
mock_context.get_state.return_value = existing_state
with patch.object(AgentState, "restore_state") as restore_state_mock:
entity_function(mock_context)
restore_state_mock.assert_called_once_with(existing_state)
class TestErrorHandling:
"""Test suite for error handling in entities."""
async def test_run_agent_handles_agent_exception(self) -> None:
"""Test that run_agent handles agent exceptions."""
mock_agent = Mock()
mock_agent.run = AsyncMock(side_effect=Exception("Agent failed"))
entity = AgentEntity(mock_agent)
mock_context = Mock()
result = await entity.run_agent(
mock_context, {"message": "Message", "conversation_id": "conv-1", "correlation_id": "corr-entity-error-1"}
)
assert result["status"] == "error"
assert "error" in result
assert "Agent failed" in result["error"]
assert result["error_type"] == "Exception"
async def test_run_agent_handles_value_error(self) -> None:
"""Test that run_agent handles ValueError instances."""
mock_agent = Mock()
mock_agent.run = AsyncMock(side_effect=ValueError("Invalid input"))
entity = AgentEntity(mock_agent)
mock_context = Mock()
result = await entity.run_agent(
mock_context, {"message": "Message", "conversation_id": "conv-1", "correlation_id": "corr-entity-error-2"}
)
assert result["status"] == "error"
assert result["error_type"] == "ValueError"
assert "Invalid input" in result["error"]
async def test_run_agent_handles_timeout_error(self) -> None:
"""Test that run_agent handles TimeoutError instances."""
mock_agent = Mock()
mock_agent.run = AsyncMock(side_effect=TimeoutError("Request timeout"))
entity = AgentEntity(mock_agent)
mock_context = Mock()
result = await entity.run_agent(
mock_context, {"message": "Message", "conversation_id": "conv-1", "correlation_id": "corr-entity-error-3"}
)
assert result["status"] == "error"
assert result["error_type"] == "TimeoutError"
def test_entity_function_handles_exception_in_operation(self) -> None:
"""Test that the entity function handles exceptions gracefully."""
mock_agent = Mock()
entity_function = create_agent_entity(mock_agent)
mock_context = Mock()
mock_context.operation_name = "run_agent"
mock_context.get_input.side_effect = Exception("Input error")
mock_context.get_state.return_value = None
# Execute - should not raise
entity_function(mock_context)
# Verify error was set
assert mock_context.set_result.called
result = mock_context.set_result.call_args[0][0]
assert "error" in result
async def test_run_agent_preserves_message_on_error(self) -> None:
"""Test that run_agent preserves message information on error."""
mock_agent = Mock()
mock_agent.run = AsyncMock(side_effect=Exception("Error"))
entity = AgentEntity(mock_agent)
mock_context = Mock()
result = await entity.run_agent(
mock_context,
{"message": "Test message", "conversation_id": "conv-123", "correlation_id": "corr-entity-error-4"},
)
# Even on error, message info should be preserved
assert result["message"] == "Test message"
assert result["conversation_id"] == "conv-123"
assert result["status"] == "error"
class TestConversationHistory:
"""Test suite for conversation history tracking."""
async def test_conversation_history_has_timestamps(self) -> None:
"""Test that conversation history entries include timestamps."""
mock_agent = Mock()
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
entity = AgentEntity(mock_agent)
mock_context = Mock()
await entity.run_agent(
mock_context, {"message": "Message", "conversation_id": "conv-1", "correlation_id": "corr-entity-history-1"}
)
# Check both user and assistant messages have timestamps
for entry in entity.state.conversation_history:
timestamp = entry.additional_properties.get("timestamp")
assert timestamp is not None
# Verify timestamp is in ISO format
datetime.fromisoformat(timestamp)
async def test_conversation_history_ordering(self) -> None:
"""Test that conversation history maintains the correct order."""
mock_agent = Mock()
entity = AgentEntity(mock_agent)
mock_context = Mock()
# Send multiple messages with different responses
mock_agent.run = AsyncMock(return_value=_agent_response("Response 1"))
await entity.run_agent(
mock_context,
{"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-entity-history-2a"},
)
mock_agent.run = AsyncMock(return_value=_agent_response("Response 2"))
await entity.run_agent(
mock_context,
{"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-entity-history-2b"},
)
mock_agent.run = AsyncMock(return_value=_agent_response("Response 3"))
await entity.run_agent(
mock_context,
{"message": "Message 3", "conversation_id": "conv-1", "correlation_id": "corr-entity-history-2c"},
)
# Verify order
history = entity.state.conversation_history
assert history[0].text == "Message 1"
assert history[1].text == "Response 1"
assert history[2].text == "Message 2"
assert history[3].text == "Response 2"
assert history[4].text == "Message 3"
assert history[5].text == "Response 3"
async def test_conversation_history_role_alternation(self) -> None:
"""Test that conversation history alternates between user and assistant roles."""
mock_agent = Mock()
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
entity = AgentEntity(mock_agent)
mock_context = Mock()
await entity.run_agent(
mock_context,
{"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-entity-history-3a"},
)
await entity.run_agent(
mock_context,
{"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-entity-history-3b"},
)
# Check role alternation
history = entity.state.conversation_history
assert _role_value(history[0]) == "user"
assert _role_value(history[1]) == "assistant"
assert _role_value(history[2]) == "user"
assert _role_value(history[3]) == "assistant"
class TestRunRequestSupport:
"""Test suite for RunRequest support in entities."""
async def test_run_agent_with_run_request_object(self) -> None:
"""Test run_agent with a RunRequest object."""
mock_agent = Mock()
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
entity = AgentEntity(mock_agent)
mock_context = Mock()
request = RunRequest(
message="Test message",
conversation_id="conv-123",
role=ChatRole.USER,
enable_tool_calls=True,
correlation_id="corr-runreq-1",
)
result = await entity.run_agent(mock_context, request)
assert result["status"] == "success"
assert result["response"] == "Response"
assert result["message"] == "Test message"
assert result["conversation_id"] == "conv-123"
async def test_run_agent_with_dict_request(self) -> None:
"""Test run_agent with a dictionary request."""
mock_agent = Mock()
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
entity = AgentEntity(mock_agent)
mock_context = Mock()
request_dict = {
"message": "Test message",
"conversation_id": "conv-456",
"role": "system",
"enable_tool_calls": False,
"correlation_id": "corr-runreq-2",
}
result = await entity.run_agent(mock_context, request_dict)
assert result["status"] == "success"
assert result["message"] == "Test message"
assert result["conversation_id"] == "conv-456"
async def test_run_agent_with_string_raises_without_correlation(self) -> None:
"""Test that run_agent rejects legacy string input without correlation ID."""
mock_agent = Mock()
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
entity = AgentEntity(mock_agent)
mock_context = Mock()
with pytest.raises(ValueError):
await entity.run_agent(mock_context, "Simple message")
async def test_run_agent_stores_role_in_history(self) -> None:
"""Test that run_agent stores the role in conversation history."""
mock_agent = Mock()
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
entity = AgentEntity(mock_agent)
mock_context = Mock()
# Send as system role
request = RunRequest(
message="System message",
conversation_id="conv-runreq-3",
role=ChatRole.SYSTEM,
correlation_id="corr-runreq-3",
)
await entity.run_agent(mock_context, request)
# Check that system role was stored
history = entity.state.conversation_history
assert _role_value(history[0]) == "system"
assert history[0].text == "System message"
async def test_run_agent_with_response_format(self) -> None:
"""Test run_agent with a JSON response format."""
mock_agent = Mock()
# Return JSON response
mock_agent.run = AsyncMock(return_value=_agent_response('{"answer": 42}'))
entity = AgentEntity(mock_agent)
mock_context = Mock()
request = RunRequest(
message="What is the answer?",
conversation_id="conv-runreq-4",
response_format=EntityStructuredResponse,
correlation_id="corr-runreq-4",
)
result = await entity.run_agent(mock_context, request)
assert result["status"] == "success"
# Should have structured_response
if "structured_response" in result:
assert result["structured_response"]["answer"] == 42
async def test_run_agent_disable_tool_calls(self) -> None:
"""Test run_agent with tool calls disabled."""
mock_agent = Mock()
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
entity = AgentEntity(mock_agent)
mock_context = Mock()
request = RunRequest(
message="Test", conversation_id="conv-runreq-5", enable_tool_calls=False, correlation_id="corr-runreq-5"
)
result = await entity.run_agent(mock_context, request)
assert result["status"] == "success"
# Agent should have been called (tool disabling is framework-dependent)
mock_agent.run.assert_called_once()
async def test_entity_function_with_run_request_dict(self) -> None:
"""Test that the entity function handles the RunRequest dict format."""
mock_agent = Mock()
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
entity_function = create_agent_entity(mock_agent)
mock_context = Mock()
mock_context.operation_name = "run_agent"
mock_context.get_input.return_value = {
"message": "Test message",
"conversation_id": "conv-789",
"role": "user",
"enable_tool_calls": True,
"correlation_id": "corr-runreq-6",
}
mock_context.get_state.return_value = None
await asyncio.to_thread(entity_function, mock_context)
# Verify result was set
assert mock_context.set_result.called
result = mock_context.set_result.call_args[0][0]
assert result["status"] == "success"
assert result["message"] == "Test message"
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
@@ -0,0 +1,478 @@
# Copyright (c) Microsoft. All rights reserved.
"""Unit tests for data models (AgentSessionId, RunRequest, AgentResponse, ChatRole)."""
import azure.durable_functions as df
import pytest
from pydantic import BaseModel
from agent_framework_azurefunctions._models import AgentResponse, AgentSessionId, ChatRole, RunRequest
class ModuleStructuredResponse(BaseModel):
value: int
class TestChatRole:
"""Test suite for ChatRole enum."""
def test_chat_role_values(self) -> None:
"""Test that ChatRole has correct values."""
assert ChatRole.USER == "user"
assert ChatRole.SYSTEM == "system"
assert ChatRole.ASSISTANT == "assistant"
def test_chat_role_is_string(self) -> None:
"""Test that ChatRole values are strings."""
assert isinstance(ChatRole.USER.value, str)
assert isinstance(ChatRole.SYSTEM.value, str)
assert isinstance(ChatRole.ASSISTANT.value, str)
class TestAgentSessionId:
"""Test suite for AgentSessionId."""
def test_init_creates_session_id(self) -> None:
"""Test that AgentSessionId initializes correctly."""
session_id = AgentSessionId(name="AgentEntity", key="test-key-123")
assert session_id.name == "AgentEntity"
assert session_id.key == "test-key-123"
def test_with_random_key_generates_guid(self) -> None:
"""Test that with_random_key generates a GUID."""
session_id = AgentSessionId.with_random_key(name="AgentEntity")
assert session_id.name == "AgentEntity"
assert len(session_id.key) == 32 # UUID hex is 32 chars
# Verify it's a valid hex string
int(session_id.key, 16)
def test_with_random_key_unique_keys(self) -> None:
"""Test that with_random_key generates unique keys."""
session_id1 = AgentSessionId.with_random_key(name="AgentEntity")
session_id2 = AgentSessionId.with_random_key(name="AgentEntity")
assert session_id1.key != session_id2.key
def test_to_entity_id_conversion(self) -> None:
"""Test conversion to EntityId."""
session_id = AgentSessionId(name="AgentEntity", key="test-key")
entity_id = session_id.to_entity_id()
assert isinstance(entity_id, df.EntityId)
assert entity_id.name == "dafx-AgentEntity"
assert entity_id.key == "test-key"
def test_from_entity_id_conversion(self) -> None:
"""Test creation from EntityId."""
entity_id = df.EntityId(name="dafx-AgentEntity", key="test-key")
session_id = AgentSessionId.from_entity_id(entity_id)
assert isinstance(session_id, AgentSessionId)
assert session_id.name == "AgentEntity"
assert session_id.key == "test-key"
def test_round_trip_entity_id_conversion(self) -> None:
"""Test round-trip conversion to and from EntityId."""
original = AgentSessionId(name="AgentEntity", key="test-key")
entity_id = original.to_entity_id()
restored = AgentSessionId.from_entity_id(entity_id)
assert restored.name == original.name
assert restored.key == original.key
def test_str_representation(self) -> None:
"""Test string representation."""
session_id = AgentSessionId(name="AgentEntity", key="test-key-123")
str_repr = str(session_id)
assert str_repr == "@AgentEntity@test-key-123"
def test_repr_representation(self) -> None:
"""Test repr representation."""
session_id = AgentSessionId(name="AgentEntity", key="test-key")
repr_str = repr(session_id)
assert "AgentSessionId" in repr_str
assert "AgentEntity" in repr_str
assert "test-key" in repr_str
def test_parse_valid_session_id(self) -> None:
"""Test parsing valid session ID string."""
session_id = AgentSessionId.parse("@AgentEntity@test-key-123")
assert session_id.name == "AgentEntity"
assert session_id.key == "test-key-123"
def test_parse_invalid_format_no_prefix(self) -> None:
"""Test parsing invalid format without @ prefix."""
with pytest.raises(ValueError) as exc_info:
AgentSessionId.parse("AgentEntity@test-key")
assert "Invalid agent session ID format" in str(exc_info.value)
def test_parse_invalid_format_single_part(self) -> None:
"""Test parsing invalid format with single part."""
with pytest.raises(ValueError) as exc_info:
AgentSessionId.parse("@AgentEntity")
assert "Invalid agent session ID format" in str(exc_info.value)
def test_parse_with_multiple_at_signs_in_key(self) -> None:
"""Test parsing with @ signs in the key."""
session_id = AgentSessionId.parse("@AgentEntity@key-with@symbols")
assert session_id.name == "AgentEntity"
assert session_id.key == "key-with@symbols"
def test_parse_round_trip(self) -> None:
"""Test round-trip parse and string conversion."""
original = AgentSessionId(name="AgentEntity", key="test-key")
str_repr = str(original)
parsed = AgentSessionId.parse(str_repr)
assert parsed.name == original.name
assert parsed.key == original.key
def test_to_entity_name_adds_prefix(self) -> None:
"""Test that to_entity_name adds the dafx- prefix."""
entity_name = AgentSessionId.to_entity_name("TestAgent")
assert entity_name == "dafx-TestAgent"
def test_from_entity_id_strips_prefix(self) -> None:
"""Test that from_entity_id strips the dafx- prefix."""
entity_id = df.EntityId(name="dafx-TestAgent", key="key123")
session_id = AgentSessionId.from_entity_id(entity_id)
assert session_id.name == "TestAgent"
assert session_id.key == "key123"
def test_from_entity_id_raises_without_prefix(self) -> None:
"""Test that from_entity_id raises ValueError when entity name lacks the prefix."""
entity_id = df.EntityId(name="TestAgent", key="key123")
with pytest.raises(ValueError) as exc_info:
AgentSessionId.from_entity_id(entity_id)
assert "not a valid agent session ID" in str(exc_info.value)
assert "dafx-" in str(exc_info.value)
class TestRunRequest:
"""Test suite for RunRequest."""
def test_init_with_defaults(self) -> None:
"""Test RunRequest initialization with defaults."""
request = RunRequest(message="Hello", conversation_id="conv-default")
assert request.message == "Hello"
assert request.role == ChatRole.USER
assert request.response_format is None
assert request.enable_tool_calls is True
assert request.conversation_id == "conv-default"
def test_init_with_all_fields(self) -> None:
"""Test RunRequest initialization with all fields."""
schema = ModuleStructuredResponse
request = RunRequest(
message="Hello",
conversation_id="conv-123",
role=ChatRole.SYSTEM,
response_format=schema,
enable_tool_calls=False,
)
assert request.message == "Hello"
assert request.role == ChatRole.SYSTEM
assert request.response_format is schema
assert request.enable_tool_calls is False
assert request.conversation_id == "conv-123"
def test_to_dict_with_defaults(self) -> None:
"""Test to_dict with default values."""
request = RunRequest(message="Test message", conversation_id="conv-to-dict")
data = request.to_dict()
assert data["message"] == "Test message"
assert data["enable_tool_calls"] is True
assert data["role"] == "user"
assert "response_format" not in data or data["response_format"] is None
assert data["conversation_id"] == "conv-to-dict"
def test_to_dict_with_all_fields(self) -> None:
"""Test to_dict with all fields."""
schema = ModuleStructuredResponse
request = RunRequest(
message="Hello",
conversation_id="conv-456",
role=ChatRole.ASSISTANT,
response_format=schema,
enable_tool_calls=False,
)
data = request.to_dict()
assert data["message"] == "Hello"
assert data["role"] == "assistant"
assert data["response_format"]["__response_schema_type__"] == "pydantic_model"
assert data["response_format"]["module"] == schema.__module__
assert data["response_format"]["qualname"] == schema.__qualname__
assert data["enable_tool_calls"] is False
assert data["conversation_id"] == "conv-456"
def test_from_dict_with_defaults(self) -> None:
"""Test from_dict with minimal data."""
data = {"message": "Hello", "conversation_id": "conv-from-dict"}
request = RunRequest.from_dict(data)
assert request.message == "Hello"
assert request.role == ChatRole.USER
assert request.enable_tool_calls is True
assert request.conversation_id == "conv-from-dict"
def test_from_dict_with_all_fields(self) -> None:
"""Test from_dict with all fields."""
data = {
"message": "Test",
"role": "system",
"response_format": {
"__response_schema_type__": "pydantic_model",
"module": ModuleStructuredResponse.__module__,
"qualname": ModuleStructuredResponse.__qualname__,
},
"enable_tool_calls": False,
"conversation_id": "conv-789",
}
request = RunRequest.from_dict(data)
assert request.message == "Test"
assert request.role == ChatRole.SYSTEM
assert request.response_format is ModuleStructuredResponse
assert request.enable_tool_calls is False
assert request.conversation_id == "conv-789"
def test_from_dict_invalid_role_defaults_to_user(self) -> None:
"""Test from_dict with invalid role defaults to USER."""
data = {"message": "Test", "role": "invalid_role", "conversation_id": "conv-invalid-role"}
request = RunRequest.from_dict(data)
assert request.role == ChatRole.USER
def test_from_dict_empty_message(self) -> None:
"""Test from_dict with empty message."""
data = {"conversation_id": "conv-empty"}
request = RunRequest.from_dict(data)
assert request.message == ""
assert request.role == ChatRole.USER
assert request.conversation_id == "conv-empty"
def test_round_trip_dict_conversion(self) -> None:
"""Test round-trip to_dict and from_dict."""
original = RunRequest(
message="Test message",
conversation_id="conv-123",
role=ChatRole.SYSTEM,
response_format=ModuleStructuredResponse,
enable_tool_calls=False,
)
data = original.to_dict()
restored = RunRequest.from_dict(data)
assert restored.message == original.message
assert restored.role == original.role
assert restored.response_format is ModuleStructuredResponse
assert restored.enable_tool_calls == original.enable_tool_calls
assert restored.conversation_id == original.conversation_id
def test_round_trip_with_pydantic_response_format(self) -> None:
"""Ensure Pydantic response formats serialize and deserialize properly."""
original = RunRequest(
message="Structured",
conversation_id="conv-pydantic",
response_format=ModuleStructuredResponse,
)
data = original.to_dict()
assert data["response_format"]["__response_schema_type__"] == "pydantic_model"
assert data["response_format"]["module"] == ModuleStructuredResponse.__module__
assert data["response_format"]["qualname"] == ModuleStructuredResponse.__qualname__
restored = RunRequest.from_dict(data)
assert restored.response_format is ModuleStructuredResponse
def test_init_with_correlation_id(self) -> None:
"""Test RunRequest initialization with correlation_id."""
request = RunRequest(message="Test message", conversation_id="conv-corr-init", correlation_id="corr-123")
assert request.message == "Test message"
assert request.correlation_id == "corr-123"
def test_to_dict_with_correlation_id(self) -> None:
"""Test to_dict includes correlation_id."""
request = RunRequest(message="Test", conversation_id="conv-corr-to-dict", correlation_id="corr-456")
data = request.to_dict()
assert data["message"] == "Test"
assert data["correlation_id"] == "corr-456"
def test_from_dict_with_correlation_id(self) -> None:
"""Test from_dict with correlation_id."""
data = {"message": "Test", "correlation_id": "corr-789", "conversation_id": "conv-corr-from-dict"}
request = RunRequest.from_dict(data)
assert request.message == "Test"
assert request.correlation_id == "corr-789"
assert request.conversation_id == "conv-corr-from-dict"
def test_round_trip_with_correlation_id(self) -> None:
"""Test round-trip to_dict and from_dict with correlation_id."""
original = RunRequest(
message="Test message",
conversation_id="conv-123",
role=ChatRole.SYSTEM,
correlation_id="corr-123",
)
data = original.to_dict()
restored = RunRequest.from_dict(data)
assert restored.message == original.message
assert restored.role == original.role
assert restored.correlation_id == original.correlation_id
assert restored.conversation_id == original.conversation_id
class TestAgentResponse:
"""Test suite for AgentResponse."""
def test_init_with_required_fields(self) -> None:
"""Test AgentResponse initialization with required fields."""
response = AgentResponse(
response="Test response", message="Test message", conversation_id="conv-123", status="success"
)
assert response.response == "Test response"
assert response.message == "Test message"
assert response.conversation_id == "conv-123"
assert response.status == "success"
assert response.message_count == 0
assert response.error is None
assert response.error_type is None
assert response.structured_response is None
def test_init_with_all_fields(self) -> None:
"""Test AgentResponse initialization with all fields."""
structured = {"answer": "42"}
response = AgentResponse(
response=None,
message="What is the answer?",
conversation_id="conv-456",
status="success",
message_count=5,
error=None,
error_type=None,
structured_response=structured,
)
assert response.response is None
assert response.structured_response == structured
assert response.message_count == 5
def test_to_dict_with_text_response(self) -> None:
"""Test to_dict with text response."""
response = AgentResponse(
response="Text response", message="Message", conversation_id="conv-1", status="success", message_count=3
)
data = response.to_dict()
assert data["response"] == "Text response"
assert data["message"] == "Message"
assert data["conversation_id"] == "conv-1"
assert data["status"] == "success"
assert data["message_count"] == 3
assert "structured_response" not in data
assert "error" not in data
assert "error_type" not in data
def test_to_dict_with_structured_response(self) -> None:
"""Test to_dict with structured response."""
structured = {"answer": 42, "confidence": 0.95}
response = AgentResponse(
response=None,
message="Question",
conversation_id="conv-2",
status="success",
structured_response=structured,
)
data = response.to_dict()
assert data["structured_response"] == structured
assert "response" not in data
def test_to_dict_with_error(self) -> None:
"""Test to_dict with error."""
response = AgentResponse(
response=None,
message="Failed message",
conversation_id="conv-3",
status="error",
error="Something went wrong",
error_type="ValueError",
)
data = response.to_dict()
assert data["status"] == "error"
assert data["error"] == "Something went wrong"
assert data["error_type"] == "ValueError"
def test_to_dict_prefers_structured_over_text(self) -> None:
"""Test to_dict prefers structured_response over response."""
structured = {"result": "structured"}
response = AgentResponse(
response="Text response",
message="Message",
conversation_id="conv-4",
status="success",
structured_response=structured,
)
data = response.to_dict()
assert "structured_response" in data
assert data["structured_response"] == structured
# Text response should not be included when structured is present
assert "response" not in data
class TestModelIntegration:
"""Test suite for integration between models."""
def test_run_request_with_session_id(self) -> None:
"""Test using RunRequest with AgentSessionId."""
session_id = AgentSessionId.with_random_key("AgentEntity")
request = RunRequest(message="Test message", conversation_id=str(session_id))
assert request.conversation_id is not None
assert request.conversation_id == str(session_id)
assert request.conversation_id.startswith("@AgentEntity@")
def test_response_from_run_request(self) -> None:
"""Test creating AgentResponse from RunRequest."""
request = RunRequest(message="What is 2+2?", conversation_id="conv-123", role=ChatRole.USER)
response = AgentResponse(
response="4",
message=request.message,
conversation_id=request.conversation_id,
status="success",
message_count=1,
)
assert response.message == request.message
assert response.conversation_id == request.conversation_id
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
@@ -0,0 +1,150 @@
# Copyright (c) Microsoft. All rights reserved.
"""Unit tests for multi-agent support in AgentFunctionApp."""
from unittest.mock import Mock
import pytest
from agent_framework_azurefunctions import AgentFunctionApp
class TestMultiAgentInit:
"""Test suite for multi-agent initialization."""
def test_init_with_agents_list(self) -> None:
"""Test initialization with list of agents."""
agent1 = Mock()
agent1.name = "Agent1"
agent2 = Mock()
agent2.name = "Agent2"
app = AgentFunctionApp(agents=[agent1, agent2])
assert len(app.agents) == 2
assert "Agent1" in app.agents
assert "Agent2" in app.agents
assert app.agents["Agent1"] == agent1
assert app.agents["Agent2"] == agent2
def test_init_with_empty_agents_list(self) -> None:
"""Test initialization with empty list of agents."""
app = AgentFunctionApp(agents=[])
assert len(app.agents) == 0
def test_init_with_no_agents(self) -> None:
"""Test initialization without any agents."""
app = AgentFunctionApp()
assert len(app.agents) == 0
def test_init_with_duplicate_agent_names(self) -> None:
"""Test initialization with agents having the same name raises error."""
agent1 = Mock()
agent1.name = "TestAgent"
agent2 = Mock()
agent2.name = "TestAgent"
with pytest.raises(ValueError, match="already registered"):
AgentFunctionApp(agents=[agent1, agent2])
def test_init_with_agent_without_name(self) -> None:
"""Test initialization with agent missing name attribute raises error."""
agent1 = Mock()
agent1.name = "Agent1"
agent2 = Mock(spec=[]) # Mock without name attribute
with pytest.raises(ValueError, match="does not have a 'name' attribute"):
AgentFunctionApp(agents=[agent1, agent2])
class TestAddAgentMethod:
"""Test suite for add_agent() method."""
def test_add_agent_to_empty_app(self) -> None:
"""Test adding agent to app initialized without agents."""
app = AgentFunctionApp()
agent = Mock()
agent.name = "NewAgent"
app.add_agent(agent)
assert len(app.agents) == 1
assert "NewAgent" in app.agents
assert app.agents["NewAgent"] == agent
def test_add_multiple_agents(self) -> None:
"""Test adding multiple agents sequentially."""
app = AgentFunctionApp()
agent1 = Mock()
agent1.name = "Agent1"
agent2 = Mock()
agent2.name = "Agent2"
app.add_agent(agent1)
app.add_agent(agent2)
assert len(app.agents) == 2
assert "Agent1" in app.agents
assert "Agent2" in app.agents
def test_add_agent_with_duplicate_name_raises_error(self) -> None:
"""Test that adding agent with duplicate name raises ValueError."""
agent1 = Mock()
agent1.name = "MyAgent"
agent2 = Mock()
agent2.name = "MyAgent"
app = AgentFunctionApp(agents=[agent1])
# Try to add another agent with the same name
with pytest.raises(ValueError, match="already registered"):
app.add_agent(agent2)
def test_add_agent_to_app_with_existing_agents(self) -> None:
"""Test adding agent to app that already has agents."""
agent1 = Mock()
agent1.name = "Agent1"
agent2 = Mock()
agent2.name = "Agent2"
app = AgentFunctionApp(agents=[agent1])
app.add_agent(agent2)
assert len(app.agents) == 2
assert "Agent1" in app.agents
assert "Agent2" in app.agents
def test_add_agent_without_name_raises_error(self) -> None:
"""Test that adding agent without name attribute raises error."""
app = AgentFunctionApp()
agent = Mock(spec=[]) # Mock without name attribute
with pytest.raises(ValueError, match="does not have a 'name' attribute"):
app.add_agent(agent)
class TestHealthCheckWithMultipleAgents:
"""Test suite for health check with multiple agents."""
def test_health_check_returns_all_agents(self) -> None:
"""Test that health check returns information about all agents."""
agent1 = Mock()
agent1.name = "Agent1"
agent2 = Mock()
agent2.name = "Agent2"
app = AgentFunctionApp(agents=[agent1, agent2])
# Note: We can't easily test the actual health check endpoint without running the app
# But we can verify the agents dictionary is properly populated
assert len(app.agents) == 2
assert app.enable_health_check is True
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
@@ -0,0 +1,425 @@
# Copyright (c) Microsoft. All rights reserved.
"""Unit tests for orchestration support (DurableAIAgent)."""
from typing import Any
from unittest.mock import Mock
import pytest
from agent_framework import AgentThread
from agent_framework_azurefunctions import DurableAIAgent, get_agent
from agent_framework_azurefunctions._models import AgentSessionId, DurableAgentThread
class TestDurableAIAgent:
"""Test suite for DurableAIAgent wrapper."""
def test_init(self) -> None:
"""Test DurableAIAgent initialization."""
mock_context = Mock()
mock_context.instance_id = "test-instance-123"
agent = DurableAIAgent(mock_context, "TestAgent")
assert agent.context == mock_context
assert agent.agent_name == "TestAgent"
def test_implements_agent_protocol(self) -> None:
"""Test that DurableAIAgent implements AgentProtocol."""
from agent_framework import AgentProtocol
mock_context = Mock()
agent = DurableAIAgent(mock_context, "TestAgent")
# Check that agent satisfies AgentProtocol
assert isinstance(agent, AgentProtocol)
def test_has_agent_protocol_properties(self) -> None:
"""Test that DurableAIAgent has AgentProtocol properties."""
mock_context = Mock()
agent = DurableAIAgent(mock_context, "TestAgent")
# AgentProtocol properties
assert hasattr(agent, "id")
assert hasattr(agent, "name")
assert hasattr(agent, "description")
assert hasattr(agent, "display_name")
# Verify values
assert agent.name == "TestAgent"
assert agent.description == "Durable agent proxy for TestAgent"
assert agent.display_name == "TestAgent"
assert agent.id is not None # Auto-generated UUID
def test_get_new_thread(self) -> None:
"""Test creating a new agent thread."""
mock_context = Mock()
mock_context.instance_id = "test-instance-456"
mock_context.new_uuid = Mock(return_value="test-guid-456")
agent = DurableAIAgent(mock_context, "WriterAgent")
thread = agent.get_new_thread()
assert isinstance(thread, DurableAgentThread)
assert thread.session_id is not None
session_id = thread.session_id
assert isinstance(session_id, AgentSessionId)
assert session_id.name == "WriterAgent"
assert session_id.key == "test-guid-456"
mock_context.new_uuid.assert_called_once()
def test_get_new_thread_deterministic(self) -> None:
"""Test that get_new_thread creates deterministic session IDs."""
mock_context = Mock()
mock_context.instance_id = "test-instance-789"
mock_context.new_uuid = Mock(side_effect=["session-guid-1", "session-guid-2"])
agent = DurableAIAgent(mock_context, "EditorAgent")
# Create multiple threads - they should have unique session IDs
thread1 = agent.get_new_thread()
thread2 = agent.get_new_thread()
assert isinstance(thread1, DurableAgentThread)
assert isinstance(thread2, DurableAgentThread)
session_id1 = thread1.session_id
session_id2 = thread2.session_id
assert session_id1 is not None and session_id2 is not None
assert isinstance(session_id1, AgentSessionId)
assert isinstance(session_id2, AgentSessionId)
assert session_id1.name == "EditorAgent"
assert session_id2.name == "EditorAgent"
assert session_id1.key == "session-guid-1"
assert session_id2.key == "session-guid-2"
assert mock_context.new_uuid.call_count == 2
def test_run_creates_entity_call(self) -> None:
"""Test that run() creates proper entity call and returns a Task."""
mock_context = Mock()
mock_context.instance_id = "test-instance-001"
mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"])
# Mock call_entity to return a Task-like object
mock_task = Mock()
mock_task._is_scheduled = False # Task attribute that orchestration checks
mock_context.call_entity = Mock(return_value=mock_task)
agent = DurableAIAgent(mock_context, "TestAgent")
# Create thread
thread = agent.get_new_thread()
# Call run() - it should return the Task directly
task = agent.run(messages="Test message", thread=thread, enable_tool_calls=True)
# Verify run() returns the Task from call_entity
assert task == mock_task
# Verify call_entity was called with correct parameters
assert mock_context.call_entity.called
call_args = mock_context.call_entity.call_args
entity_id, operation, request = call_args[0]
assert operation == "run_agent"
assert request["message"] == "Test message"
assert request["enable_tool_calls"] is True
assert "correlation_id" in request
assert request["correlation_id"] == "correlation-guid"
assert "conversation_id" in request
assert request["conversation_id"] == "thread-guid"
def test_run_without_thread(self) -> None:
"""Test that run() works without explicit thread (creates unique session key)."""
mock_context = Mock()
mock_context.instance_id = "test-instance-002"
# Two calls to new_uuid: one for session_key, one for correlation_id
mock_context.new_uuid = Mock(side_effect=["auto-generated-guid", "correlation-guid"])
mock_task = Mock()
mock_task._is_scheduled = False
mock_context.call_entity = Mock(return_value=mock_task)
agent = DurableAIAgent(mock_context, "TestAgent")
# Call without thread
task = agent.run(messages="Test message")
assert task == mock_task
# Verify the entity ID uses the auto-generated GUID with dafx- prefix
call_args = mock_context.call_entity.call_args
entity_id = call_args[0][0]
assert entity_id.name == "dafx-TestAgent"
assert entity_id.key == "auto-generated-guid"
# Should be called twice: once for session_key, once for correlation_id
assert mock_context.new_uuid.call_count == 2
def test_run_with_response_format(self) -> None:
"""Test that run() passes response format correctly."""
mock_context = Mock()
mock_context.instance_id = "test-instance-003"
mock_task = Mock()
mock_task._is_scheduled = False
mock_context.call_entity = Mock(return_value=mock_task)
agent = DurableAIAgent(mock_context, "TestAgent")
from pydantic import BaseModel
class SampleSchema(BaseModel):
key: str
# Create thread and call
thread = agent.get_new_thread()
task = agent.run(messages="Test message", thread=thread, response_format=SampleSchema)
assert task == mock_task
# Verify schema was passed in the call_entity arguments
call_args = mock_context.call_entity.call_args
input_data = call_args[0][2] # Third argument is input_data
assert "response_format" in input_data
assert input_data["response_format"]["__response_schema_type__"] == "pydantic_model"
assert input_data["response_format"]["module"] == SampleSchema.__module__
assert input_data["response_format"]["qualname"] == SampleSchema.__qualname__
def test_messages_to_string(self) -> None:
"""Test converting ChatMessage list to string."""
from agent_framework import ChatMessage
mock_context = Mock()
agent = DurableAIAgent(mock_context, "TestAgent")
messages = [
ChatMessage(role="user", text="Hello"),
ChatMessage(role="assistant", text="Hi there"),
ChatMessage(role="user", text="How are you?"),
]
result = agent._messages_to_string(messages)
assert result == "Hello\nHi there\nHow are you?"
def test_run_with_chat_message(self) -> None:
"""Test that run() handles ChatMessage input."""
from agent_framework import ChatMessage
mock_context = Mock()
mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"])
mock_task = Mock()
mock_context.call_entity = Mock(return_value=mock_task)
agent = DurableAIAgent(mock_context, "TestAgent")
thread = agent.get_new_thread()
# Call with ChatMessage
msg = ChatMessage(role="user", text="Hello")
task = agent.run(messages=msg, thread=thread)
assert task == mock_task
# Verify message was converted to string
call_args = mock_context.call_entity.call_args
request = call_args[0][2]
assert request["message"] == "Hello"
def test_run_stream_raises_not_implemented(self) -> None:
"""Test that run_stream() method raises NotImplementedError."""
mock_context = Mock()
agent = DurableAIAgent(mock_context, "TestAgent")
with pytest.raises(NotImplementedError) as exc_info:
agent.run_stream("Test message")
error_msg = str(exc_info.value)
assert "Streaming is not supported" in error_msg
def test_entity_id_format(self) -> None:
"""Test that EntityId is created with correct format (name, key)."""
from azure.durable_functions import EntityId
mock_context = Mock()
mock_context.new_uuid = Mock(return_value="test-guid-789")
mock_context.call_entity = Mock(return_value=Mock())
agent = DurableAIAgent(mock_context, "WriterAgent")
thread = agent.get_new_thread()
# Call run() to trigger entity ID creation
agent.run("Test", thread=thread)
# Verify call_entity was called with correct EntityId
call_args = mock_context.call_entity.call_args
entity_id = call_args[0][0]
# EntityId should be EntityId(name="dafx-WriterAgent", key="test-guid-789")
# Which formats as "@dafx-writeragent@test-guid-789"
assert isinstance(entity_id, EntityId)
assert entity_id.name == "dafx-WriterAgent"
assert entity_id.key == "test-guid-789"
assert str(entity_id) == "@dafx-writeragent@test-guid-789"
class TestGetAgentHelper:
"""Test suite for the get_agent helper function."""
def test_get_agent_function(self) -> None:
"""Test get_agent function creates DurableAIAgent."""
mock_context = Mock()
mock_context.instance_id = "test-instance-100"
agent = get_agent(mock_context, "MyAgent")
assert isinstance(agent, DurableAIAgent)
assert agent.agent_name == "MyAgent"
assert agent.context == mock_context
class TestOrchestrationIntegration:
"""Integration tests for orchestration scenarios."""
def test_sequential_agent_calls_simulation(self) -> None:
"""Simulate sequential agent calls in an orchestration."""
mock_context = Mock()
mock_context.instance_id = "test-orchestration-001"
# new_uuid will be called 3 times:
# 1. thread creation
# 2. correlation_id for first call
# 3. correlation_id for second call
mock_context.new_uuid = Mock(side_effect=["deterministic-guid-001", "corr-1", "corr-2"])
# Track entity calls
entity_calls: list[dict[str, Any]] = []
def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> Mock:
entity_calls.append({"entity_id": str(entity_id), "operation": operation, "input": input_data})
# Return a mock Task
mock_task = Mock()
mock_task._is_scheduled = False
return mock_task
mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect)
# Create agent
agent = get_agent(mock_context, "WriterAgent")
# Create thread
thread = agent.get_new_thread()
# First call - returns Task
task1 = agent.run("Write something", thread=thread)
assert hasattr(task1, "_is_scheduled")
# Second call - returns Task
task2 = agent.run("Improve: something", thread=thread)
assert hasattr(task2, "_is_scheduled")
# Verify both calls used the same entity (same session key)
assert len(entity_calls) == 2
assert entity_calls[0]["entity_id"] == entity_calls[1]["entity_id"]
# EntityId format is @dafx-writeragent@deterministic-guid-001
assert entity_calls[0]["entity_id"] == "@dafx-writeragent@deterministic-guid-001"
# new_uuid called 3 times: thread + 2 correlation IDs
assert mock_context.new_uuid.call_count == 3
def test_multiple_agents_in_orchestration(self) -> None:
"""Test using multiple different agents in one orchestration."""
mock_context = Mock()
mock_context.instance_id = "test-orchestration-002"
# Mock new_uuid to return different GUIDs for each call
# Order: writer thread, editor thread, writer correlation, editor correlation
mock_context.new_uuid = Mock(side_effect=["writer-guid-001", "editor-guid-002", "writer-corr", "editor-corr"])
entity_calls: list[str] = []
def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> Mock:
entity_calls.append(str(entity_id))
mock_task = Mock()
mock_task._is_scheduled = False
return mock_task
mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect)
# Create multiple agents
writer = get_agent(mock_context, "WriterAgent")
editor = get_agent(mock_context, "EditorAgent")
writer_thread = writer.get_new_thread()
editor_thread = editor.get_new_thread()
# Call both agents - returns Tasks
writer_task = writer.run("Write", thread=writer_thread)
editor_task = editor.run("Edit", thread=editor_thread)
assert hasattr(writer_task, "_is_scheduled")
assert hasattr(editor_task, "_is_scheduled")
# Verify different entity IDs were used
assert len(entity_calls) == 2
# EntityId format is @dafx-agentname@guid (lowercased agent name with dafx- prefix)
assert entity_calls[0] == "@dafx-writeragent@writer-guid-001"
assert entity_calls[1] == "@dafx-editoragent@editor-guid-002"
class TestAgentThreadSerialization:
"""Test that AgentThread can be serialized for orchestration state."""
async def test_agent_thread_serialize(self) -> None:
"""Test that AgentThread can be serialized."""
thread = AgentThread()
# Serialize
serialized = await thread.serialize()
assert isinstance(serialized, dict)
assert "service_thread_id" in serialized
async def test_agent_thread_deserialize(self) -> None:
"""Test that AgentThread can be deserialized."""
thread = AgentThread()
serialized = await thread.serialize()
# Deserialize
restored = await AgentThread.deserialize(serialized)
assert isinstance(restored, AgentThread)
assert restored.service_thread_id == thread.service_thread_id
async def test_durable_agent_thread_serialization(self) -> None:
"""Test that DurableAgentThread persists session metadata during serialization."""
mock_context = Mock()
mock_context.instance_id = "test-instance-999"
mock_context.new_uuid = Mock(return_value="test-guid-999")
agent = DurableAIAgent(mock_context, "TestAgent")
thread = agent.get_new_thread()
assert isinstance(thread, DurableAgentThread)
# Verify custom attribute and property exist
assert thread.session_id is not None
session_id = thread.session_id
assert isinstance(session_id, AgentSessionId)
assert session_id.name == "TestAgent"
assert session_id.key == "test-guid-999"
# Standard serialization should still work
serialized = await thread.serialize()
assert isinstance(serialized, dict)
assert serialized.get("durable_session_id") == str(session_id)
# After deserialization, we'd need to restore the custom attribute
# This would be handled by the orchestration framework
restored = await DurableAgentThread.deserialize(serialized)
assert isinstance(restored, DurableAgentThread)
assert restored.session_id == session_id
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
@@ -0,0 +1,110 @@
# Copyright (c) Microsoft. All rights reserved.
"""Unit tests for AgentState correlation ID tracking."""
from unittest.mock import Mock
import pytest
from agent_framework import AgentRunResponse
from agent_framework_azurefunctions._state import AgentState
class TestAgentStateCorrelationId:
"""Test suite for AgentState correlation ID tracking."""
def _create_mock_response(self, text: str = "Response") -> Mock:
"""Create a mock AgentRunResponse with the provided text."""
mock_response = Mock(spec=AgentRunResponse)
mock_response.to_dict.return_value = {"text": text, "messages": []}
return mock_response
def test_add_assistant_message_with_correlation_id(self) -> None:
state = AgentState()
state.add_user_message("Hello", correlation_id="corr-123-request")
state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123")
message_metadata = state.conversation_history[-1].additional_properties or {}
assert message_metadata.get("correlation_id") == "corr-123"
response_data = state.try_get_agent_response("corr-123")
assert response_data is not None
assert response_data["content"] == "Response"
assert response_data["agent_response"] == {"text": "Response", "messages": []}
def test_try_get_agent_response_returns_response(self) -> None:
state = AgentState()
state.add_user_message("Hello", correlation_id="corr-200-request")
state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-456")
response_data = state.try_get_agent_response("corr-456")
assert response_data is not None
assert response_data["content"] == "Response"
def test_try_get_agent_response_returns_none_for_missing_id(self) -> None:
state = AgentState()
state.add_user_message("Hello", correlation_id="corr-300-request")
state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123")
assert state.try_get_agent_response("non-existent") is None
def test_multiple_responses_tracked_separately(self) -> None:
state = AgentState()
for index in range(3):
state.add_user_message(f"Message {index}", correlation_id=f"corr-{index}-request")
state.add_assistant_message(
f"Response {index}",
self._create_mock_response(text=f"Response {index}"),
correlation_id=f"corr-{index}",
)
for index in range(3):
payload = state.try_get_agent_response(f"corr-{index}")
assert payload is not None
assert payload["content"] == f"Response {index}"
def test_add_assistant_message_without_correlation_id(self) -> None:
state = AgentState()
state.add_user_message("Hello", correlation_id="corr-400-request")
state.add_assistant_message("Response", self._create_mock_response())
assert state.try_get_agent_response("missing") is None
assert state.last_response == "Response"
def test_to_dict_does_not_duplicate_agent_responses(self) -> None:
state = AgentState()
state.add_user_message("Hello", correlation_id="corr-500-request")
state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123")
state_snapshot = state.to_dict()
assert "agent_responses" not in state_snapshot
metadata = state_snapshot["conversation_history"][-1]["additional_properties"]
assert metadata["correlation_id"] == "corr-123"
def test_restore_state_preserves_agent_response_lookup(self) -> None:
state = AgentState()
state.add_user_message("Hello", correlation_id="corr-600-request")
state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123")
restored_state = AgentState()
restored_state.restore_state(state.to_dict())
payload = restored_state.try_get_agent_response("corr-123")
assert payload is not None
assert payload["content"] == "Response"
def test_reset_clears_conversation_history(self) -> None:
state = AgentState()
state.add_user_message("Hello", correlation_id="corr-700-request")
state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123")
state.reset()
assert len(state.conversation_history) == 0
assert state.try_get_agent_response("corr-123") is None
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])