mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
617 lines
24 KiB
Python
617 lines
24 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
"""Unit tests for orchestration support (DurableAIAgent)."""
|
|
|
|
from typing import Any
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
from agent_framework import AgentResponse, AgentThread, ChatMessage
|
|
from azure.durable_functions.models.Task import TaskBase, TaskState
|
|
|
|
from agent_framework_azurefunctions import AgentFunctionApp, DurableAIAgent
|
|
from agent_framework_azurefunctions._models import AgentSessionId, DurableAgentThread
|
|
from agent_framework_azurefunctions._orchestration import AgentTask
|
|
|
|
|
|
def _app_with_registered_agents(*agent_names: str) -> AgentFunctionApp:
|
|
app = AgentFunctionApp(enable_health_check=False, enable_http_endpoints=False)
|
|
for name in agent_names:
|
|
agent = Mock()
|
|
agent.name = name
|
|
app.add_agent(agent)
|
|
return app
|
|
|
|
|
|
class _FakeTask(TaskBase):
|
|
"""Concrete TaskBase for testing AgentTask wiring."""
|
|
|
|
def __init__(self, task_id: int = 1):
|
|
super().__init__(task_id, [])
|
|
self._set_is_scheduled(False)
|
|
self.action_repr = []
|
|
self.state = TaskState.RUNNING
|
|
|
|
|
|
def _create_entity_task(task_id: int = 1) -> TaskBase:
|
|
"""Create a minimal TaskBase instance for AgentTask tests."""
|
|
return _FakeTask(task_id)
|
|
|
|
|
|
class TestAgentResponseHelpers:
|
|
"""Tests for helper utilities that prepare AgentResponse values."""
|
|
|
|
@staticmethod
|
|
def _create_agent_task() -> AgentTask:
|
|
entity_task = _create_entity_task()
|
|
return AgentTask(entity_task, None, "correlation-id")
|
|
|
|
def test_load_agent_response_from_instance(self) -> None:
|
|
task = self._create_agent_task()
|
|
response = AgentResponse(messages=[ChatMessage(role="assistant", text='{"foo": "bar"}')])
|
|
|
|
loaded = task._load_agent_response(response)
|
|
|
|
assert loaded is response
|
|
assert loaded.value is None
|
|
|
|
def test_load_agent_response_from_serialized(self) -> None:
|
|
task = self._create_agent_task()
|
|
serialized = AgentResponse(messages=[ChatMessage(role="assistant", text="structured")]).to_dict()
|
|
serialized["value"] = {"answer": 42}
|
|
|
|
loaded = task._load_agent_response(serialized)
|
|
|
|
assert loaded is not None
|
|
assert loaded.value == {"answer": 42}
|
|
loaded_dict = loaded.to_dict()
|
|
assert loaded_dict["type"] == "agent_response"
|
|
|
|
def test_load_agent_response_rejects_none(self) -> None:
|
|
task = self._create_agent_task()
|
|
|
|
with pytest.raises(ValueError):
|
|
task._load_agent_response(None)
|
|
|
|
def test_load_agent_response_rejects_unsupported_type(self) -> None:
|
|
task = self._create_agent_task()
|
|
|
|
with pytest.raises(TypeError, match="Unsupported type"):
|
|
task._load_agent_response(["invalid", "list"]) # type: ignore[arg-type]
|
|
|
|
def test_try_set_value_success(self) -> None:
|
|
"""Test try_set_value correctly processes successful task completion."""
|
|
entity_task = _create_entity_task()
|
|
task = AgentTask(entity_task, None, "correlation-id")
|
|
|
|
# Simulate successful entity task completion
|
|
entity_task.state = TaskState.SUCCEEDED
|
|
entity_task.result = AgentResponse(messages=[ChatMessage(role="assistant", text="Test response")]).to_dict()
|
|
|
|
# Clear pending_tasks to simulate that parent has processed the child
|
|
task.pending_tasks.clear()
|
|
|
|
# Call try_set_value
|
|
task.try_set_value(entity_task)
|
|
|
|
# Verify task completed successfully with AgentResponse
|
|
assert task.state == TaskState.SUCCEEDED
|
|
assert isinstance(task.result, AgentResponse)
|
|
assert task.result.text == "Test response"
|
|
|
|
def test_try_set_value_failure(self) -> None:
|
|
"""Test try_set_value correctly handles failed task completion."""
|
|
entity_task = _create_entity_task()
|
|
task = AgentTask(entity_task, None, "correlation-id")
|
|
|
|
# Simulate failed entity task
|
|
entity_task.state = TaskState.FAILED
|
|
entity_task.result = Exception("Entity call failed")
|
|
|
|
# Call try_set_value
|
|
task.try_set_value(entity_task)
|
|
|
|
# Verify task failed with the error
|
|
assert task.state == TaskState.FAILED
|
|
assert isinstance(task.result, Exception)
|
|
assert str(task.result) == "Entity call failed"
|
|
|
|
def test_try_set_value_with_response_format(self) -> None:
|
|
"""Test try_set_value parses structured output when response_format is provided."""
|
|
from pydantic import BaseModel
|
|
|
|
class TestSchema(BaseModel):
|
|
answer: str
|
|
|
|
entity_task = _create_entity_task()
|
|
task = AgentTask(entity_task, TestSchema, "correlation-id")
|
|
|
|
# Simulate successful entity task with JSON response
|
|
entity_task.state = TaskState.SUCCEEDED
|
|
entity_task.result = AgentResponse(messages=[ChatMessage(role="assistant", text='{"answer": "42"}')]).to_dict()
|
|
|
|
# Clear pending_tasks to simulate that parent has processed the child
|
|
task.pending_tasks.clear()
|
|
|
|
# Call try_set_value
|
|
task.try_set_value(entity_task)
|
|
|
|
# Verify task completed and value was parsed
|
|
assert task.state == TaskState.SUCCEEDED
|
|
assert isinstance(task.result, AgentResponse)
|
|
assert isinstance(task.result.value, TestSchema)
|
|
assert task.result.value.answer == "42"
|
|
|
|
def test_ensure_response_format_parses_value(self) -> None:
|
|
"""Test _ensure_response_format correctly parses response value."""
|
|
from pydantic import BaseModel
|
|
|
|
class SampleSchema(BaseModel):
|
|
name: str
|
|
|
|
task = self._create_agent_task()
|
|
response = AgentResponse(messages=[ChatMessage(role="assistant", text='{"name": "test"}')])
|
|
|
|
# Value should be None initially
|
|
assert response.value is None
|
|
|
|
# Parse the value
|
|
task._ensure_response_format(SampleSchema, "test-correlation", response)
|
|
|
|
# Value should now be parsed
|
|
assert isinstance(response.value, SampleSchema)
|
|
assert response.value.name == "test"
|
|
|
|
def test_ensure_response_format_skips_if_already_parsed(self) -> None:
|
|
"""Test _ensure_response_format does not re-parse if value already matches format."""
|
|
from pydantic import BaseModel
|
|
|
|
class SampleSchema(BaseModel):
|
|
name: str
|
|
|
|
task = self._create_agent_task()
|
|
existing_value = SampleSchema(name="existing")
|
|
response = AgentResponse(
|
|
messages=[ChatMessage(role="assistant", text='{"name": "new"}')],
|
|
value=existing_value,
|
|
)
|
|
|
|
# Call _ensure_response_format
|
|
task._ensure_response_format(SampleSchema, "test-correlation", response)
|
|
|
|
# Value should remain unchanged (not re-parsed)
|
|
assert response.value is existing_value
|
|
assert response.value.name == "existing"
|
|
|
|
|
|
class TestDurableAIAgent:
|
|
"""Test suite for DurableAIAgent wrapper."""
|
|
|
|
def test_init(self) -> None:
|
|
"""Test DurableAIAgent initialization."""
|
|
mock_context = Mock()
|
|
mock_context.instance_id = "test-instance-123"
|
|
|
|
agent = DurableAIAgent(mock_context, "TestAgent")
|
|
|
|
assert agent.context == mock_context
|
|
assert agent.agent_name == "TestAgent"
|
|
|
|
def test_implements_agent_protocol(self) -> None:
|
|
"""Test that DurableAIAgent implements AgentProtocol."""
|
|
from agent_framework import AgentProtocol
|
|
|
|
mock_context = Mock()
|
|
agent = DurableAIAgent(mock_context, "TestAgent")
|
|
|
|
# Check that agent satisfies AgentProtocol
|
|
assert isinstance(agent, AgentProtocol)
|
|
|
|
def test_has_agent_protocol_properties(self) -> None:
|
|
"""Test that DurableAIAgent has AgentProtocol properties."""
|
|
mock_context = Mock()
|
|
agent = DurableAIAgent(mock_context, "TestAgent")
|
|
|
|
# AgentProtocol properties
|
|
assert hasattr(agent, "id")
|
|
assert hasattr(agent, "name")
|
|
assert hasattr(agent, "description")
|
|
|
|
# Verify values
|
|
assert agent.name == "TestAgent"
|
|
assert agent.description == "Durable agent proxy for TestAgent"
|
|
assert agent.id is not None # Auto-generated UUID
|
|
|
|
def test_get_new_thread(self) -> None:
|
|
"""Test creating a new agent thread."""
|
|
mock_context = Mock()
|
|
mock_context.instance_id = "test-instance-456"
|
|
mock_context.new_uuid = Mock(return_value="test-guid-456")
|
|
|
|
agent = DurableAIAgent(mock_context, "WriterAgent")
|
|
thread = agent.get_new_thread()
|
|
|
|
assert isinstance(thread, DurableAgentThread)
|
|
assert thread.session_id is not None
|
|
session_id = thread.session_id
|
|
assert isinstance(session_id, AgentSessionId)
|
|
assert session_id.name == "WriterAgent"
|
|
assert session_id.key == "test-guid-456"
|
|
mock_context.new_uuid.assert_called_once()
|
|
|
|
def test_get_new_thread_deterministic(self) -> None:
|
|
"""Test that get_new_thread creates deterministic session IDs."""
|
|
|
|
mock_context = Mock()
|
|
mock_context.instance_id = "test-instance-789"
|
|
mock_context.new_uuid = Mock(side_effect=["session-guid-1", "session-guid-2"])
|
|
|
|
agent = DurableAIAgent(mock_context, "EditorAgent")
|
|
|
|
# Create multiple threads - they should have unique session IDs
|
|
thread1 = agent.get_new_thread()
|
|
thread2 = agent.get_new_thread()
|
|
|
|
assert isinstance(thread1, DurableAgentThread)
|
|
assert isinstance(thread2, DurableAgentThread)
|
|
|
|
session_id1 = thread1.session_id
|
|
session_id2 = thread2.session_id
|
|
assert session_id1 is not None and session_id2 is not None
|
|
assert isinstance(session_id1, AgentSessionId)
|
|
assert isinstance(session_id2, AgentSessionId)
|
|
assert session_id1.name == "EditorAgent"
|
|
assert session_id2.name == "EditorAgent"
|
|
assert session_id1.key == "session-guid-1"
|
|
assert session_id2.key == "session-guid-2"
|
|
assert mock_context.new_uuid.call_count == 2
|
|
|
|
def test_run_creates_entity_call(self) -> None:
|
|
"""Test that run() creates proper entity call and returns a Task."""
|
|
mock_context = Mock()
|
|
mock_context.instance_id = "test-instance-001"
|
|
mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"])
|
|
|
|
entity_task = _create_entity_task()
|
|
mock_context.call_entity = Mock(return_value=entity_task)
|
|
|
|
agent = DurableAIAgent(mock_context, "TestAgent")
|
|
|
|
# Create thread
|
|
thread = agent.get_new_thread()
|
|
|
|
# Call run() - returns AgentTask directly
|
|
task = agent.run(messages="Test message", thread=thread, enable_tool_calls=True)
|
|
|
|
assert isinstance(task, AgentTask)
|
|
assert task.children[0] == entity_task
|
|
|
|
# Verify call_entity was called with correct parameters
|
|
assert mock_context.call_entity.called
|
|
call_args = mock_context.call_entity.call_args
|
|
entity_id, operation, request = call_args[0]
|
|
|
|
assert operation == "run"
|
|
assert request["message"] == "Test message"
|
|
assert request["enable_tool_calls"] is True
|
|
assert "correlationId" in request
|
|
assert request["correlationId"] == "correlation-guid"
|
|
assert "thread_id" in request
|
|
assert request["thread_id"] == "thread-guid"
|
|
# Verify orchestration ID is set from context.instance_id
|
|
assert "orchestrationId" in request
|
|
assert request["orchestrationId"] == "test-instance-001"
|
|
|
|
def test_run_sets_orchestration_id(self) -> None:
|
|
"""Test that run() sets the orchestration_id from context.instance_id."""
|
|
mock_context = Mock()
|
|
mock_context.instance_id = "my-orchestration-123"
|
|
mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"])
|
|
|
|
entity_task = _create_entity_task()
|
|
mock_context.call_entity = Mock(return_value=entity_task)
|
|
|
|
agent = DurableAIAgent(mock_context, "TestAgent")
|
|
thread = agent.get_new_thread()
|
|
|
|
agent.run(messages="Test", thread=thread)
|
|
|
|
call_args = mock_context.call_entity.call_args
|
|
request = call_args[0][2]
|
|
|
|
assert request["orchestrationId"] == "my-orchestration-123"
|
|
|
|
def test_run_without_thread(self) -> None:
|
|
"""Test that run() works without explicit thread (creates unique session key)."""
|
|
mock_context = Mock()
|
|
mock_context.instance_id = "test-instance-002"
|
|
mock_context.new_uuid = Mock(side_effect=["auto-generated-guid", "correlation-guid"])
|
|
|
|
entity_task = _create_entity_task()
|
|
mock_context.call_entity = Mock(return_value=entity_task)
|
|
|
|
agent = DurableAIAgent(mock_context, "TestAgent")
|
|
|
|
# Call without thread
|
|
task = agent.run(messages="Test message")
|
|
|
|
assert isinstance(task, AgentTask)
|
|
assert task.children[0] == entity_task
|
|
|
|
# Verify the entity ID uses the auto-generated GUID with dafx- prefix
|
|
call_args = mock_context.call_entity.call_args
|
|
entity_id = call_args[0][0]
|
|
assert entity_id.name == "dafx-TestAgent"
|
|
assert entity_id.key == "auto-generated-guid"
|
|
# Should be called twice: once for session_key, once for correlationId
|
|
assert mock_context.new_uuid.call_count == 2
|
|
|
|
def test_run_with_response_format(self) -> None:
|
|
"""Test that run() passes response format correctly."""
|
|
mock_context = Mock()
|
|
mock_context.instance_id = "test-instance-003"
|
|
|
|
entity_task = _create_entity_task()
|
|
mock_context.call_entity = Mock(return_value=entity_task)
|
|
|
|
agent = DurableAIAgent(mock_context, "TestAgent")
|
|
|
|
from pydantic import BaseModel
|
|
|
|
class SampleSchema(BaseModel):
|
|
key: str
|
|
|
|
# Create thread and call
|
|
thread = agent.get_new_thread()
|
|
|
|
task = agent.run(messages="Test message", thread=thread, options={"response_format": SampleSchema})
|
|
|
|
assert isinstance(task, AgentTask)
|
|
assert task.children[0] == entity_task
|
|
|
|
# Verify schema was passed in the call_entity arguments
|
|
call_args = mock_context.call_entity.call_args
|
|
input_data = call_args[0][2] # Third argument is input_data
|
|
assert "response_format" in input_data
|
|
assert input_data["response_format"]["__response_schema_type__"] == "pydantic_model"
|
|
assert input_data["response_format"]["module"] == SampleSchema.__module__
|
|
assert input_data["response_format"]["qualname"] == SampleSchema.__qualname__
|
|
|
|
def test_messages_to_string(self) -> None:
|
|
"""Test converting ChatMessage list to string."""
|
|
from agent_framework import ChatMessage
|
|
|
|
mock_context = Mock()
|
|
agent = DurableAIAgent(mock_context, "TestAgent")
|
|
|
|
messages = [
|
|
ChatMessage(role="user", text="Hello"),
|
|
ChatMessage(role="assistant", text="Hi there"),
|
|
ChatMessage(role="user", text="How are you?"),
|
|
]
|
|
|
|
result = agent._messages_to_string(messages)
|
|
|
|
assert result == "Hello\nHi there\nHow are you?"
|
|
|
|
def test_run_with_chat_message(self) -> None:
|
|
"""Test that run() handles ChatMessage input."""
|
|
from agent_framework import ChatMessage
|
|
|
|
mock_context = Mock()
|
|
mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"])
|
|
entity_task = _create_entity_task()
|
|
mock_context.call_entity = Mock(return_value=entity_task)
|
|
|
|
agent = DurableAIAgent(mock_context, "TestAgent")
|
|
thread = agent.get_new_thread()
|
|
|
|
# Call with ChatMessage
|
|
msg = ChatMessage(role="user", text="Hello")
|
|
task = agent.run(messages=msg, thread=thread)
|
|
|
|
assert isinstance(task, AgentTask)
|
|
assert task.children[0] == entity_task
|
|
|
|
# Verify message was converted to string
|
|
call_args = mock_context.call_entity.call_args
|
|
request = call_args[0][2]
|
|
assert request["message"] == "Hello"
|
|
|
|
def test_run_stream_raises_not_implemented(self) -> None:
|
|
"""Test that run_stream() method raises NotImplementedError."""
|
|
mock_context = Mock()
|
|
agent = DurableAIAgent(mock_context, "TestAgent")
|
|
|
|
with pytest.raises(NotImplementedError) as exc_info:
|
|
agent.run_stream("Test message")
|
|
|
|
error_msg = str(exc_info.value)
|
|
assert "Streaming is not supported" in error_msg
|
|
|
|
def test_entity_id_format(self) -> None:
|
|
"""Test that EntityId is created with correct format (name, key)."""
|
|
from azure.durable_functions import EntityId
|
|
|
|
mock_context = Mock()
|
|
mock_context.new_uuid = Mock(return_value="test-guid-789")
|
|
mock_context.call_entity = Mock(return_value=_create_entity_task())
|
|
|
|
agent = DurableAIAgent(mock_context, "WriterAgent")
|
|
thread = agent.get_new_thread()
|
|
|
|
# Call run() to trigger entity ID creation
|
|
agent.run("Test", thread=thread)
|
|
|
|
# Verify call_entity was called with correct EntityId
|
|
call_args = mock_context.call_entity.call_args
|
|
entity_id = call_args[0][0]
|
|
|
|
# EntityId should be EntityId(name="dafx-WriterAgent", key="test-guid-789")
|
|
# Which formats as "@dafx-writeragent@test-guid-789"
|
|
assert isinstance(entity_id, EntityId)
|
|
assert entity_id.name == "dafx-WriterAgent"
|
|
assert entity_id.key == "test-guid-789"
|
|
assert str(entity_id) == "@dafx-writeragent@test-guid-789"
|
|
|
|
|
|
class TestAgentFunctionAppGetAgent:
|
|
"""Test suite for AgentFunctionApp.get_agent."""
|
|
|
|
def test_get_agent_method(self) -> None:
|
|
"""Test get_agent method creates DurableAIAgent for registered agent."""
|
|
app = _app_with_registered_agents("MyAgent")
|
|
mock_context = Mock()
|
|
mock_context.instance_id = "test-instance-100"
|
|
|
|
agent = app.get_agent(mock_context, "MyAgent")
|
|
|
|
assert isinstance(agent, DurableAIAgent)
|
|
assert agent.agent_name == "MyAgent"
|
|
assert agent.context == mock_context
|
|
|
|
def test_get_agent_raises_for_unregistered_agent(self) -> None:
|
|
"""Test get_agent raises ValueError when agent is not registered."""
|
|
app = _app_with_registered_agents("KnownAgent")
|
|
|
|
with pytest.raises(ValueError, match=r"Agent 'MissingAgent' is not registered with this app\."):
|
|
app.get_agent(Mock(), "MissingAgent")
|
|
|
|
|
|
class TestOrchestrationIntegration:
|
|
"""Integration tests for orchestration scenarios."""
|
|
|
|
def test_sequential_agent_calls_simulation(self) -> None:
|
|
"""Simulate sequential agent calls in an orchestration."""
|
|
mock_context = Mock()
|
|
mock_context.instance_id = "test-orchestration-001"
|
|
# new_uuid will be called 3 times:
|
|
# 1. thread creation
|
|
# 2. correlationId for first call
|
|
# 3. correlationId for second call
|
|
mock_context.new_uuid = Mock(side_effect=["deterministic-guid-001", "corr-1", "corr-2"])
|
|
|
|
# Track entity calls
|
|
entity_calls: list[dict[str, Any]] = []
|
|
|
|
def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> TaskBase:
|
|
entity_calls.append({"entity_id": str(entity_id), "operation": operation, "input": input_data})
|
|
return _create_entity_task()
|
|
|
|
mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect)
|
|
|
|
app = _app_with_registered_agents("WriterAgent")
|
|
agent = app.get_agent(mock_context, "WriterAgent")
|
|
|
|
# Create thread
|
|
thread = agent.get_new_thread()
|
|
|
|
# First call - returns AgentTask
|
|
task1 = agent.run("Write something", thread=thread)
|
|
assert isinstance(task1, AgentTask)
|
|
|
|
# Second call - returns AgentTask
|
|
task2 = agent.run("Improve: something", thread=thread)
|
|
assert isinstance(task2, AgentTask)
|
|
|
|
# Verify both calls used the same entity (same session key)
|
|
assert len(entity_calls) == 2
|
|
assert entity_calls[0]["entity_id"] == entity_calls[1]["entity_id"]
|
|
# EntityId format is @dafx-writeragent@deterministic-guid-001
|
|
assert entity_calls[0]["entity_id"] == "@dafx-writeragent@deterministic-guid-001"
|
|
# new_uuid called 3 times: thread + 2 correlation IDs
|
|
assert mock_context.new_uuid.call_count == 3
|
|
|
|
def test_multiple_agents_in_orchestration(self) -> None:
|
|
"""Test using multiple different agents in one orchestration."""
|
|
mock_context = Mock()
|
|
mock_context.instance_id = "test-orchestration-002"
|
|
# Mock new_uuid to return different GUIDs for each call
|
|
# Order: writer thread, editor thread, writer correlation, editor correlation
|
|
mock_context.new_uuid = Mock(side_effect=["writer-guid-001", "editor-guid-002", "writer-corr", "editor-corr"])
|
|
|
|
entity_calls: list[str] = []
|
|
|
|
def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> TaskBase:
|
|
entity_calls.append(str(entity_id))
|
|
return _create_entity_task()
|
|
|
|
mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect)
|
|
|
|
app = _app_with_registered_agents("WriterAgent", "EditorAgent")
|
|
writer = app.get_agent(mock_context, "WriterAgent")
|
|
editor = app.get_agent(mock_context, "EditorAgent")
|
|
|
|
writer_thread = writer.get_new_thread()
|
|
editor_thread = editor.get_new_thread()
|
|
|
|
# Call both agents - returns AgentTasks
|
|
writer_task = writer.run("Write", thread=writer_thread)
|
|
editor_task = editor.run("Edit", thread=editor_thread)
|
|
|
|
assert isinstance(writer_task, AgentTask)
|
|
assert isinstance(editor_task, AgentTask)
|
|
|
|
# Verify different entity IDs were used
|
|
assert len(entity_calls) == 2
|
|
# EntityId format is @dafx-agentname@guid (lowercased agent name with dafx- prefix)
|
|
assert entity_calls[0] == "@dafx-writeragent@writer-guid-001"
|
|
assert entity_calls[1] == "@dafx-editoragent@editor-guid-002"
|
|
|
|
|
|
class TestAgentThreadSerialization:
|
|
"""Test that AgentThread can be serialized for orchestration state."""
|
|
|
|
async def test_agent_thread_serialize(self) -> None:
|
|
"""Test that AgentThread can be serialized."""
|
|
thread = AgentThread()
|
|
|
|
# Serialize
|
|
serialized = await thread.serialize()
|
|
|
|
assert isinstance(serialized, dict)
|
|
assert "service_thread_id" in serialized
|
|
|
|
async def test_agent_thread_deserialize(self) -> None:
|
|
"""Test that AgentThread can be deserialized."""
|
|
thread = AgentThread()
|
|
serialized = await thread.serialize()
|
|
|
|
# Deserialize
|
|
restored = await AgentThread.deserialize(serialized)
|
|
|
|
assert isinstance(restored, AgentThread)
|
|
assert restored.service_thread_id == thread.service_thread_id
|
|
|
|
async def test_durable_agent_thread_serialization(self) -> None:
|
|
"""Test that DurableAgentThread persists session metadata during serialization."""
|
|
mock_context = Mock()
|
|
mock_context.instance_id = "test-instance-999"
|
|
mock_context.new_uuid = Mock(return_value="test-guid-999")
|
|
|
|
agent = DurableAIAgent(mock_context, "TestAgent")
|
|
thread = agent.get_new_thread()
|
|
|
|
assert isinstance(thread, DurableAgentThread)
|
|
# Verify custom attribute and property exist
|
|
assert thread.session_id is not None
|
|
session_id = thread.session_id
|
|
assert isinstance(session_id, AgentSessionId)
|
|
assert session_id.name == "TestAgent"
|
|
assert session_id.key == "test-guid-999"
|
|
|
|
# Standard serialization should still work
|
|
serialized = await thread.serialize()
|
|
assert isinstance(serialized, dict)
|
|
assert serialized.get("durable_session_id") == str(session_id)
|
|
|
|
# After deserialization, we'd need to restore the custom attribute
|
|
# This would be handled by the orchestration framework
|
|
restored = await DurableAgentThread.deserialize(serialized)
|
|
assert isinstance(restored, DurableAgentThread)
|
|
assert restored.session_id == session_id
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v", "--tb=short"])
|