mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add Unit tests for Azurefunctions package (#1976)
* Add Unit tests for Azurefunctions * remove duplicate import
This commit is contained in:
committed by
GitHub
Unverified
parent
90742ba48e
commit
754491cdd3
@@ -197,9 +197,10 @@ class DurableAgentThread(AgentThread):
|
||||
**kwargs: Any,
|
||||
) -> "DurableAgentThread":
|
||||
"""Restores a durable thread, rehydrating the stored session identifier."""
|
||||
session_id_value = serialized_thread_state.get(cls._SERIALIZED_SESSION_ID_KEY)
|
||||
state_payload = dict(serialized_thread_state)
|
||||
session_id_value = state_payload.pop(cls._SERIALIZED_SESSION_ID_KEY, None)
|
||||
thread = await super().deserialize(
|
||||
serialized_thread_state,
|
||||
state_payload,
|
||||
message_store=message_store,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,614 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Unit tests for AgentFunctionApp."""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, TypeVar
|
||||
from unittest.mock import ANY, AsyncMock, Mock, patch
|
||||
|
||||
import azure.durable_functions as df
|
||||
import azure.functions as func
|
||||
import pytest
|
||||
from agent_framework import AgentRunResponse, ChatMessage
|
||||
|
||||
from agent_framework_azurefunctions import AgentFunctionApp
|
||||
from agent_framework_azurefunctions._entities import AgentEntity, AgentState, create_agent_entity
|
||||
from agent_framework_azurefunctions._errors import IncomingRequestError
|
||||
|
||||
TFunc = TypeVar("TFunc", bound=Callable[..., Any])
|
||||
|
||||
|
||||
class TestAgentFunctionAppInit:
|
||||
"""Test suite for AgentFunctionApp initialization."""
|
||||
|
||||
def test_init_with_defaults(self) -> None:
|
||||
"""Test initialization with default parameters."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "TestAgent"
|
||||
|
||||
app = AgentFunctionApp(agents=[mock_agent])
|
||||
|
||||
assert len(app.agents) == 1
|
||||
assert "TestAgent" in app.agents
|
||||
assert app.enable_health_check is True
|
||||
|
||||
def test_init_with_custom_auth_level(self) -> None:
|
||||
"""Test initialization with custom auth level."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "TestAgent"
|
||||
|
||||
app = AgentFunctionApp(agents=[mock_agent], http_auth_level=func.AuthLevel.FUNCTION)
|
||||
|
||||
# App should be created successfully
|
||||
assert "TestAgent" in app.agents
|
||||
|
||||
def test_init_with_health_check_disabled(self) -> None:
|
||||
"""Test initialization with health check disabled."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "TestAgent"
|
||||
|
||||
app = AgentFunctionApp(agents=[mock_agent], enable_health_check=False)
|
||||
|
||||
assert app.enable_health_check is False
|
||||
|
||||
def test_init_with_http_endpoints_disabled(self) -> None:
|
||||
"""Test initialization with HTTP endpoints disabled."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "TestAgent"
|
||||
|
||||
app = AgentFunctionApp(agents=[mock_agent], enable_http_endpoints=False)
|
||||
|
||||
assert app.enable_http_endpoints is False
|
||||
|
||||
def test_init_stores_agent_reference(self) -> None:
|
||||
"""Test that agent reference is stored correctly."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "TestAgent"
|
||||
|
||||
app = AgentFunctionApp(agents=[mock_agent])
|
||||
|
||||
assert app.agents["TestAgent"].name == "TestAgent"
|
||||
|
||||
def test_add_agent_uses_specific_callback(self) -> None:
|
||||
"""Verify that a per-agent callback overrides the default."""
|
||||
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "CallbackAgent"
|
||||
specific_callback = Mock()
|
||||
|
||||
with patch.object(AgentFunctionApp, "_setup_agent_functions") as setup_mock:
|
||||
app = AgentFunctionApp(default_callback=Mock())
|
||||
app.add_agent(mock_agent, callback=specific_callback)
|
||||
|
||||
setup_mock.assert_called_once()
|
||||
_, _, passed_callback, enable_http_endpoint = setup_mock.call_args[0]
|
||||
assert passed_callback is specific_callback
|
||||
assert enable_http_endpoint is True
|
||||
|
||||
def test_default_callback_applied_when_no_specific(self) -> None:
|
||||
"""Ensure the default callback is supplied when add_agent lacks override."""
|
||||
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "DefaultAgent"
|
||||
default_callback = Mock()
|
||||
|
||||
with patch.object(AgentFunctionApp, "_setup_agent_functions") as setup_mock:
|
||||
app = AgentFunctionApp(default_callback=default_callback)
|
||||
app.add_agent(mock_agent)
|
||||
|
||||
setup_mock.assert_called_once()
|
||||
_, _, passed_callback, enable_http_endpoint = setup_mock.call_args[0]
|
||||
assert passed_callback is default_callback
|
||||
assert enable_http_endpoint is True
|
||||
|
||||
def test_init_with_agents_uses_default_callback(self) -> None:
|
||||
"""Agents provided in __init__ should receive the default callback."""
|
||||
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "InitAgent"
|
||||
default_callback = Mock()
|
||||
|
||||
with patch.object(AgentFunctionApp, "_setup_agent_functions") as setup_mock:
|
||||
AgentFunctionApp(agents=[mock_agent], default_callback=default_callback)
|
||||
|
||||
setup_mock.assert_called_once()
|
||||
_, _, passed_callback, enable_http_endpoint = setup_mock.call_args[0]
|
||||
assert passed_callback is default_callback
|
||||
assert enable_http_endpoint is True
|
||||
|
||||
|
||||
class TestAgentFunctionAppSetup:
|
||||
"""Test suite for AgentFunctionApp setup and configuration."""
|
||||
|
||||
def test_app_is_dfapp_instance(self) -> None:
|
||||
"""Test that AgentFunctionApp is a DFApp instance."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "TestAgent"
|
||||
|
||||
app = AgentFunctionApp(agents=[mock_agent])
|
||||
|
||||
assert isinstance(app, df.DFApp)
|
||||
|
||||
def test_setup_creates_http_trigger(self) -> None:
|
||||
"""Test that setup creates an HTTP trigger."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "TestAgent"
|
||||
|
||||
def passthrough_decorator(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]:
|
||||
def decorator(func: TFunc) -> TFunc:
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
with (
|
||||
patch.object(AgentFunctionApp, "route", new=passthrough_decorator),
|
||||
patch.object(AgentFunctionApp, "durable_client_input", new=passthrough_decorator),
|
||||
patch.object(AgentFunctionApp, "entity_trigger", new=passthrough_decorator),
|
||||
):
|
||||
app = AgentFunctionApp(agents=[mock_agent])
|
||||
|
||||
# Verify agent is registered
|
||||
assert "TestAgent" in app.agents
|
||||
|
||||
def test_setup_skips_http_trigger_when_disabled(self) -> None:
|
||||
"""Test that HTTP trigger is not created when disabled."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "TestAgent"
|
||||
|
||||
captured_routes: list[str | None] = []
|
||||
|
||||
def capture_route(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]:
|
||||
def decorator(func: TFunc) -> TFunc:
|
||||
route_key = kwargs.get("route") if kwargs else None
|
||||
captured_routes.append(route_key)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def passthrough_decorator(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]:
|
||||
def decorator(func: TFunc) -> TFunc:
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
with (
|
||||
patch.object(AgentFunctionApp, "function_name", new=passthrough_decorator),
|
||||
patch.object(AgentFunctionApp, "route", new=capture_route),
|
||||
patch.object(AgentFunctionApp, "durable_client_input", new=passthrough_decorator),
|
||||
patch.object(AgentFunctionApp, "entity_trigger", new=passthrough_decorator),
|
||||
):
|
||||
app = AgentFunctionApp(agents=[mock_agent], enable_http_endpoints=False)
|
||||
|
||||
# Verify agent is registered
|
||||
assert "TestAgent" in app.agents
|
||||
|
||||
# Verify that no HTTP run route was created
|
||||
run_route = f"agents/{mock_agent.name}/run"
|
||||
assert run_route not in captured_routes
|
||||
|
||||
def test_agent_override_enables_http_route_when_app_disabled(self) -> None:
|
||||
"""Agent-level override should enable HTTP route even when app disables it."""
|
||||
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "OverrideAgent"
|
||||
|
||||
with (
|
||||
patch.object(AgentFunctionApp, "_setup_http_run_route") as http_route_mock,
|
||||
patch.object(AgentFunctionApp, "_setup_agent_entity") as agent_entity_mock,
|
||||
):
|
||||
app = AgentFunctionApp(enable_health_check=False, enable_http_endpoints=False)
|
||||
app.add_agent(mock_agent, enable_http_endpoint=True)
|
||||
|
||||
http_route_mock.assert_called_once_with("OverrideAgent")
|
||||
agent_entity_mock.assert_called_once_with(mock_agent, "OverrideAgent", ANY)
|
||||
assert app.agent_http_endpoint_flags["OverrideAgent"] is True
|
||||
|
||||
def test_agent_override_disables_http_route_when_app_enabled(self) -> None:
|
||||
"""Agent-level override should disable HTTP route even when app enables it."""
|
||||
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "DisabledOverride"
|
||||
|
||||
with (
|
||||
patch.object(AgentFunctionApp, "_setup_http_run_route") as http_route_mock,
|
||||
patch.object(AgentFunctionApp, "_setup_agent_entity") as agent_entity_mock,
|
||||
):
|
||||
app = AgentFunctionApp(enable_health_check=False, enable_http_endpoints=True)
|
||||
app.add_agent(mock_agent, enable_http_endpoint=False)
|
||||
|
||||
http_route_mock.assert_not_called()
|
||||
agent_entity_mock.assert_called_once_with(mock_agent, "DisabledOverride", ANY)
|
||||
assert app.agent_http_endpoint_flags["DisabledOverride"] is False
|
||||
|
||||
def test_multiple_apps_independent(self) -> None:
|
||||
"""Test that multiple AgentFunctionApp instances are independent."""
|
||||
agent1 = Mock()
|
||||
agent1.name = "Agent1"
|
||||
agent2 = Mock()
|
||||
agent2.name = "Agent2"
|
||||
|
||||
app1 = AgentFunctionApp(agents=[agent1])
|
||||
app2 = AgentFunctionApp(agents=[agent2])
|
||||
|
||||
assert app1.agents["Agent1"].name == "Agent1"
|
||||
assert app2.agents["Agent2"].name == "Agent2"
|
||||
assert "Agent1" in app1.agents
|
||||
assert "Agent2" in app2.agents
|
||||
|
||||
|
||||
class TestWaitForCompletionAndCorrelationId:
|
||||
"""Tests for wait_for_completion flag and correlation ID handling."""
|
||||
|
||||
def _create_app(self) -> AgentFunctionApp:
|
||||
mock_agent = Mock()
|
||||
mock_agent.__class__.__name__ = "MockAgent"
|
||||
mock_agent.name = "MockAgent"
|
||||
return AgentFunctionApp(agents=[mock_agent], enable_health_check=False)
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
headers: dict[str, str] | None = None,
|
||||
params: dict[str, str] | None = None,
|
||||
) -> Mock:
|
||||
request = Mock()
|
||||
request.headers = headers or {}
|
||||
request.params = params or {}
|
||||
return request
|
||||
|
||||
def test_wait_for_completion_header_true(self) -> None:
|
||||
"""Test that the wait-for-completion header is honored."""
|
||||
app = self._create_app()
|
||||
request = self._make_request(headers={"X-Wait-For-Completion": "true"})
|
||||
|
||||
assert app._should_wait_for_completion(request, {}) is True
|
||||
|
||||
def test_wait_for_completion_body_variants(self) -> None:
|
||||
"""Test that multiple payload spellings are accepted."""
|
||||
app = self._create_app()
|
||||
request = self._make_request()
|
||||
|
||||
assert app._should_wait_for_completion(request, {"wait_for_completion": "true"}) is True
|
||||
assert app._should_wait_for_completion(request, {"waitForCompletion": "1"}) is True
|
||||
assert app._should_wait_for_completion(request, {"WaitForCompletion": "no"}) is False
|
||||
|
||||
|
||||
class TestAgentEntityOperations:
|
||||
"""Test suite for entity operations."""
|
||||
|
||||
async def test_entity_run_agent_operation(self) -> None:
|
||||
"""Test that entity can run agent operation."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(
|
||||
return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Test response")])
|
||||
)
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
result = await entity.run_agent(
|
||||
mock_context,
|
||||
{"message": "Test message", "conversation_id": "test-conv-123", "correlation_id": "corr-app-entity-1"},
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["response"] == "Test response"
|
||||
assert result["message"] == "Test message"
|
||||
assert result["conversation_id"] == "test-conv-123"
|
||||
assert entity.state.message_count == 1
|
||||
|
||||
async def test_entity_stores_conversation_history(self) -> None:
|
||||
"""Test that the entity stores conversation history."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(
|
||||
return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Response 1")])
|
||||
)
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
# Send first message
|
||||
await entity.run_agent(
|
||||
mock_context, {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-app-entity-2"}
|
||||
)
|
||||
|
||||
history = entity.state.conversation_history
|
||||
assert len(history) == 2 # User + assistant
|
||||
|
||||
user_msg = history[0]
|
||||
user_role = getattr(user_msg.role, "value", user_msg.role)
|
||||
assert user_role == "user"
|
||||
assert user_msg.text == "Message 1"
|
||||
|
||||
assistant_msg = history[1]
|
||||
assistant_role = getattr(assistant_msg.role, "value", assistant_msg.role)
|
||||
assert assistant_role == "assistant"
|
||||
assert assistant_msg.text == "Response 1"
|
||||
|
||||
async def test_entity_increments_message_count(self) -> None:
|
||||
"""Test that the entity increments the message count."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(
|
||||
return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Response")])
|
||||
)
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
assert entity.state.message_count == 0
|
||||
|
||||
await entity.run_agent(
|
||||
mock_context, {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-app-entity-3a"}
|
||||
)
|
||||
assert entity.state.message_count == 1
|
||||
|
||||
await entity.run_agent(
|
||||
mock_context, {"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-app-entity-3b"}
|
||||
)
|
||||
assert entity.state.message_count == 2
|
||||
|
||||
def test_entity_reset(self) -> None:
|
||||
"""Test that entity reset clears state."""
|
||||
mock_agent = Mock()
|
||||
entity = AgentEntity(mock_agent)
|
||||
|
||||
# Set some state
|
||||
entity.state.message_count = 10
|
||||
entity.state.last_response = "Some response"
|
||||
entity.state.conversation_history = [
|
||||
ChatMessage(role="user", text="test", additional_properties={"timestamp": "2024-01-01T00:00:00Z"})
|
||||
]
|
||||
|
||||
# Reset
|
||||
mock_context = Mock()
|
||||
entity.reset(mock_context)
|
||||
|
||||
assert entity.state.message_count == 0
|
||||
assert entity.state.last_response is None
|
||||
assert len(entity.state.conversation_history) == 0
|
||||
|
||||
|
||||
class TestAgentEntityFactory:
|
||||
"""Test suite for the entity factory function."""
|
||||
|
||||
def test_create_agent_entity_returns_function(self) -> None:
|
||||
"""Test that create_agent_entity returns a function."""
|
||||
mock_agent = Mock()
|
||||
entity_function = create_agent_entity(mock_agent)
|
||||
|
||||
assert callable(entity_function)
|
||||
|
||||
def test_entity_function_handles_run_agent_operation(self) -> None:
|
||||
"""Test that the entity function handles the run_agent operation."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(
|
||||
return_value=AgentRunResponse(messages=[ChatMessage(role="assistant", text="Response")])
|
||||
)
|
||||
|
||||
entity_function = create_agent_entity(mock_agent)
|
||||
|
||||
# Mock context
|
||||
mock_context = Mock()
|
||||
mock_context.operation_name = "run_agent"
|
||||
mock_context.get_input.return_value = {
|
||||
"message": "Test message",
|
||||
"conversation_id": "conv-123",
|
||||
"correlation_id": "corr-app-factory-1",
|
||||
}
|
||||
mock_context.get_state.return_value = None
|
||||
|
||||
# Execute entity function
|
||||
entity_function(mock_context)
|
||||
|
||||
# Verify result was set
|
||||
assert mock_context.set_result.called
|
||||
assert mock_context.set_state.called
|
||||
|
||||
def test_entity_function_handles_reset_operation(self) -> None:
|
||||
"""Test that the entity function handles the reset operation."""
|
||||
mock_agent = Mock()
|
||||
entity_function = create_agent_entity(mock_agent)
|
||||
|
||||
# Mock context
|
||||
mock_context = Mock()
|
||||
mock_context.operation_name = "reset"
|
||||
mock_context.get_state.return_value = {
|
||||
"message_count": 5,
|
||||
"conversation_history": [{"role": "user", "content": "test"}],
|
||||
"last_response": "Test",
|
||||
}
|
||||
|
||||
# Execute entity function
|
||||
entity_function(mock_context)
|
||||
|
||||
# Verify result was set
|
||||
assert mock_context.set_result.called
|
||||
result_call = mock_context.set_result.call_args[0][0]
|
||||
assert result_call["status"] == "reset"
|
||||
|
||||
def test_entity_function_handles_unknown_operation(self) -> None:
|
||||
"""Test that the entity function handles an unknown operation."""
|
||||
mock_agent = Mock()
|
||||
entity_function = create_agent_entity(mock_agent)
|
||||
|
||||
# Mock context with unknown operation
|
||||
mock_context = Mock()
|
||||
mock_context.operation_name = "unknown_operation"
|
||||
mock_context.get_state.return_value = None
|
||||
|
||||
# Execute entity function
|
||||
entity_function(mock_context)
|
||||
|
||||
# Verify error result was set
|
||||
assert mock_context.set_result.called
|
||||
result_call = mock_context.set_result.call_args[0][0]
|
||||
assert "error" in result_call
|
||||
assert "unknown_operation" in result_call["error"]
|
||||
|
||||
def test_entity_function_restores_state(self) -> None:
|
||||
"""Test that the entity function restores state from the context."""
|
||||
mock_agent = Mock()
|
||||
entity_function = create_agent_entity(mock_agent)
|
||||
|
||||
# Mock context with existing state
|
||||
existing_state = {
|
||||
"message_count": 3,
|
||||
"conversation_history": [{"role": "user", "content": "msg1"}, {"role": "assistant", "content": "resp1"}],
|
||||
"last_response": "resp1",
|
||||
}
|
||||
|
||||
mock_context = Mock()
|
||||
mock_context.operation_name = "reset"
|
||||
mock_context.get_state.return_value = existing_state
|
||||
|
||||
with patch.object(AgentState, "restore_state") as restore_state_mock:
|
||||
entity_function(mock_context)
|
||||
|
||||
restore_state_mock.assert_called_once_with(existing_state)
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test suite for error handling."""
|
||||
|
||||
async def test_entity_handles_agent_error(self) -> None:
|
||||
"""Test that the entity handles agent execution errors."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(side_effect=Exception("Agent error"))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
result = await entity.run_agent(
|
||||
mock_context, {"message": "Test message", "conversation_id": "conv-1", "correlation_id": "corr-app-error-1"}
|
||||
)
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "error" in result
|
||||
assert "Agent error" in result["error"]
|
||||
assert result["error_type"] == "Exception"
|
||||
|
||||
def test_entity_function_handles_exception(self) -> None:
|
||||
"""Test that the entity function handles exceptions gracefully."""
|
||||
mock_agent = Mock()
|
||||
# Force an exception by making get_input fail
|
||||
mock_agent.run = AsyncMock(side_effect=Exception("Test error"))
|
||||
|
||||
entity_function = create_agent_entity(mock_agent)
|
||||
|
||||
mock_context = Mock()
|
||||
mock_context.operation_name = "run_agent"
|
||||
mock_context.get_input.side_effect = Exception("Input error")
|
||||
mock_context.get_state.return_value = None
|
||||
|
||||
# Execute entity function - should not raise
|
||||
entity_function(mock_context)
|
||||
|
||||
# Verify error result was set
|
||||
assert mock_context.set_result.called
|
||||
result_call = mock_context.set_result.call_args[0][0]
|
||||
assert "error" in result_call
|
||||
|
||||
|
||||
class TestIncomingRequestParsing:
|
||||
"""Tests for parsing run requests with JSON and plain text bodies."""
|
||||
|
||||
def _create_app(self) -> AgentFunctionApp:
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "ParserAgent"
|
||||
return AgentFunctionApp(agents=[mock_agent], enable_health_check=False)
|
||||
|
||||
def test_parse_plain_text_body(self) -> None:
|
||||
"""Test parsing a plain-text request body."""
|
||||
app = self._create_app()
|
||||
|
||||
request = Mock()
|
||||
request.get_json.side_effect = ValueError("Invalid JSON")
|
||||
request.get_body.return_value = b"Plain text message"
|
||||
|
||||
req_body, message = app._parse_incoming_request(request)
|
||||
|
||||
assert req_body == {}
|
||||
assert message == "Plain text message"
|
||||
|
||||
def test_parse_plain_text_requires_content(self) -> None:
|
||||
"""Test that plain-text requests require message content."""
|
||||
app = self._create_app()
|
||||
|
||||
request = Mock()
|
||||
request.get_json.side_effect = ValueError("Invalid JSON")
|
||||
request.get_body.return_value = b" "
|
||||
|
||||
with pytest.raises(IncomingRequestError) as exc_info:
|
||||
app._parse_incoming_request(request)
|
||||
|
||||
assert "Message is required" in str(exc_info.value)
|
||||
|
||||
def test_extract_session_key_from_query_params(self) -> None:
|
||||
"""Test session key extraction from query parameters."""
|
||||
app = self._create_app()
|
||||
|
||||
request = Mock()
|
||||
request.params = {"sessionId": "query-session"}
|
||||
req_body = {}
|
||||
|
||||
session_key = app._resolve_session_key(request, req_body)
|
||||
|
||||
assert session_key == "query-session"
|
||||
|
||||
|
||||
class TestHttpRunRoute:
|
||||
"""Tests for the HTTP run route behavior."""
|
||||
|
||||
async def test_http_run_accepts_plain_text(self) -> None:
|
||||
"""Test that the HTTP handler accepts plain-text requests."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "HttpAgent"
|
||||
|
||||
captured_handlers: dict[str | None, Callable[..., Awaitable[func.HttpResponse]]] = {}
|
||||
|
||||
def capture_decorator(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]:
|
||||
def decorator(func: TFunc) -> TFunc:
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def capture_route(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]:
|
||||
def decorator(func: TFunc) -> TFunc:
|
||||
route_key = kwargs.get("route") if kwargs else None
|
||||
captured_handlers[route_key] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
with (
|
||||
patch.object(AgentFunctionApp, "function_name", new=capture_decorator),
|
||||
patch.object(AgentFunctionApp, "route", new=capture_route),
|
||||
patch.object(AgentFunctionApp, "durable_client_input", new=capture_decorator),
|
||||
patch.object(AgentFunctionApp, "entity_trigger", new=capture_decorator),
|
||||
):
|
||||
AgentFunctionApp(agents=[mock_agent], enable_health_check=False)
|
||||
|
||||
run_route = f"agents/{mock_agent.name}/run"
|
||||
handler = captured_handlers[run_route]
|
||||
|
||||
request = Mock()
|
||||
request.headers = {}
|
||||
request.params = {}
|
||||
request.route_params = {}
|
||||
request.get_json.side_effect = ValueError("Invalid JSON")
|
||||
request.get_body.return_value = b"Plain text via HTTP"
|
||||
|
||||
client = AsyncMock()
|
||||
|
||||
response = await handler(request, client)
|
||||
|
||||
assert response.status_code == 202
|
||||
|
||||
signal_args = client.signal_entity.call_args[0]
|
||||
run_request = signal_args[2]
|
||||
|
||||
assert run_request["message"] == "Plain text via HTTP"
|
||||
assert run_request["role"] == "user"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
@@ -0,0 +1,904 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Unit tests for AgentEntity and entity operations.
|
||||
|
||||
Run with: pytest tests/test_entities.py -v
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from datetime import datetime
|
||||
from typing import Any, TypeVar
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework import AgentRunResponse, AgentRunResponseUpdate, ChatMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agent_framework_azurefunctions._entities import AgentEntity, create_agent_entity
|
||||
from agent_framework_azurefunctions._models import ChatRole, RunRequest
|
||||
from agent_framework_azurefunctions._state import AgentState
|
||||
|
||||
TFunc = TypeVar("TFunc", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def _role_value(chat_message: ChatMessage) -> str:
|
||||
"""Helper to extract the string role from a ChatMessage."""
|
||||
role = getattr(chat_message, "role", None)
|
||||
role_value = getattr(role, "value", role)
|
||||
if role_value is None:
|
||||
return ""
|
||||
return str(role_value)
|
||||
|
||||
|
||||
def _agent_response(text: str | None) -> AgentRunResponse:
|
||||
"""Create an AgentRunResponse with a single assistant message."""
|
||||
message = (
|
||||
ChatMessage(role="assistant", text=text) if text is not None else ChatMessage(role="assistant", contents=[])
|
||||
)
|
||||
return AgentRunResponse(messages=[message])
|
||||
|
||||
|
||||
class RecordingCallback:
|
||||
"""Callback implementation capturing streaming and final responses for assertions."""
|
||||
|
||||
def __init__(self):
|
||||
self.stream_mock = AsyncMock()
|
||||
self.response_mock = AsyncMock()
|
||||
|
||||
async def on_streaming_response_update(
|
||||
self,
|
||||
update: AgentRunResponseUpdate,
|
||||
context: Any,
|
||||
) -> None:
|
||||
await self.stream_mock(update, context)
|
||||
|
||||
async def on_agent_response(self, response: AgentRunResponse, context: Any) -> None:
|
||||
await self.response_mock(response, context)
|
||||
|
||||
|
||||
class EntityStructuredResponse(BaseModel):
|
||||
answer: float
|
||||
|
||||
|
||||
class TestAgentEntityInit:
|
||||
"""Test suite for AgentEntity initialization."""
|
||||
|
||||
def test_init_creates_entity(self) -> None:
|
||||
"""Test that AgentEntity initializes correctly."""
|
||||
mock_agent = Mock()
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
|
||||
assert entity.agent == mock_agent
|
||||
assert entity.state.conversation_history == []
|
||||
assert entity.state.last_response is None
|
||||
assert entity.state.message_count == 0
|
||||
|
||||
def test_init_stores_agent_reference(self) -> None:
|
||||
"""Test that the agent reference is stored correctly."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "TestAgent"
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
|
||||
assert entity.agent.name == "TestAgent"
|
||||
|
||||
def test_init_with_different_agent_types(self) -> None:
|
||||
"""Test initialization with different agent types."""
|
||||
agent1 = Mock()
|
||||
agent1.__class__.__name__ = "AzureOpenAIAgent"
|
||||
|
||||
agent2 = Mock()
|
||||
agent2.__class__.__name__ = "CustomAgent"
|
||||
|
||||
entity1 = AgentEntity(agent1)
|
||||
entity2 = AgentEntity(agent2)
|
||||
|
||||
assert entity1.agent.__class__.__name__ == "AzureOpenAIAgent"
|
||||
assert entity2.agent.__class__.__name__ == "CustomAgent"
|
||||
|
||||
|
||||
class TestAgentEntityRunAgent:
|
||||
"""Test suite for the run_agent operation."""
|
||||
|
||||
async def test_run_agent_executes_agent(self) -> None:
|
||||
"""Test that run_agent executes the agent."""
|
||||
mock_agent = Mock()
|
||||
mock_response = _agent_response("Test response")
|
||||
mock_agent.run = AsyncMock(return_value=mock_response)
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
result = await entity.run_agent(
|
||||
mock_context, {"message": "Test message", "conversation_id": "conv-123", "correlation_id": "corr-entity-1"}
|
||||
)
|
||||
|
||||
# Verify agent.run was called
|
||||
mock_agent.run.assert_called_once()
|
||||
_, kwargs = mock_agent.run.call_args
|
||||
sent_messages = kwargs.get("messages")
|
||||
assert isinstance(sent_messages, list)
|
||||
assert len(sent_messages) == 1
|
||||
sent_message = sent_messages[0]
|
||||
assert isinstance(sent_message, ChatMessage)
|
||||
assert sent_message.text == "Test message"
|
||||
assert _role_value(sent_message) == "user"
|
||||
|
||||
# Verify result
|
||||
assert result["status"] == "success"
|
||||
assert result["response"] == "Test response"
|
||||
assert result["message"] == "Test message"
|
||||
assert result["conversation_id"] == "conv-123"
|
||||
|
||||
async def test_run_agent_streaming_callbacks_invoked(self) -> None:
|
||||
"""Ensure streaming updates trigger callbacks and run() is not used."""
|
||||
|
||||
updates = [
|
||||
AgentRunResponseUpdate(text="Hello"),
|
||||
AgentRunResponseUpdate(text=" world"),
|
||||
]
|
||||
|
||||
async def update_generator() -> AsyncIterator[AgentRunResponseUpdate]:
|
||||
for update in updates:
|
||||
yield update
|
||||
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "StreamingAgent"
|
||||
mock_agent.run_stream = Mock(return_value=update_generator())
|
||||
mock_agent.run = AsyncMock(side_effect=AssertionError("run() should not be called when streaming succeeds"))
|
||||
|
||||
callback = RecordingCallback()
|
||||
entity = AgentEntity(mock_agent, callback=callback)
|
||||
mock_context = Mock()
|
||||
|
||||
result = await entity.run_agent(
|
||||
mock_context,
|
||||
{
|
||||
"message": "Tell me something",
|
||||
"conversation_id": "session-1",
|
||||
"correlation_id": "corr-stream-1",
|
||||
},
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert "Hello" in result.get("response", "")
|
||||
assert callback.stream_mock.await_count == len(updates)
|
||||
assert callback.response_mock.await_count == 1
|
||||
mock_agent.run.assert_not_called()
|
||||
|
||||
# Validate callback arguments
|
||||
stream_calls = callback.stream_mock.await_args_list
|
||||
for expected_update, recorded_call in zip(updates, stream_calls, strict=True):
|
||||
assert recorded_call.args[0] is expected_update
|
||||
context = recorded_call.args[1]
|
||||
assert context.agent_name == "StreamingAgent"
|
||||
assert context.correlation_id == "corr-stream-1"
|
||||
assert context.conversation_id == "session-1"
|
||||
assert context.request_message == "Tell me something"
|
||||
|
||||
final_call = callback.response_mock.await_args
|
||||
assert final_call is not None
|
||||
final_response, final_context = final_call.args
|
||||
assert final_context.agent_name == "StreamingAgent"
|
||||
assert final_context.correlation_id == "corr-stream-1"
|
||||
assert final_context.conversation_id == "session-1"
|
||||
assert final_context.request_message == "Tell me something"
|
||||
assert getattr(final_response, "text", "").strip()
|
||||
|
||||
async def test_run_agent_final_callback_without_streaming(self) -> None:
|
||||
"""Ensure the final callback fires even when streaming is unavailable."""
|
||||
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "NonStreamingAgent"
|
||||
mock_agent.run_stream = None
|
||||
agent_response = _agent_response("Final response")
|
||||
mock_agent.run = AsyncMock(return_value=agent_response)
|
||||
|
||||
callback = RecordingCallback()
|
||||
entity = AgentEntity(mock_agent, callback=callback)
|
||||
mock_context = Mock()
|
||||
|
||||
result = await entity.run_agent(
|
||||
mock_context,
|
||||
{
|
||||
"message": "Hi",
|
||||
"conversation_id": "session-2",
|
||||
"correlation_id": "corr-final-1",
|
||||
},
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result.get("response") == "Final response"
|
||||
assert callback.stream_mock.await_count == 0
|
||||
assert callback.response_mock.await_count == 1
|
||||
|
||||
final_call = callback.response_mock.await_args
|
||||
assert final_call is not None
|
||||
assert final_call.args[0] is agent_response
|
||||
final_context = final_call.args[1]
|
||||
assert final_context.agent_name == "NonStreamingAgent"
|
||||
assert final_context.correlation_id == "corr-final-1"
|
||||
assert final_context.conversation_id == "session-2"
|
||||
assert final_context.request_message == "Hi"
|
||||
|
||||
async def test_run_agent_updates_conversation_history(self) -> None:
|
||||
"""Test that run_agent updates the conversation history."""
|
||||
mock_agent = Mock()
|
||||
mock_response = _agent_response("Agent response")
|
||||
mock_agent.run = AsyncMock(return_value=mock_response)
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
await entity.run_agent(
|
||||
mock_context, {"message": "User message", "conversation_id": "conv-1", "correlation_id": "corr-entity-2"}
|
||||
)
|
||||
|
||||
# Should have 2 entries: user message + assistant response
|
||||
history = entity.state.conversation_history
|
||||
|
||||
assert len(history) == 2
|
||||
|
||||
user_msg = history[0]
|
||||
assert _role_value(user_msg) == "user"
|
||||
assert user_msg.text == "User message"
|
||||
|
||||
assistant_msg = history[1]
|
||||
assert _role_value(assistant_msg) == "assistant"
|
||||
assert assistant_msg.text == "Agent response"
|
||||
|
||||
async def test_run_agent_increments_message_count(self) -> None:
|
||||
"""Test that run_agent increments the message count."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
assert entity.state.message_count == 0
|
||||
|
||||
await entity.run_agent(
|
||||
mock_context, {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-entity-3a"}
|
||||
)
|
||||
assert entity.state.message_count == 1
|
||||
|
||||
await entity.run_agent(
|
||||
mock_context, {"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-entity-3b"}
|
||||
)
|
||||
assert entity.state.message_count == 2
|
||||
|
||||
await entity.run_agent(
|
||||
mock_context, {"message": "Message 3", "conversation_id": "conv-1", "correlation_id": "corr-entity-3c"}
|
||||
)
|
||||
assert entity.state.message_count == 3
|
||||
|
||||
async def test_run_agent_stores_last_response(self) -> None:
|
||||
"""Test that run_agent stores the last response."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response 1"))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
await entity.run_agent(
|
||||
mock_context, {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-entity-4a"}
|
||||
)
|
||||
assert entity.state.last_response == "Response 1"
|
||||
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response 2"))
|
||||
await entity.run_agent(
|
||||
mock_context, {"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-entity-4b"}
|
||||
)
|
||||
assert entity.state.last_response == "Response 2"
|
||||
|
||||
async def test_run_agent_with_none_conversation_id(self) -> None:
|
||||
"""Test run_agent with a None conversation identifier."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
with pytest.raises(ValueError, match="conversation_id"):
|
||||
await entity.run_agent(
|
||||
mock_context, {"message": "Message", "conversation_id": None, "correlation_id": "corr-entity-5"}
|
||||
)
|
||||
|
||||
async def test_run_agent_handles_response_without_text_attribute(self) -> None:
|
||||
"""Test that run_agent handles responses without a text attribute."""
|
||||
mock_agent = Mock()
|
||||
|
||||
class NoTextResponse(AgentRunResponse):
|
||||
@property
|
||||
def text(self) -> str: # type: ignore[override]
|
||||
raise AttributeError("text attribute missing")
|
||||
|
||||
mock_response = NoTextResponse(messages=[ChatMessage(role="assistant", text="ignored")])
|
||||
mock_agent.run = AsyncMock(return_value=mock_response)
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
result = await entity.run_agent(
|
||||
mock_context, {"message": "Message", "conversation_id": "conv-1", "correlation_id": "corr-entity-6"}
|
||||
)
|
||||
|
||||
# Should handle gracefully
|
||||
assert result["status"] == "success"
|
||||
assert result["response"] == "Error extracting response"
|
||||
|
||||
async def test_run_agent_handles_none_response_text(self) -> None:
|
||||
"""Test that run_agent handles responses with None text."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response(None))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
result = await entity.run_agent(
|
||||
mock_context, {"message": "Message", "conversation_id": "conv-1", "correlation_id": "corr-entity-7"}
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["response"] == "No response"
|
||||
|
||||
async def test_run_agent_multiple_conversations(self) -> None:
|
||||
"""Test that run_agent maintains history across multiple messages."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
# Send multiple messages
|
||||
await entity.run_agent(
|
||||
mock_context, {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-entity-8a"}
|
||||
)
|
||||
await entity.run_agent(
|
||||
mock_context, {"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-entity-8b"}
|
||||
)
|
||||
await entity.run_agent(
|
||||
mock_context, {"message": "Message 3", "conversation_id": "conv-1", "correlation_id": "corr-entity-8c"}
|
||||
)
|
||||
|
||||
history = entity.state.conversation_history
|
||||
assert len(history) == 6
|
||||
assert entity.state.message_count == 3
|
||||
|
||||
|
||||
class TestAgentEntityReset:
|
||||
"""Test suite for the reset operation."""
|
||||
|
||||
def test_reset_clears_conversation_history(self) -> None:
|
||||
"""Test that reset clears the conversation history."""
|
||||
mock_agent = Mock()
|
||||
entity = AgentEntity(mock_agent)
|
||||
|
||||
# Add some history
|
||||
entity.state.conversation_history = [
|
||||
ChatMessage(role="user", text="msg1"),
|
||||
ChatMessage(role="assistant", text="resp1"),
|
||||
]
|
||||
|
||||
mock_context = Mock()
|
||||
entity.reset(mock_context)
|
||||
|
||||
assert entity.state.conversation_history == []
|
||||
|
||||
def test_reset_clears_last_response(self) -> None:
|
||||
"""Test that reset clears the last response."""
|
||||
mock_agent = Mock()
|
||||
entity = AgentEntity(mock_agent)
|
||||
|
||||
entity.state.last_response = "Some response"
|
||||
|
||||
mock_context = Mock()
|
||||
entity.reset(mock_context)
|
||||
|
||||
assert entity.state.last_response is None
|
||||
|
||||
def test_reset_clears_message_count(self) -> None:
|
||||
"""Test that reset clears the message count."""
|
||||
mock_agent = Mock()
|
||||
entity = AgentEntity(mock_agent)
|
||||
|
||||
entity.state.message_count = 10
|
||||
|
||||
mock_context = Mock()
|
||||
entity.reset(mock_context)
|
||||
|
||||
assert entity.state.message_count == 0
|
||||
|
||||
async def test_reset_after_conversation(self) -> None:
|
||||
"""Test reset after a full conversation."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
# Have a conversation
|
||||
await entity.run_agent(
|
||||
mock_context, {"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-entity-10a"}
|
||||
)
|
||||
await entity.run_agent(
|
||||
mock_context, {"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-entity-10b"}
|
||||
)
|
||||
|
||||
# Verify state before reset
|
||||
assert entity.state.message_count == 2
|
||||
assert len(entity.state.conversation_history) == 4
|
||||
|
||||
# Reset
|
||||
entity.reset(mock_context)
|
||||
|
||||
# Verify state after reset
|
||||
assert entity.state.message_count == 0
|
||||
assert len(entity.state.conversation_history) == 0
|
||||
assert entity.state.last_response is None
|
||||
|
||||
|
||||
class TestCreateAgentEntity:
|
||||
"""Test suite for the create_agent_entity factory function."""
|
||||
|
||||
def test_create_agent_entity_returns_callable(self) -> None:
|
||||
"""Test that create_agent_entity returns a callable."""
|
||||
mock_agent = Mock()
|
||||
|
||||
entity_function = create_agent_entity(mock_agent)
|
||||
|
||||
assert callable(entity_function)
|
||||
|
||||
def test_entity_function_handles_run_agent(self) -> None:
|
||||
"""Test that the entity function handles the run_agent operation."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity_function = create_agent_entity(mock_agent)
|
||||
|
||||
# Mock context
|
||||
mock_context = Mock()
|
||||
mock_context.operation_name = "run_agent"
|
||||
mock_context.get_input.return_value = {
|
||||
"message": "Test message",
|
||||
"conversation_id": "conv-123",
|
||||
"correlation_id": "corr-entity-factory",
|
||||
}
|
||||
mock_context.get_state.return_value = None
|
||||
|
||||
# Execute
|
||||
entity_function(mock_context)
|
||||
|
||||
# Verify result and state were set
|
||||
assert mock_context.set_result.called
|
||||
assert mock_context.set_state.called
|
||||
|
||||
def test_entity_function_handles_reset(self) -> None:
|
||||
"""Test that the entity function handles the reset operation."""
|
||||
mock_agent = Mock()
|
||||
|
||||
entity_function = create_agent_entity(mock_agent)
|
||||
|
||||
# Mock context with existing state
|
||||
mock_context = Mock()
|
||||
mock_context.operation_name = "reset"
|
||||
mock_context.get_state.return_value = {
|
||||
"message_count": 5,
|
||||
"conversation_history": [
|
||||
ChatMessage(
|
||||
role="user", text="test", additional_properties={"timestamp": "2024-01-01T00:00:00Z"}
|
||||
).to_dict()
|
||||
],
|
||||
"last_response": "Test",
|
||||
}
|
||||
|
||||
# Execute
|
||||
entity_function(mock_context)
|
||||
|
||||
# Verify reset result
|
||||
assert mock_context.set_result.called
|
||||
result = mock_context.set_result.call_args[0][0]
|
||||
assert result["status"] == "reset"
|
||||
|
||||
# Verify state was cleared
|
||||
assert mock_context.set_state.called
|
||||
state = mock_context.set_state.call_args[0][0]
|
||||
assert state["message_count"] == 0
|
||||
assert state["conversation_history"] == []
|
||||
assert state["last_response"] is None
|
||||
|
||||
def test_entity_function_handles_unknown_operation(self) -> None:
|
||||
"""Test that the entity function handles unknown operations."""
|
||||
mock_agent = Mock()
|
||||
|
||||
entity_function = create_agent_entity(mock_agent)
|
||||
|
||||
mock_context = Mock()
|
||||
mock_context.operation_name = "invalid_operation"
|
||||
mock_context.get_state.return_value = None
|
||||
|
||||
# Execute
|
||||
entity_function(mock_context)
|
||||
|
||||
# Verify error result
|
||||
assert mock_context.set_result.called
|
||||
result = mock_context.set_result.call_args[0][0]
|
||||
assert "error" in result
|
||||
assert "invalid_operation" in result["error"].lower()
|
||||
|
||||
def test_entity_function_creates_new_entity_on_first_call(self) -> None:
|
||||
"""Test that the entity function creates a new entity when no state exists."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.__class__.__name__ = "Agent"
|
||||
|
||||
entity_function = create_agent_entity(mock_agent)
|
||||
mock_context = Mock()
|
||||
mock_context.operation_name = "reset"
|
||||
mock_context.get_state.return_value = None # No existing state
|
||||
|
||||
# Execute
|
||||
entity_function(mock_context)
|
||||
|
||||
# Verify new entity state was created
|
||||
assert mock_context.set_result.called
|
||||
result = mock_context.set_result.call_args[0][0]
|
||||
assert result["status"] == "reset"
|
||||
assert mock_context.set_state.called
|
||||
state = mock_context.set_state.call_args[0][0]
|
||||
assert state["message_count"] == 0
|
||||
assert state["conversation_history"] == []
|
||||
|
||||
def test_entity_function_restores_existing_state(self) -> None:
|
||||
"""Test that the entity function restores existing state."""
|
||||
mock_agent = Mock()
|
||||
|
||||
entity_function = create_agent_entity(mock_agent)
|
||||
|
||||
existing_state = {
|
||||
"message_count": 5,
|
||||
"conversation_history": [
|
||||
ChatMessage(
|
||||
role="user", text="msg1", additional_properties={"timestamp": "2024-01-01T00:00:00Z"}
|
||||
).to_dict(),
|
||||
ChatMessage(
|
||||
role="assistant", text="resp1", additional_properties={"timestamp": "2024-01-01T00:05:00Z"}
|
||||
).to_dict(),
|
||||
],
|
||||
"last_response": "resp1",
|
||||
}
|
||||
|
||||
mock_context = Mock()
|
||||
mock_context.operation_name = "reset"
|
||||
mock_context.get_state.return_value = existing_state
|
||||
|
||||
with patch.object(AgentState, "restore_state") as restore_state_mock:
|
||||
entity_function(mock_context)
|
||||
|
||||
restore_state_mock.assert_called_once_with(existing_state)
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test suite for error handling in entities."""
|
||||
|
||||
async def test_run_agent_handles_agent_exception(self) -> None:
|
||||
"""Test that run_agent handles agent exceptions."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(side_effect=Exception("Agent failed"))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
result = await entity.run_agent(
|
||||
mock_context, {"message": "Message", "conversation_id": "conv-1", "correlation_id": "corr-entity-error-1"}
|
||||
)
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "error" in result
|
||||
assert "Agent failed" in result["error"]
|
||||
assert result["error_type"] == "Exception"
|
||||
|
||||
async def test_run_agent_handles_value_error(self) -> None:
|
||||
"""Test that run_agent handles ValueError instances."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(side_effect=ValueError("Invalid input"))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
result = await entity.run_agent(
|
||||
mock_context, {"message": "Message", "conversation_id": "conv-1", "correlation_id": "corr-entity-error-2"}
|
||||
)
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert result["error_type"] == "ValueError"
|
||||
assert "Invalid input" in result["error"]
|
||||
|
||||
async def test_run_agent_handles_timeout_error(self) -> None:
|
||||
"""Test that run_agent handles TimeoutError instances."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(side_effect=TimeoutError("Request timeout"))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
result = await entity.run_agent(
|
||||
mock_context, {"message": "Message", "conversation_id": "conv-1", "correlation_id": "corr-entity-error-3"}
|
||||
)
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert result["error_type"] == "TimeoutError"
|
||||
|
||||
def test_entity_function_handles_exception_in_operation(self) -> None:
|
||||
"""Test that the entity function handles exceptions gracefully."""
|
||||
mock_agent = Mock()
|
||||
|
||||
entity_function = create_agent_entity(mock_agent)
|
||||
|
||||
mock_context = Mock()
|
||||
mock_context.operation_name = "run_agent"
|
||||
mock_context.get_input.side_effect = Exception("Input error")
|
||||
mock_context.get_state.return_value = None
|
||||
|
||||
# Execute - should not raise
|
||||
entity_function(mock_context)
|
||||
|
||||
# Verify error was set
|
||||
assert mock_context.set_result.called
|
||||
result = mock_context.set_result.call_args[0][0]
|
||||
assert "error" in result
|
||||
|
||||
async def test_run_agent_preserves_message_on_error(self) -> None:
|
||||
"""Test that run_agent preserves message information on error."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(side_effect=Exception("Error"))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
result = await entity.run_agent(
|
||||
mock_context,
|
||||
{"message": "Test message", "conversation_id": "conv-123", "correlation_id": "corr-entity-error-4"},
|
||||
)
|
||||
|
||||
# Even on error, message info should be preserved
|
||||
assert result["message"] == "Test message"
|
||||
assert result["conversation_id"] == "conv-123"
|
||||
assert result["status"] == "error"
|
||||
|
||||
|
||||
class TestConversationHistory:
|
||||
"""Test suite for conversation history tracking."""
|
||||
|
||||
async def test_conversation_history_has_timestamps(self) -> None:
|
||||
"""Test that conversation history entries include timestamps."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
await entity.run_agent(
|
||||
mock_context, {"message": "Message", "conversation_id": "conv-1", "correlation_id": "corr-entity-history-1"}
|
||||
)
|
||||
|
||||
# Check both user and assistant messages have timestamps
|
||||
for entry in entity.state.conversation_history:
|
||||
timestamp = entry.additional_properties.get("timestamp")
|
||||
assert timestamp is not None
|
||||
# Verify timestamp is in ISO format
|
||||
datetime.fromisoformat(timestamp)
|
||||
|
||||
async def test_conversation_history_ordering(self) -> None:
|
||||
"""Test that conversation history maintains the correct order."""
|
||||
mock_agent = Mock()
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
# Send multiple messages with different responses
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response 1"))
|
||||
await entity.run_agent(
|
||||
mock_context,
|
||||
{"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-entity-history-2a"},
|
||||
)
|
||||
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response 2"))
|
||||
await entity.run_agent(
|
||||
mock_context,
|
||||
{"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-entity-history-2b"},
|
||||
)
|
||||
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response 3"))
|
||||
await entity.run_agent(
|
||||
mock_context,
|
||||
{"message": "Message 3", "conversation_id": "conv-1", "correlation_id": "corr-entity-history-2c"},
|
||||
)
|
||||
|
||||
# Verify order
|
||||
history = entity.state.conversation_history
|
||||
assert history[0].text == "Message 1"
|
||||
assert history[1].text == "Response 1"
|
||||
assert history[2].text == "Message 2"
|
||||
assert history[3].text == "Response 2"
|
||||
assert history[4].text == "Message 3"
|
||||
assert history[5].text == "Response 3"
|
||||
|
||||
async def test_conversation_history_role_alternation(self) -> None:
|
||||
"""Test that conversation history alternates between user and assistant roles."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
await entity.run_agent(
|
||||
mock_context,
|
||||
{"message": "Message 1", "conversation_id": "conv-1", "correlation_id": "corr-entity-history-3a"},
|
||||
)
|
||||
await entity.run_agent(
|
||||
mock_context,
|
||||
{"message": "Message 2", "conversation_id": "conv-1", "correlation_id": "corr-entity-history-3b"},
|
||||
)
|
||||
|
||||
# Check role alternation
|
||||
history = entity.state.conversation_history
|
||||
assert _role_value(history[0]) == "user"
|
||||
assert _role_value(history[1]) == "assistant"
|
||||
assert _role_value(history[2]) == "user"
|
||||
assert _role_value(history[3]) == "assistant"
|
||||
|
||||
|
||||
class TestRunRequestSupport:
|
||||
"""Test suite for RunRequest support in entities."""
|
||||
|
||||
async def test_run_agent_with_run_request_object(self) -> None:
|
||||
"""Test run_agent with a RunRequest object."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
request = RunRequest(
|
||||
message="Test message",
|
||||
conversation_id="conv-123",
|
||||
role=ChatRole.USER,
|
||||
enable_tool_calls=True,
|
||||
correlation_id="corr-runreq-1",
|
||||
)
|
||||
|
||||
result = await entity.run_agent(mock_context, request)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["response"] == "Response"
|
||||
assert result["message"] == "Test message"
|
||||
assert result["conversation_id"] == "conv-123"
|
||||
|
||||
async def test_run_agent_with_dict_request(self) -> None:
|
||||
"""Test run_agent with a dictionary request."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
request_dict = {
|
||||
"message": "Test message",
|
||||
"conversation_id": "conv-456",
|
||||
"role": "system",
|
||||
"enable_tool_calls": False,
|
||||
"correlation_id": "corr-runreq-2",
|
||||
}
|
||||
|
||||
result = await entity.run_agent(mock_context, request_dict)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["message"] == "Test message"
|
||||
assert result["conversation_id"] == "conv-456"
|
||||
|
||||
async def test_run_agent_with_string_raises_without_correlation(self) -> None:
|
||||
"""Test that run_agent rejects legacy string input without correlation ID."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await entity.run_agent(mock_context, "Simple message")
|
||||
|
||||
async def test_run_agent_stores_role_in_history(self) -> None:
|
||||
"""Test that run_agent stores the role in conversation history."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
# Send as system role
|
||||
request = RunRequest(
|
||||
message="System message",
|
||||
conversation_id="conv-runreq-3",
|
||||
role=ChatRole.SYSTEM,
|
||||
correlation_id="corr-runreq-3",
|
||||
)
|
||||
|
||||
await entity.run_agent(mock_context, request)
|
||||
|
||||
# Check that system role was stored
|
||||
history = entity.state.conversation_history
|
||||
assert _role_value(history[0]) == "system"
|
||||
assert history[0].text == "System message"
|
||||
|
||||
async def test_run_agent_with_response_format(self) -> None:
|
||||
"""Test run_agent with a JSON response format."""
|
||||
mock_agent = Mock()
|
||||
# Return JSON response
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response('{"answer": 42}'))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
request = RunRequest(
|
||||
message="What is the answer?",
|
||||
conversation_id="conv-runreq-4",
|
||||
response_format=EntityStructuredResponse,
|
||||
correlation_id="corr-runreq-4",
|
||||
)
|
||||
|
||||
result = await entity.run_agent(mock_context, request)
|
||||
|
||||
assert result["status"] == "success"
|
||||
# Should have structured_response
|
||||
if "structured_response" in result:
|
||||
assert result["structured_response"]["answer"] == 42
|
||||
|
||||
async def test_run_agent_disable_tool_calls(self) -> None:
|
||||
"""Test run_agent with tool calls disabled."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
request = RunRequest(
|
||||
message="Test", conversation_id="conv-runreq-5", enable_tool_calls=False, correlation_id="corr-runreq-5"
|
||||
)
|
||||
|
||||
result = await entity.run_agent(mock_context, request)
|
||||
|
||||
assert result["status"] == "success"
|
||||
# Agent should have been called (tool disabling is framework-dependent)
|
||||
mock_agent.run.assert_called_once()
|
||||
|
||||
async def test_entity_function_with_run_request_dict(self) -> None:
|
||||
"""Test that the entity function handles the RunRequest dict format."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity_function = create_agent_entity(mock_agent)
|
||||
|
||||
mock_context = Mock()
|
||||
mock_context.operation_name = "run_agent"
|
||||
mock_context.get_input.return_value = {
|
||||
"message": "Test message",
|
||||
"conversation_id": "conv-789",
|
||||
"role": "user",
|
||||
"enable_tool_calls": True,
|
||||
"correlation_id": "corr-runreq-6",
|
||||
}
|
||||
mock_context.get_state.return_value = None
|
||||
|
||||
await asyncio.to_thread(entity_function, mock_context)
|
||||
|
||||
# Verify result was set
|
||||
assert mock_context.set_result.called
|
||||
result = mock_context.set_result.call_args[0][0]
|
||||
assert result["status"] == "success"
|
||||
assert result["message"] == "Test message"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
@@ -0,0 +1,478 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Unit tests for data models (AgentSessionId, RunRequest, AgentResponse, ChatRole)."""
|
||||
|
||||
import azure.durable_functions as df
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agent_framework_azurefunctions._models import AgentResponse, AgentSessionId, ChatRole, RunRequest
|
||||
|
||||
|
||||
class ModuleStructuredResponse(BaseModel):
|
||||
value: int
|
||||
|
||||
|
||||
class TestChatRole:
|
||||
"""Test suite for ChatRole enum."""
|
||||
|
||||
def test_chat_role_values(self) -> None:
|
||||
"""Test that ChatRole has correct values."""
|
||||
assert ChatRole.USER == "user"
|
||||
assert ChatRole.SYSTEM == "system"
|
||||
assert ChatRole.ASSISTANT == "assistant"
|
||||
|
||||
def test_chat_role_is_string(self) -> None:
|
||||
"""Test that ChatRole values are strings."""
|
||||
assert isinstance(ChatRole.USER.value, str)
|
||||
assert isinstance(ChatRole.SYSTEM.value, str)
|
||||
assert isinstance(ChatRole.ASSISTANT.value, str)
|
||||
|
||||
|
||||
class TestAgentSessionId:
|
||||
"""Test suite for AgentSessionId."""
|
||||
|
||||
def test_init_creates_session_id(self) -> None:
|
||||
"""Test that AgentSessionId initializes correctly."""
|
||||
session_id = AgentSessionId(name="AgentEntity", key="test-key-123")
|
||||
|
||||
assert session_id.name == "AgentEntity"
|
||||
assert session_id.key == "test-key-123"
|
||||
|
||||
def test_with_random_key_generates_guid(self) -> None:
|
||||
"""Test that with_random_key generates a GUID."""
|
||||
session_id = AgentSessionId.with_random_key(name="AgentEntity")
|
||||
|
||||
assert session_id.name == "AgentEntity"
|
||||
assert len(session_id.key) == 32 # UUID hex is 32 chars
|
||||
# Verify it's a valid hex string
|
||||
int(session_id.key, 16)
|
||||
|
||||
def test_with_random_key_unique_keys(self) -> None:
|
||||
"""Test that with_random_key generates unique keys."""
|
||||
session_id1 = AgentSessionId.with_random_key(name="AgentEntity")
|
||||
session_id2 = AgentSessionId.with_random_key(name="AgentEntity")
|
||||
|
||||
assert session_id1.key != session_id2.key
|
||||
|
||||
def test_to_entity_id_conversion(self) -> None:
|
||||
"""Test conversion to EntityId."""
|
||||
session_id = AgentSessionId(name="AgentEntity", key="test-key")
|
||||
entity_id = session_id.to_entity_id()
|
||||
|
||||
assert isinstance(entity_id, df.EntityId)
|
||||
assert entity_id.name == "dafx-AgentEntity"
|
||||
assert entity_id.key == "test-key"
|
||||
|
||||
def test_from_entity_id_conversion(self) -> None:
|
||||
"""Test creation from EntityId."""
|
||||
entity_id = df.EntityId(name="dafx-AgentEntity", key="test-key")
|
||||
session_id = AgentSessionId.from_entity_id(entity_id)
|
||||
|
||||
assert isinstance(session_id, AgentSessionId)
|
||||
assert session_id.name == "AgentEntity"
|
||||
assert session_id.key == "test-key"
|
||||
|
||||
def test_round_trip_entity_id_conversion(self) -> None:
|
||||
"""Test round-trip conversion to and from EntityId."""
|
||||
original = AgentSessionId(name="AgentEntity", key="test-key")
|
||||
entity_id = original.to_entity_id()
|
||||
restored = AgentSessionId.from_entity_id(entity_id)
|
||||
|
||||
assert restored.name == original.name
|
||||
assert restored.key == original.key
|
||||
|
||||
def test_str_representation(self) -> None:
|
||||
"""Test string representation."""
|
||||
session_id = AgentSessionId(name="AgentEntity", key="test-key-123")
|
||||
str_repr = str(session_id)
|
||||
|
||||
assert str_repr == "@AgentEntity@test-key-123"
|
||||
|
||||
def test_repr_representation(self) -> None:
|
||||
"""Test repr representation."""
|
||||
session_id = AgentSessionId(name="AgentEntity", key="test-key")
|
||||
repr_str = repr(session_id)
|
||||
|
||||
assert "AgentSessionId" in repr_str
|
||||
assert "AgentEntity" in repr_str
|
||||
assert "test-key" in repr_str
|
||||
|
||||
def test_parse_valid_session_id(self) -> None:
|
||||
"""Test parsing valid session ID string."""
|
||||
session_id = AgentSessionId.parse("@AgentEntity@test-key-123")
|
||||
|
||||
assert session_id.name == "AgentEntity"
|
||||
assert session_id.key == "test-key-123"
|
||||
|
||||
def test_parse_invalid_format_no_prefix(self) -> None:
|
||||
"""Test parsing invalid format without @ prefix."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AgentSessionId.parse("AgentEntity@test-key")
|
||||
|
||||
assert "Invalid agent session ID format" in str(exc_info.value)
|
||||
|
||||
def test_parse_invalid_format_single_part(self) -> None:
|
||||
"""Test parsing invalid format with single part."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AgentSessionId.parse("@AgentEntity")
|
||||
|
||||
assert "Invalid agent session ID format" in str(exc_info.value)
|
||||
|
||||
def test_parse_with_multiple_at_signs_in_key(self) -> None:
|
||||
"""Test parsing with @ signs in the key."""
|
||||
session_id = AgentSessionId.parse("@AgentEntity@key-with@symbols")
|
||||
|
||||
assert session_id.name == "AgentEntity"
|
||||
assert session_id.key == "key-with@symbols"
|
||||
|
||||
def test_parse_round_trip(self) -> None:
|
||||
"""Test round-trip parse and string conversion."""
|
||||
original = AgentSessionId(name="AgentEntity", key="test-key")
|
||||
str_repr = str(original)
|
||||
parsed = AgentSessionId.parse(str_repr)
|
||||
|
||||
assert parsed.name == original.name
|
||||
assert parsed.key == original.key
|
||||
|
||||
def test_to_entity_name_adds_prefix(self) -> None:
|
||||
"""Test that to_entity_name adds the dafx- prefix."""
|
||||
entity_name = AgentSessionId.to_entity_name("TestAgent")
|
||||
assert entity_name == "dafx-TestAgent"
|
||||
|
||||
def test_from_entity_id_strips_prefix(self) -> None:
|
||||
"""Test that from_entity_id strips the dafx- prefix."""
|
||||
entity_id = df.EntityId(name="dafx-TestAgent", key="key123")
|
||||
session_id = AgentSessionId.from_entity_id(entity_id)
|
||||
|
||||
assert session_id.name == "TestAgent"
|
||||
assert session_id.key == "key123"
|
||||
|
||||
def test_from_entity_id_raises_without_prefix(self) -> None:
|
||||
"""Test that from_entity_id raises ValueError when entity name lacks the prefix."""
|
||||
entity_id = df.EntityId(name="TestAgent", key="key123")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AgentSessionId.from_entity_id(entity_id)
|
||||
|
||||
assert "not a valid agent session ID" in str(exc_info.value)
|
||||
assert "dafx-" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestRunRequest:
|
||||
"""Test suite for RunRequest."""
|
||||
|
||||
def test_init_with_defaults(self) -> None:
|
||||
"""Test RunRequest initialization with defaults."""
|
||||
request = RunRequest(message="Hello", conversation_id="conv-default")
|
||||
|
||||
assert request.message == "Hello"
|
||||
assert request.role == ChatRole.USER
|
||||
assert request.response_format is None
|
||||
assert request.enable_tool_calls is True
|
||||
assert request.conversation_id == "conv-default"
|
||||
|
||||
def test_init_with_all_fields(self) -> None:
|
||||
"""Test RunRequest initialization with all fields."""
|
||||
schema = ModuleStructuredResponse
|
||||
request = RunRequest(
|
||||
message="Hello",
|
||||
conversation_id="conv-123",
|
||||
role=ChatRole.SYSTEM,
|
||||
response_format=schema,
|
||||
enable_tool_calls=False,
|
||||
)
|
||||
|
||||
assert request.message == "Hello"
|
||||
assert request.role == ChatRole.SYSTEM
|
||||
assert request.response_format is schema
|
||||
assert request.enable_tool_calls is False
|
||||
assert request.conversation_id == "conv-123"
|
||||
|
||||
def test_to_dict_with_defaults(self) -> None:
|
||||
"""Test to_dict with default values."""
|
||||
request = RunRequest(message="Test message", conversation_id="conv-to-dict")
|
||||
data = request.to_dict()
|
||||
|
||||
assert data["message"] == "Test message"
|
||||
assert data["enable_tool_calls"] is True
|
||||
assert data["role"] == "user"
|
||||
assert "response_format" not in data or data["response_format"] is None
|
||||
assert data["conversation_id"] == "conv-to-dict"
|
||||
|
||||
def test_to_dict_with_all_fields(self) -> None:
|
||||
"""Test to_dict with all fields."""
|
||||
schema = ModuleStructuredResponse
|
||||
request = RunRequest(
|
||||
message="Hello",
|
||||
conversation_id="conv-456",
|
||||
role=ChatRole.ASSISTANT,
|
||||
response_format=schema,
|
||||
enable_tool_calls=False,
|
||||
)
|
||||
data = request.to_dict()
|
||||
|
||||
assert data["message"] == "Hello"
|
||||
assert data["role"] == "assistant"
|
||||
assert data["response_format"]["__response_schema_type__"] == "pydantic_model"
|
||||
assert data["response_format"]["module"] == schema.__module__
|
||||
assert data["response_format"]["qualname"] == schema.__qualname__
|
||||
assert data["enable_tool_calls"] is False
|
||||
assert data["conversation_id"] == "conv-456"
|
||||
|
||||
def test_from_dict_with_defaults(self) -> None:
|
||||
"""Test from_dict with minimal data."""
|
||||
data = {"message": "Hello", "conversation_id": "conv-from-dict"}
|
||||
request = RunRequest.from_dict(data)
|
||||
|
||||
assert request.message == "Hello"
|
||||
assert request.role == ChatRole.USER
|
||||
assert request.enable_tool_calls is True
|
||||
assert request.conversation_id == "conv-from-dict"
|
||||
|
||||
def test_from_dict_with_all_fields(self) -> None:
|
||||
"""Test from_dict with all fields."""
|
||||
data = {
|
||||
"message": "Test",
|
||||
"role": "system",
|
||||
"response_format": {
|
||||
"__response_schema_type__": "pydantic_model",
|
||||
"module": ModuleStructuredResponse.__module__,
|
||||
"qualname": ModuleStructuredResponse.__qualname__,
|
||||
},
|
||||
"enable_tool_calls": False,
|
||||
"conversation_id": "conv-789",
|
||||
}
|
||||
request = RunRequest.from_dict(data)
|
||||
|
||||
assert request.message == "Test"
|
||||
assert request.role == ChatRole.SYSTEM
|
||||
assert request.response_format is ModuleStructuredResponse
|
||||
assert request.enable_tool_calls is False
|
||||
assert request.conversation_id == "conv-789"
|
||||
|
||||
def test_from_dict_invalid_role_defaults_to_user(self) -> None:
|
||||
"""Test from_dict with invalid role defaults to USER."""
|
||||
data = {"message": "Test", "role": "invalid_role", "conversation_id": "conv-invalid-role"}
|
||||
request = RunRequest.from_dict(data)
|
||||
|
||||
assert request.role == ChatRole.USER
|
||||
|
||||
def test_from_dict_empty_message(self) -> None:
|
||||
"""Test from_dict with empty message."""
|
||||
data = {"conversation_id": "conv-empty"}
|
||||
request = RunRequest.from_dict(data)
|
||||
|
||||
assert request.message == ""
|
||||
assert request.role == ChatRole.USER
|
||||
assert request.conversation_id == "conv-empty"
|
||||
|
||||
def test_round_trip_dict_conversion(self) -> None:
|
||||
"""Test round-trip to_dict and from_dict."""
|
||||
original = RunRequest(
|
||||
message="Test message",
|
||||
conversation_id="conv-123",
|
||||
role=ChatRole.SYSTEM,
|
||||
response_format=ModuleStructuredResponse,
|
||||
enable_tool_calls=False,
|
||||
)
|
||||
|
||||
data = original.to_dict()
|
||||
restored = RunRequest.from_dict(data)
|
||||
|
||||
assert restored.message == original.message
|
||||
assert restored.role == original.role
|
||||
assert restored.response_format is ModuleStructuredResponse
|
||||
assert restored.enable_tool_calls == original.enable_tool_calls
|
||||
assert restored.conversation_id == original.conversation_id
|
||||
|
||||
def test_round_trip_with_pydantic_response_format(self) -> None:
|
||||
"""Ensure Pydantic response formats serialize and deserialize properly."""
|
||||
original = RunRequest(
|
||||
message="Structured",
|
||||
conversation_id="conv-pydantic",
|
||||
response_format=ModuleStructuredResponse,
|
||||
)
|
||||
|
||||
data = original.to_dict()
|
||||
|
||||
assert data["response_format"]["__response_schema_type__"] == "pydantic_model"
|
||||
assert data["response_format"]["module"] == ModuleStructuredResponse.__module__
|
||||
assert data["response_format"]["qualname"] == ModuleStructuredResponse.__qualname__
|
||||
|
||||
restored = RunRequest.from_dict(data)
|
||||
assert restored.response_format is ModuleStructuredResponse
|
||||
|
||||
def test_init_with_correlation_id(self) -> None:
|
||||
"""Test RunRequest initialization with correlation_id."""
|
||||
request = RunRequest(message="Test message", conversation_id="conv-corr-init", correlation_id="corr-123")
|
||||
|
||||
assert request.message == "Test message"
|
||||
assert request.correlation_id == "corr-123"
|
||||
|
||||
def test_to_dict_with_correlation_id(self) -> None:
|
||||
"""Test to_dict includes correlation_id."""
|
||||
request = RunRequest(message="Test", conversation_id="conv-corr-to-dict", correlation_id="corr-456")
|
||||
data = request.to_dict()
|
||||
|
||||
assert data["message"] == "Test"
|
||||
assert data["correlation_id"] == "corr-456"
|
||||
|
||||
def test_from_dict_with_correlation_id(self) -> None:
|
||||
"""Test from_dict with correlation_id."""
|
||||
data = {"message": "Test", "correlation_id": "corr-789", "conversation_id": "conv-corr-from-dict"}
|
||||
request = RunRequest.from_dict(data)
|
||||
|
||||
assert request.message == "Test"
|
||||
assert request.correlation_id == "corr-789"
|
||||
assert request.conversation_id == "conv-corr-from-dict"
|
||||
|
||||
def test_round_trip_with_correlation_id(self) -> None:
|
||||
"""Test round-trip to_dict and from_dict with correlation_id."""
|
||||
original = RunRequest(
|
||||
message="Test message",
|
||||
conversation_id="conv-123",
|
||||
role=ChatRole.SYSTEM,
|
||||
correlation_id="corr-123",
|
||||
)
|
||||
|
||||
data = original.to_dict()
|
||||
restored = RunRequest.from_dict(data)
|
||||
|
||||
assert restored.message == original.message
|
||||
assert restored.role == original.role
|
||||
assert restored.correlation_id == original.correlation_id
|
||||
assert restored.conversation_id == original.conversation_id
|
||||
|
||||
|
||||
class TestAgentResponse:
|
||||
"""Test suite for AgentResponse."""
|
||||
|
||||
def test_init_with_required_fields(self) -> None:
|
||||
"""Test AgentResponse initialization with required fields."""
|
||||
response = AgentResponse(
|
||||
response="Test response", message="Test message", conversation_id="conv-123", status="success"
|
||||
)
|
||||
|
||||
assert response.response == "Test response"
|
||||
assert response.message == "Test message"
|
||||
assert response.conversation_id == "conv-123"
|
||||
assert response.status == "success"
|
||||
assert response.message_count == 0
|
||||
assert response.error is None
|
||||
assert response.error_type is None
|
||||
assert response.structured_response is None
|
||||
|
||||
def test_init_with_all_fields(self) -> None:
|
||||
"""Test AgentResponse initialization with all fields."""
|
||||
structured = {"answer": "42"}
|
||||
response = AgentResponse(
|
||||
response=None,
|
||||
message="What is the answer?",
|
||||
conversation_id="conv-456",
|
||||
status="success",
|
||||
message_count=5,
|
||||
error=None,
|
||||
error_type=None,
|
||||
structured_response=structured,
|
||||
)
|
||||
|
||||
assert response.response is None
|
||||
assert response.structured_response == structured
|
||||
assert response.message_count == 5
|
||||
|
||||
def test_to_dict_with_text_response(self) -> None:
|
||||
"""Test to_dict with text response."""
|
||||
response = AgentResponse(
|
||||
response="Text response", message="Message", conversation_id="conv-1", status="success", message_count=3
|
||||
)
|
||||
data = response.to_dict()
|
||||
|
||||
assert data["response"] == "Text response"
|
||||
assert data["message"] == "Message"
|
||||
assert data["conversation_id"] == "conv-1"
|
||||
assert data["status"] == "success"
|
||||
assert data["message_count"] == 3
|
||||
assert "structured_response" not in data
|
||||
assert "error" not in data
|
||||
assert "error_type" not in data
|
||||
|
||||
def test_to_dict_with_structured_response(self) -> None:
|
||||
"""Test to_dict with structured response."""
|
||||
structured = {"answer": 42, "confidence": 0.95}
|
||||
response = AgentResponse(
|
||||
response=None,
|
||||
message="Question",
|
||||
conversation_id="conv-2",
|
||||
status="success",
|
||||
structured_response=structured,
|
||||
)
|
||||
data = response.to_dict()
|
||||
|
||||
assert data["structured_response"] == structured
|
||||
assert "response" not in data
|
||||
|
||||
def test_to_dict_with_error(self) -> None:
|
||||
"""Test to_dict with error."""
|
||||
response = AgentResponse(
|
||||
response=None,
|
||||
message="Failed message",
|
||||
conversation_id="conv-3",
|
||||
status="error",
|
||||
error="Something went wrong",
|
||||
error_type="ValueError",
|
||||
)
|
||||
data = response.to_dict()
|
||||
|
||||
assert data["status"] == "error"
|
||||
assert data["error"] == "Something went wrong"
|
||||
assert data["error_type"] == "ValueError"
|
||||
|
||||
def test_to_dict_prefers_structured_over_text(self) -> None:
|
||||
"""Test to_dict prefers structured_response over response."""
|
||||
structured = {"result": "structured"}
|
||||
response = AgentResponse(
|
||||
response="Text response",
|
||||
message="Message",
|
||||
conversation_id="conv-4",
|
||||
status="success",
|
||||
structured_response=structured,
|
||||
)
|
||||
data = response.to_dict()
|
||||
|
||||
assert "structured_response" in data
|
||||
assert data["structured_response"] == structured
|
||||
# Text response should not be included when structured is present
|
||||
assert "response" not in data
|
||||
|
||||
|
||||
class TestModelIntegration:
|
||||
"""Test suite for integration between models."""
|
||||
|
||||
def test_run_request_with_session_id(self) -> None:
|
||||
"""Test using RunRequest with AgentSessionId."""
|
||||
session_id = AgentSessionId.with_random_key("AgentEntity")
|
||||
request = RunRequest(message="Test message", conversation_id=str(session_id))
|
||||
|
||||
assert request.conversation_id is not None
|
||||
assert request.conversation_id == str(session_id)
|
||||
assert request.conversation_id.startswith("@AgentEntity@")
|
||||
|
||||
def test_response_from_run_request(self) -> None:
|
||||
"""Test creating AgentResponse from RunRequest."""
|
||||
request = RunRequest(message="What is 2+2?", conversation_id="conv-123", role=ChatRole.USER)
|
||||
|
||||
response = AgentResponse(
|
||||
response="4",
|
||||
message=request.message,
|
||||
conversation_id=request.conversation_id,
|
||||
status="success",
|
||||
message_count=1,
|
||||
)
|
||||
|
||||
assert response.message == request.message
|
||||
assert response.conversation_id == request.conversation_id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
@@ -0,0 +1,150 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Unit tests for multi-agent support in AgentFunctionApp."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework_azurefunctions import AgentFunctionApp
|
||||
|
||||
|
||||
class TestMultiAgentInit:
|
||||
"""Test suite for multi-agent initialization."""
|
||||
|
||||
def test_init_with_agents_list(self) -> None:
|
||||
"""Test initialization with list of agents."""
|
||||
agent1 = Mock()
|
||||
agent1.name = "Agent1"
|
||||
agent2 = Mock()
|
||||
agent2.name = "Agent2"
|
||||
|
||||
app = AgentFunctionApp(agents=[agent1, agent2])
|
||||
|
||||
assert len(app.agents) == 2
|
||||
assert "Agent1" in app.agents
|
||||
assert "Agent2" in app.agents
|
||||
assert app.agents["Agent1"] == agent1
|
||||
assert app.agents["Agent2"] == agent2
|
||||
|
||||
def test_init_with_empty_agents_list(self) -> None:
|
||||
"""Test initialization with empty list of agents."""
|
||||
app = AgentFunctionApp(agents=[])
|
||||
|
||||
assert len(app.agents) == 0
|
||||
|
||||
def test_init_with_no_agents(self) -> None:
|
||||
"""Test initialization without any agents."""
|
||||
app = AgentFunctionApp()
|
||||
|
||||
assert len(app.agents) == 0
|
||||
|
||||
def test_init_with_duplicate_agent_names(self) -> None:
|
||||
"""Test initialization with agents having the same name raises error."""
|
||||
agent1 = Mock()
|
||||
agent1.name = "TestAgent"
|
||||
agent2 = Mock()
|
||||
agent2.name = "TestAgent"
|
||||
|
||||
with pytest.raises(ValueError, match="already registered"):
|
||||
AgentFunctionApp(agents=[agent1, agent2])
|
||||
|
||||
def test_init_with_agent_without_name(self) -> None:
|
||||
"""Test initialization with agent missing name attribute raises error."""
|
||||
agent1 = Mock()
|
||||
agent1.name = "Agent1"
|
||||
agent2 = Mock(spec=[]) # Mock without name attribute
|
||||
|
||||
with pytest.raises(ValueError, match="does not have a 'name' attribute"):
|
||||
AgentFunctionApp(agents=[agent1, agent2])
|
||||
|
||||
|
||||
class TestAddAgentMethod:
|
||||
"""Test suite for add_agent() method."""
|
||||
|
||||
def test_add_agent_to_empty_app(self) -> None:
|
||||
"""Test adding agent to app initialized without agents."""
|
||||
app = AgentFunctionApp()
|
||||
|
||||
agent = Mock()
|
||||
agent.name = "NewAgent"
|
||||
|
||||
app.add_agent(agent)
|
||||
|
||||
assert len(app.agents) == 1
|
||||
assert "NewAgent" in app.agents
|
||||
assert app.agents["NewAgent"] == agent
|
||||
|
||||
def test_add_multiple_agents(self) -> None:
|
||||
"""Test adding multiple agents sequentially."""
|
||||
app = AgentFunctionApp()
|
||||
|
||||
agent1 = Mock()
|
||||
agent1.name = "Agent1"
|
||||
agent2 = Mock()
|
||||
agent2.name = "Agent2"
|
||||
|
||||
app.add_agent(agent1)
|
||||
app.add_agent(agent2)
|
||||
|
||||
assert len(app.agents) == 2
|
||||
assert "Agent1" in app.agents
|
||||
assert "Agent2" in app.agents
|
||||
|
||||
def test_add_agent_with_duplicate_name_raises_error(self) -> None:
|
||||
"""Test that adding agent with duplicate name raises ValueError."""
|
||||
agent1 = Mock()
|
||||
agent1.name = "MyAgent"
|
||||
agent2 = Mock()
|
||||
agent2.name = "MyAgent"
|
||||
|
||||
app = AgentFunctionApp(agents=[agent1])
|
||||
|
||||
# Try to add another agent with the same name
|
||||
with pytest.raises(ValueError, match="already registered"):
|
||||
app.add_agent(agent2)
|
||||
|
||||
def test_add_agent_to_app_with_existing_agents(self) -> None:
|
||||
"""Test adding agent to app that already has agents."""
|
||||
agent1 = Mock()
|
||||
agent1.name = "Agent1"
|
||||
agent2 = Mock()
|
||||
agent2.name = "Agent2"
|
||||
|
||||
app = AgentFunctionApp(agents=[agent1])
|
||||
app.add_agent(agent2)
|
||||
|
||||
assert len(app.agents) == 2
|
||||
assert "Agent1" in app.agents
|
||||
assert "Agent2" in app.agents
|
||||
|
||||
def test_add_agent_without_name_raises_error(self) -> None:
|
||||
"""Test that adding agent without name attribute raises error."""
|
||||
app = AgentFunctionApp()
|
||||
|
||||
agent = Mock(spec=[]) # Mock without name attribute
|
||||
|
||||
with pytest.raises(ValueError, match="does not have a 'name' attribute"):
|
||||
app.add_agent(agent)
|
||||
|
||||
|
||||
class TestHealthCheckWithMultipleAgents:
|
||||
"""Test suite for health check with multiple agents."""
|
||||
|
||||
def test_health_check_returns_all_agents(self) -> None:
|
||||
"""Test that health check returns information about all agents."""
|
||||
agent1 = Mock()
|
||||
agent1.name = "Agent1"
|
||||
agent2 = Mock()
|
||||
agent2.name = "Agent2"
|
||||
|
||||
app = AgentFunctionApp(agents=[agent1, agent2])
|
||||
|
||||
# Note: We can't easily test the actual health check endpoint without running the app
|
||||
# But we can verify the agents dictionary is properly populated
|
||||
assert len(app.agents) == 2
|
||||
assert app.enable_health_check is True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
@@ -0,0 +1,425 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Unit tests for orchestration support (DurableAIAgent)."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from agent_framework import AgentThread
|
||||
|
||||
from agent_framework_azurefunctions import DurableAIAgent, get_agent
|
||||
from agent_framework_azurefunctions._models import AgentSessionId, DurableAgentThread
|
||||
|
||||
|
||||
class TestDurableAIAgent:
|
||||
"""Test suite for DurableAIAgent wrapper."""
|
||||
|
||||
def test_init(self) -> None:
|
||||
"""Test DurableAIAgent initialization."""
|
||||
mock_context = Mock()
|
||||
mock_context.instance_id = "test-instance-123"
|
||||
|
||||
agent = DurableAIAgent(mock_context, "TestAgent")
|
||||
|
||||
assert agent.context == mock_context
|
||||
assert agent.agent_name == "TestAgent"
|
||||
|
||||
def test_implements_agent_protocol(self) -> None:
|
||||
"""Test that DurableAIAgent implements AgentProtocol."""
|
||||
from agent_framework import AgentProtocol
|
||||
|
||||
mock_context = Mock()
|
||||
agent = DurableAIAgent(mock_context, "TestAgent")
|
||||
|
||||
# Check that agent satisfies AgentProtocol
|
||||
assert isinstance(agent, AgentProtocol)
|
||||
|
||||
def test_has_agent_protocol_properties(self) -> None:
|
||||
"""Test that DurableAIAgent has AgentProtocol properties."""
|
||||
mock_context = Mock()
|
||||
agent = DurableAIAgent(mock_context, "TestAgent")
|
||||
|
||||
# AgentProtocol properties
|
||||
assert hasattr(agent, "id")
|
||||
assert hasattr(agent, "name")
|
||||
assert hasattr(agent, "description")
|
||||
assert hasattr(agent, "display_name")
|
||||
|
||||
# Verify values
|
||||
assert agent.name == "TestAgent"
|
||||
assert agent.description == "Durable agent proxy for TestAgent"
|
||||
assert agent.display_name == "TestAgent"
|
||||
assert agent.id is not None # Auto-generated UUID
|
||||
|
||||
def test_get_new_thread(self) -> None:
|
||||
"""Test creating a new agent thread."""
|
||||
mock_context = Mock()
|
||||
mock_context.instance_id = "test-instance-456"
|
||||
mock_context.new_uuid = Mock(return_value="test-guid-456")
|
||||
|
||||
agent = DurableAIAgent(mock_context, "WriterAgent")
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
assert isinstance(thread, DurableAgentThread)
|
||||
assert thread.session_id is not None
|
||||
session_id = thread.session_id
|
||||
assert isinstance(session_id, AgentSessionId)
|
||||
assert session_id.name == "WriterAgent"
|
||||
assert session_id.key == "test-guid-456"
|
||||
mock_context.new_uuid.assert_called_once()
|
||||
|
||||
def test_get_new_thread_deterministic(self) -> None:
|
||||
"""Test that get_new_thread creates deterministic session IDs."""
|
||||
|
||||
mock_context = Mock()
|
||||
mock_context.instance_id = "test-instance-789"
|
||||
mock_context.new_uuid = Mock(side_effect=["session-guid-1", "session-guid-2"])
|
||||
|
||||
agent = DurableAIAgent(mock_context, "EditorAgent")
|
||||
|
||||
# Create multiple threads - they should have unique session IDs
|
||||
thread1 = agent.get_new_thread()
|
||||
thread2 = agent.get_new_thread()
|
||||
|
||||
assert isinstance(thread1, DurableAgentThread)
|
||||
assert isinstance(thread2, DurableAgentThread)
|
||||
|
||||
session_id1 = thread1.session_id
|
||||
session_id2 = thread2.session_id
|
||||
assert session_id1 is not None and session_id2 is not None
|
||||
assert isinstance(session_id1, AgentSessionId)
|
||||
assert isinstance(session_id2, AgentSessionId)
|
||||
assert session_id1.name == "EditorAgent"
|
||||
assert session_id2.name == "EditorAgent"
|
||||
assert session_id1.key == "session-guid-1"
|
||||
assert session_id2.key == "session-guid-2"
|
||||
assert mock_context.new_uuid.call_count == 2
|
||||
|
||||
def test_run_creates_entity_call(self) -> None:
|
||||
"""Test that run() creates proper entity call and returns a Task."""
|
||||
mock_context = Mock()
|
||||
mock_context.instance_id = "test-instance-001"
|
||||
mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"])
|
||||
|
||||
# Mock call_entity to return a Task-like object
|
||||
mock_task = Mock()
|
||||
mock_task._is_scheduled = False # Task attribute that orchestration checks
|
||||
|
||||
mock_context.call_entity = Mock(return_value=mock_task)
|
||||
|
||||
agent = DurableAIAgent(mock_context, "TestAgent")
|
||||
|
||||
# Create thread
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
# Call run() - it should return the Task directly
|
||||
task = agent.run(messages="Test message", thread=thread, enable_tool_calls=True)
|
||||
|
||||
# Verify run() returns the Task from call_entity
|
||||
assert task == mock_task
|
||||
|
||||
# Verify call_entity was called with correct parameters
|
||||
assert mock_context.call_entity.called
|
||||
call_args = mock_context.call_entity.call_args
|
||||
entity_id, operation, request = call_args[0]
|
||||
|
||||
assert operation == "run_agent"
|
||||
assert request["message"] == "Test message"
|
||||
assert request["enable_tool_calls"] is True
|
||||
assert "correlation_id" in request
|
||||
assert request["correlation_id"] == "correlation-guid"
|
||||
assert "conversation_id" in request
|
||||
assert request["conversation_id"] == "thread-guid"
|
||||
|
||||
def test_run_without_thread(self) -> None:
|
||||
"""Test that run() works without explicit thread (creates unique session key)."""
|
||||
mock_context = Mock()
|
||||
mock_context.instance_id = "test-instance-002"
|
||||
# Two calls to new_uuid: one for session_key, one for correlation_id
|
||||
mock_context.new_uuid = Mock(side_effect=["auto-generated-guid", "correlation-guid"])
|
||||
|
||||
mock_task = Mock()
|
||||
mock_task._is_scheduled = False
|
||||
mock_context.call_entity = Mock(return_value=mock_task)
|
||||
|
||||
agent = DurableAIAgent(mock_context, "TestAgent")
|
||||
|
||||
# Call without thread
|
||||
task = agent.run(messages="Test message")
|
||||
|
||||
assert task == mock_task
|
||||
|
||||
# Verify the entity ID uses the auto-generated GUID with dafx- prefix
|
||||
call_args = mock_context.call_entity.call_args
|
||||
entity_id = call_args[0][0]
|
||||
assert entity_id.name == "dafx-TestAgent"
|
||||
assert entity_id.key == "auto-generated-guid"
|
||||
# Should be called twice: once for session_key, once for correlation_id
|
||||
assert mock_context.new_uuid.call_count == 2
|
||||
|
||||
def test_run_with_response_format(self) -> None:
|
||||
"""Test that run() passes response format correctly."""
|
||||
mock_context = Mock()
|
||||
mock_context.instance_id = "test-instance-003"
|
||||
|
||||
mock_task = Mock()
|
||||
mock_task._is_scheduled = False
|
||||
mock_context.call_entity = Mock(return_value=mock_task)
|
||||
|
||||
agent = DurableAIAgent(mock_context, "TestAgent")
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
class SampleSchema(BaseModel):
|
||||
key: str
|
||||
|
||||
# Create thread and call
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
task = agent.run(messages="Test message", thread=thread, response_format=SampleSchema)
|
||||
|
||||
assert task == mock_task
|
||||
|
||||
# Verify schema was passed in the call_entity arguments
|
||||
call_args = mock_context.call_entity.call_args
|
||||
input_data = call_args[0][2] # Third argument is input_data
|
||||
assert "response_format" in input_data
|
||||
assert input_data["response_format"]["__response_schema_type__"] == "pydantic_model"
|
||||
assert input_data["response_format"]["module"] == SampleSchema.__module__
|
||||
assert input_data["response_format"]["qualname"] == SampleSchema.__qualname__
|
||||
|
||||
def test_messages_to_string(self) -> None:
|
||||
"""Test converting ChatMessage list to string."""
|
||||
from agent_framework import ChatMessage
|
||||
|
||||
mock_context = Mock()
|
||||
agent = DurableAIAgent(mock_context, "TestAgent")
|
||||
|
||||
messages = [
|
||||
ChatMessage(role="user", text="Hello"),
|
||||
ChatMessage(role="assistant", text="Hi there"),
|
||||
ChatMessage(role="user", text="How are you?"),
|
||||
]
|
||||
|
||||
result = agent._messages_to_string(messages)
|
||||
|
||||
assert result == "Hello\nHi there\nHow are you?"
|
||||
|
||||
def test_run_with_chat_message(self) -> None:
|
||||
"""Test that run() handles ChatMessage input."""
|
||||
from agent_framework import ChatMessage
|
||||
|
||||
mock_context = Mock()
|
||||
mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"])
|
||||
mock_task = Mock()
|
||||
mock_context.call_entity = Mock(return_value=mock_task)
|
||||
|
||||
agent = DurableAIAgent(mock_context, "TestAgent")
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
# Call with ChatMessage
|
||||
msg = ChatMessage(role="user", text="Hello")
|
||||
task = agent.run(messages=msg, thread=thread)
|
||||
|
||||
assert task == mock_task
|
||||
|
||||
# Verify message was converted to string
|
||||
call_args = mock_context.call_entity.call_args
|
||||
request = call_args[0][2]
|
||||
assert request["message"] == "Hello"
|
||||
|
||||
def test_run_stream_raises_not_implemented(self) -> None:
|
||||
"""Test that run_stream() method raises NotImplementedError."""
|
||||
mock_context = Mock()
|
||||
agent = DurableAIAgent(mock_context, "TestAgent")
|
||||
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
agent.run_stream("Test message")
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Streaming is not supported" in error_msg
|
||||
|
||||
def test_entity_id_format(self) -> None:
|
||||
"""Test that EntityId is created with correct format (name, key)."""
|
||||
from azure.durable_functions import EntityId
|
||||
|
||||
mock_context = Mock()
|
||||
mock_context.new_uuid = Mock(return_value="test-guid-789")
|
||||
mock_context.call_entity = Mock(return_value=Mock())
|
||||
|
||||
agent = DurableAIAgent(mock_context, "WriterAgent")
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
# Call run() to trigger entity ID creation
|
||||
agent.run("Test", thread=thread)
|
||||
|
||||
# Verify call_entity was called with correct EntityId
|
||||
call_args = mock_context.call_entity.call_args
|
||||
entity_id = call_args[0][0]
|
||||
|
||||
# EntityId should be EntityId(name="dafx-WriterAgent", key="test-guid-789")
|
||||
# Which formats as "@dafx-writeragent@test-guid-789"
|
||||
assert isinstance(entity_id, EntityId)
|
||||
assert entity_id.name == "dafx-WriterAgent"
|
||||
assert entity_id.key == "test-guid-789"
|
||||
assert str(entity_id) == "@dafx-writeragent@test-guid-789"
|
||||
|
||||
|
||||
class TestGetAgentHelper:
|
||||
"""Test suite for the get_agent helper function."""
|
||||
|
||||
def test_get_agent_function(self) -> None:
|
||||
"""Test get_agent function creates DurableAIAgent."""
|
||||
mock_context = Mock()
|
||||
mock_context.instance_id = "test-instance-100"
|
||||
|
||||
agent = get_agent(mock_context, "MyAgent")
|
||||
|
||||
assert isinstance(agent, DurableAIAgent)
|
||||
assert agent.agent_name == "MyAgent"
|
||||
assert agent.context == mock_context
|
||||
|
||||
|
||||
class TestOrchestrationIntegration:
|
||||
"""Integration tests for orchestration scenarios."""
|
||||
|
||||
def test_sequential_agent_calls_simulation(self) -> None:
|
||||
"""Simulate sequential agent calls in an orchestration."""
|
||||
mock_context = Mock()
|
||||
mock_context.instance_id = "test-orchestration-001"
|
||||
# new_uuid will be called 3 times:
|
||||
# 1. thread creation
|
||||
# 2. correlation_id for first call
|
||||
# 3. correlation_id for second call
|
||||
mock_context.new_uuid = Mock(side_effect=["deterministic-guid-001", "corr-1", "corr-2"])
|
||||
|
||||
# Track entity calls
|
||||
entity_calls: list[dict[str, Any]] = []
|
||||
|
||||
def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> Mock:
|
||||
entity_calls.append({"entity_id": str(entity_id), "operation": operation, "input": input_data})
|
||||
|
||||
# Return a mock Task
|
||||
mock_task = Mock()
|
||||
mock_task._is_scheduled = False
|
||||
return mock_task
|
||||
|
||||
mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect)
|
||||
|
||||
# Create agent
|
||||
agent = get_agent(mock_context, "WriterAgent")
|
||||
|
||||
# Create thread
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
# First call - returns Task
|
||||
task1 = agent.run("Write something", thread=thread)
|
||||
assert hasattr(task1, "_is_scheduled")
|
||||
|
||||
# Second call - returns Task
|
||||
task2 = agent.run("Improve: something", thread=thread)
|
||||
assert hasattr(task2, "_is_scheduled")
|
||||
|
||||
# Verify both calls used the same entity (same session key)
|
||||
assert len(entity_calls) == 2
|
||||
assert entity_calls[0]["entity_id"] == entity_calls[1]["entity_id"]
|
||||
# EntityId format is @dafx-writeragent@deterministic-guid-001
|
||||
assert entity_calls[0]["entity_id"] == "@dafx-writeragent@deterministic-guid-001"
|
||||
# new_uuid called 3 times: thread + 2 correlation IDs
|
||||
assert mock_context.new_uuid.call_count == 3
|
||||
|
||||
def test_multiple_agents_in_orchestration(self) -> None:
|
||||
"""Test using multiple different agents in one orchestration."""
|
||||
mock_context = Mock()
|
||||
mock_context.instance_id = "test-orchestration-002"
|
||||
# Mock new_uuid to return different GUIDs for each call
|
||||
# Order: writer thread, editor thread, writer correlation, editor correlation
|
||||
mock_context.new_uuid = Mock(side_effect=["writer-guid-001", "editor-guid-002", "writer-corr", "editor-corr"])
|
||||
|
||||
entity_calls: list[str] = []
|
||||
|
||||
def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> Mock:
|
||||
entity_calls.append(str(entity_id))
|
||||
mock_task = Mock()
|
||||
mock_task._is_scheduled = False
|
||||
return mock_task
|
||||
|
||||
mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect)
|
||||
|
||||
# Create multiple agents
|
||||
writer = get_agent(mock_context, "WriterAgent")
|
||||
editor = get_agent(mock_context, "EditorAgent")
|
||||
|
||||
writer_thread = writer.get_new_thread()
|
||||
editor_thread = editor.get_new_thread()
|
||||
|
||||
# Call both agents - returns Tasks
|
||||
writer_task = writer.run("Write", thread=writer_thread)
|
||||
editor_task = editor.run("Edit", thread=editor_thread)
|
||||
|
||||
assert hasattr(writer_task, "_is_scheduled")
|
||||
assert hasattr(editor_task, "_is_scheduled")
|
||||
|
||||
# Verify different entity IDs were used
|
||||
assert len(entity_calls) == 2
|
||||
# EntityId format is @dafx-agentname@guid (lowercased agent name with dafx- prefix)
|
||||
assert entity_calls[0] == "@dafx-writeragent@writer-guid-001"
|
||||
assert entity_calls[1] == "@dafx-editoragent@editor-guid-002"
|
||||
|
||||
|
||||
class TestAgentThreadSerialization:
|
||||
"""Test that AgentThread can be serialized for orchestration state."""
|
||||
|
||||
async def test_agent_thread_serialize(self) -> None:
|
||||
"""Test that AgentThread can be serialized."""
|
||||
thread = AgentThread()
|
||||
|
||||
# Serialize
|
||||
serialized = await thread.serialize()
|
||||
|
||||
assert isinstance(serialized, dict)
|
||||
assert "service_thread_id" in serialized
|
||||
|
||||
async def test_agent_thread_deserialize(self) -> None:
|
||||
"""Test that AgentThread can be deserialized."""
|
||||
thread = AgentThread()
|
||||
serialized = await thread.serialize()
|
||||
|
||||
# Deserialize
|
||||
restored = await AgentThread.deserialize(serialized)
|
||||
|
||||
assert isinstance(restored, AgentThread)
|
||||
assert restored.service_thread_id == thread.service_thread_id
|
||||
|
||||
async def test_durable_agent_thread_serialization(self) -> None:
|
||||
"""Test that DurableAgentThread persists session metadata during serialization."""
|
||||
mock_context = Mock()
|
||||
mock_context.instance_id = "test-instance-999"
|
||||
mock_context.new_uuid = Mock(return_value="test-guid-999")
|
||||
|
||||
agent = DurableAIAgent(mock_context, "TestAgent")
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
assert isinstance(thread, DurableAgentThread)
|
||||
# Verify custom attribute and property exist
|
||||
assert thread.session_id is not None
|
||||
session_id = thread.session_id
|
||||
assert isinstance(session_id, AgentSessionId)
|
||||
assert session_id.name == "TestAgent"
|
||||
assert session_id.key == "test-guid-999"
|
||||
|
||||
# Standard serialization should still work
|
||||
serialized = await thread.serialize()
|
||||
assert isinstance(serialized, dict)
|
||||
assert serialized.get("durable_session_id") == str(session_id)
|
||||
|
||||
# After deserialization, we'd need to restore the custom attribute
|
||||
# This would be handled by the orchestration framework
|
||||
restored = await DurableAgentThread.deserialize(serialized)
|
||||
assert isinstance(restored, DurableAgentThread)
|
||||
assert restored.session_id == session_id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
@@ -0,0 +1,110 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Unit tests for AgentState correlation ID tracking."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from agent_framework import AgentRunResponse
|
||||
|
||||
from agent_framework_azurefunctions._state import AgentState
|
||||
|
||||
|
||||
class TestAgentStateCorrelationId:
|
||||
"""Test suite for AgentState correlation ID tracking."""
|
||||
|
||||
def _create_mock_response(self, text: str = "Response") -> Mock:
|
||||
"""Create a mock AgentRunResponse with the provided text."""
|
||||
mock_response = Mock(spec=AgentRunResponse)
|
||||
mock_response.to_dict.return_value = {"text": text, "messages": []}
|
||||
return mock_response
|
||||
|
||||
def test_add_assistant_message_with_correlation_id(self) -> None:
|
||||
state = AgentState()
|
||||
state.add_user_message("Hello", correlation_id="corr-123-request")
|
||||
state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123")
|
||||
message_metadata = state.conversation_history[-1].additional_properties or {}
|
||||
assert message_metadata.get("correlation_id") == "corr-123"
|
||||
|
||||
response_data = state.try_get_agent_response("corr-123")
|
||||
assert response_data is not None
|
||||
assert response_data["content"] == "Response"
|
||||
assert response_data["agent_response"] == {"text": "Response", "messages": []}
|
||||
|
||||
def test_try_get_agent_response_returns_response(self) -> None:
|
||||
state = AgentState()
|
||||
state.add_user_message("Hello", correlation_id="corr-200-request")
|
||||
state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-456")
|
||||
|
||||
response_data = state.try_get_agent_response("corr-456")
|
||||
|
||||
assert response_data is not None
|
||||
assert response_data["content"] == "Response"
|
||||
|
||||
def test_try_get_agent_response_returns_none_for_missing_id(self) -> None:
|
||||
state = AgentState()
|
||||
state.add_user_message("Hello", correlation_id="corr-300-request")
|
||||
state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123")
|
||||
|
||||
assert state.try_get_agent_response("non-existent") is None
|
||||
|
||||
def test_multiple_responses_tracked_separately(self) -> None:
|
||||
state = AgentState()
|
||||
|
||||
for index in range(3):
|
||||
state.add_user_message(f"Message {index}", correlation_id=f"corr-{index}-request")
|
||||
state.add_assistant_message(
|
||||
f"Response {index}",
|
||||
self._create_mock_response(text=f"Response {index}"),
|
||||
correlation_id=f"corr-{index}",
|
||||
)
|
||||
|
||||
for index in range(3):
|
||||
payload = state.try_get_agent_response(f"corr-{index}")
|
||||
assert payload is not None
|
||||
assert payload["content"] == f"Response {index}"
|
||||
|
||||
def test_add_assistant_message_without_correlation_id(self) -> None:
|
||||
state = AgentState()
|
||||
state.add_user_message("Hello", correlation_id="corr-400-request")
|
||||
state.add_assistant_message("Response", self._create_mock_response())
|
||||
|
||||
assert state.try_get_agent_response("missing") is None
|
||||
assert state.last_response == "Response"
|
||||
|
||||
def test_to_dict_does_not_duplicate_agent_responses(self) -> None:
|
||||
state = AgentState()
|
||||
state.add_user_message("Hello", correlation_id="corr-500-request")
|
||||
state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123")
|
||||
|
||||
state_snapshot = state.to_dict()
|
||||
|
||||
assert "agent_responses" not in state_snapshot
|
||||
metadata = state_snapshot["conversation_history"][-1]["additional_properties"]
|
||||
assert metadata["correlation_id"] == "corr-123"
|
||||
|
||||
def test_restore_state_preserves_agent_response_lookup(self) -> None:
|
||||
state = AgentState()
|
||||
state.add_user_message("Hello", correlation_id="corr-600-request")
|
||||
state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123")
|
||||
|
||||
restored_state = AgentState()
|
||||
restored_state.restore_state(state.to_dict())
|
||||
|
||||
payload = restored_state.try_get_agent_response("corr-123")
|
||||
assert payload is not None
|
||||
assert payload["content"] == "Response"
|
||||
|
||||
def test_reset_clears_conversation_history(self) -> None:
|
||||
state = AgentState()
|
||||
state.add_user_message("Hello", correlation_id="corr-700-request")
|
||||
state.add_assistant_message("Response", self._create_mock_response(), correlation_id="corr-123")
|
||||
|
||||
state.reset()
|
||||
|
||||
assert len(state.conversation_history) == 0
|
||||
assert state.try_get_agent_response("corr-123") is None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
Reference in New Issue
Block a user