mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
e3eff65a6b
* Add worker and clients * Clean code and refactor common code * Implement sample * Add sample * Update readmes * Fix tests * Fix tests * Update requirements * Fix typo * Address comments * use response.text
273 lines
10 KiB
Python
273 lines
10 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
"""Unit tests for AgentSessionId and DurableAgentThread."""
|
|
|
|
import pytest
|
|
from agent_framework import AgentThread
|
|
|
|
from agent_framework_durabletask._models import AgentSessionId, DurableAgentThread
|
|
|
|
|
|
class TestAgentSessionId:
|
|
"""Test suite for AgentSessionId."""
|
|
|
|
def test_init_creates_session_id(self) -> None:
|
|
"""Test that AgentSessionId initializes correctly."""
|
|
session_id = AgentSessionId(name="AgentEntity", key="test-key-123")
|
|
|
|
assert session_id.name == "AgentEntity"
|
|
assert session_id.key == "test-key-123"
|
|
|
|
def test_with_random_key_generates_guid(self) -> None:
|
|
"""Test that with_random_key generates a GUID."""
|
|
session_id = AgentSessionId.with_random_key(name="AgentEntity")
|
|
|
|
assert session_id.name == "AgentEntity"
|
|
assert len(session_id.key) == 32 # UUID hex is 32 chars
|
|
# Verify it's a valid hex string
|
|
int(session_id.key, 16)
|
|
|
|
def test_with_random_key_unique_keys(self) -> None:
|
|
"""Test that with_random_key generates unique keys."""
|
|
session_id1 = AgentSessionId.with_random_key(name="AgentEntity")
|
|
session_id2 = AgentSessionId.with_random_key(name="AgentEntity")
|
|
|
|
assert session_id1.key != session_id2.key
|
|
|
|
def test_str_representation(self) -> None:
|
|
"""Test string representation."""
|
|
session_id = AgentSessionId(name="AgentEntity", key="test-key-123")
|
|
str_repr = str(session_id)
|
|
|
|
assert str_repr == "@AgentEntity@test-key-123"
|
|
|
|
def test_repr_representation(self) -> None:
|
|
"""Test repr representation."""
|
|
session_id = AgentSessionId(name="AgentEntity", key="test-key")
|
|
repr_str = repr(session_id)
|
|
|
|
assert "AgentSessionId" in repr_str
|
|
assert "AgentEntity" in repr_str
|
|
assert "test-key" in repr_str
|
|
|
|
def test_parse_valid_session_id(self) -> None:
|
|
"""Test parsing valid session ID string."""
|
|
session_id = AgentSessionId.parse("@AgentEntity@test-key-123")
|
|
|
|
assert session_id.name == "AgentEntity"
|
|
assert session_id.key == "test-key-123"
|
|
|
|
def test_parse_invalid_format_no_prefix(self) -> None:
|
|
"""Test parsing invalid format without @ prefix."""
|
|
with pytest.raises(ValueError) as exc_info:
|
|
AgentSessionId.parse("AgentEntity@test-key")
|
|
|
|
assert "Invalid agent session ID format" in str(exc_info.value)
|
|
|
|
def test_parse_invalid_format_single_part(self) -> None:
|
|
"""Test parsing invalid format with single part."""
|
|
with pytest.raises(ValueError) as exc_info:
|
|
AgentSessionId.parse("@AgentEntity")
|
|
|
|
assert "Invalid agent session ID format" in str(exc_info.value)
|
|
|
|
def test_parse_with_multiple_at_signs_in_key(self) -> None:
|
|
"""Test parsing with @ signs in the key."""
|
|
session_id = AgentSessionId.parse("@AgentEntity@key-with@symbols")
|
|
|
|
assert session_id.name == "AgentEntity"
|
|
assert session_id.key == "key-with@symbols"
|
|
|
|
def test_parse_round_trip(self) -> None:
|
|
"""Test round-trip parse and string conversion."""
|
|
original = AgentSessionId(name="AgentEntity", key="test-key")
|
|
str_repr = str(original)
|
|
parsed = AgentSessionId.parse(str_repr)
|
|
|
|
assert parsed.name == original.name
|
|
assert parsed.key == original.key
|
|
|
|
def test_to_entity_name_adds_prefix(self) -> None:
|
|
"""Test that to_entity_name adds the dafx- prefix."""
|
|
entity_name = AgentSessionId.to_entity_name("TestAgent")
|
|
assert entity_name == "dafx-TestAgent"
|
|
|
|
|
|
class TestDurableAgentThread:
|
|
"""Test suite for DurableAgentThread."""
|
|
|
|
def test_init_with_session_id(self) -> None:
|
|
"""Test DurableAgentThread initialization with session ID."""
|
|
session_id = AgentSessionId(name="TestAgent", key="test-key")
|
|
thread = DurableAgentThread(session_id=session_id)
|
|
|
|
assert thread.session_id is not None
|
|
assert thread.session_id == session_id
|
|
|
|
def test_init_without_session_id(self) -> None:
|
|
"""Test DurableAgentThread initialization without session ID."""
|
|
thread = DurableAgentThread()
|
|
|
|
assert thread.session_id is None
|
|
|
|
def test_session_id_setter(self) -> None:
|
|
"""Test setting a session ID to an existing thread."""
|
|
thread = DurableAgentThread()
|
|
assert thread.session_id is None
|
|
|
|
session_id = AgentSessionId(name="TestAgent", key="test-key")
|
|
thread.session_id = session_id
|
|
|
|
assert thread.session_id is not None
|
|
assert thread.session_id == session_id
|
|
assert thread.session_id.name == "TestAgent"
|
|
|
|
def test_from_session_id(self) -> None:
|
|
"""Test creating DurableAgentThread from session ID."""
|
|
session_id = AgentSessionId(name="TestAgent", key="test-key")
|
|
thread = DurableAgentThread.from_session_id(session_id)
|
|
|
|
assert isinstance(thread, DurableAgentThread)
|
|
assert thread.session_id is not None
|
|
assert thread.session_id == session_id
|
|
assert thread.session_id.name == "TestAgent"
|
|
assert thread.session_id.key == "test-key"
|
|
|
|
def test_from_session_id_with_service_thread_id(self) -> None:
|
|
"""Test creating DurableAgentThread with service thread ID."""
|
|
session_id = AgentSessionId(name="TestAgent", key="test-key")
|
|
thread = DurableAgentThread.from_session_id(session_id, service_thread_id="service-123")
|
|
|
|
assert thread.session_id is not None
|
|
assert thread.session_id == session_id
|
|
assert thread.service_thread_id == "service-123"
|
|
|
|
async def test_serialize_with_session_id(self) -> None:
|
|
"""Test serialization includes session ID."""
|
|
session_id = AgentSessionId(name="TestAgent", key="test-key")
|
|
thread = DurableAgentThread(session_id=session_id)
|
|
|
|
serialized = await thread.serialize()
|
|
|
|
assert isinstance(serialized, dict)
|
|
assert "durable_session_id" in serialized
|
|
assert serialized["durable_session_id"] == "@TestAgent@test-key"
|
|
|
|
async def test_serialize_without_session_id(self) -> None:
|
|
"""Test serialization without session ID."""
|
|
thread = DurableAgentThread()
|
|
|
|
serialized = await thread.serialize()
|
|
|
|
assert isinstance(serialized, dict)
|
|
assert "durable_session_id" not in serialized
|
|
|
|
async def test_deserialize_with_session_id(self) -> None:
|
|
"""Test deserialization restores session ID."""
|
|
serialized = {
|
|
"service_thread_id": "thread-123",
|
|
"durable_session_id": "@TestAgent@test-key",
|
|
}
|
|
|
|
thread = await DurableAgentThread.deserialize(serialized)
|
|
|
|
assert isinstance(thread, DurableAgentThread)
|
|
assert thread.session_id is not None
|
|
assert thread.session_id.name == "TestAgent"
|
|
assert thread.session_id.key == "test-key"
|
|
assert thread.service_thread_id == "thread-123"
|
|
|
|
async def test_deserialize_without_session_id(self) -> None:
|
|
"""Test deserialization without session ID."""
|
|
serialized = {
|
|
"service_thread_id": "thread-456",
|
|
}
|
|
|
|
thread = await DurableAgentThread.deserialize(serialized)
|
|
|
|
assert isinstance(thread, DurableAgentThread)
|
|
assert thread.session_id is None
|
|
assert thread.service_thread_id == "thread-456"
|
|
|
|
async def test_round_trip_serialization(self) -> None:
|
|
"""Test round-trip serialization preserves session ID."""
|
|
session_id = AgentSessionId(name="TestAgent", key="test-key-789")
|
|
original = DurableAgentThread(session_id=session_id)
|
|
|
|
serialized = await original.serialize()
|
|
restored = await DurableAgentThread.deserialize(serialized)
|
|
|
|
assert isinstance(restored, DurableAgentThread)
|
|
assert restored.session_id is not None
|
|
assert restored.session_id.name == session_id.name
|
|
assert restored.session_id.key == session_id.key
|
|
|
|
async def test_deserialize_invalid_session_id_type(self) -> None:
|
|
"""Test deserialization with invalid session ID type raises error."""
|
|
serialized = {
|
|
"service_thread_id": "thread-123",
|
|
"durable_session_id": 12345, # Invalid type
|
|
}
|
|
|
|
with pytest.raises(ValueError, match="durable_session_id must be a string"):
|
|
await DurableAgentThread.deserialize(serialized)
|
|
|
|
|
|
class TestAgentThreadCompatibility:
|
|
"""Test suite for compatibility between AgentThread and DurableAgentThread."""
|
|
|
|
async def test_agent_thread_serialize(self) -> None:
|
|
"""Test that base AgentThread can be serialized."""
|
|
thread = AgentThread()
|
|
|
|
serialized = await thread.serialize()
|
|
|
|
assert isinstance(serialized, dict)
|
|
assert "service_thread_id" in serialized
|
|
|
|
async def test_agent_thread_deserialize(self) -> None:
|
|
"""Test that base AgentThread can be deserialized."""
|
|
thread = AgentThread()
|
|
serialized = await thread.serialize()
|
|
|
|
restored = await AgentThread.deserialize(serialized)
|
|
|
|
assert isinstance(restored, AgentThread)
|
|
assert restored.service_thread_id == thread.service_thread_id
|
|
|
|
async def test_durable_thread_is_agent_thread(self) -> None:
|
|
"""Test that DurableAgentThread is an AgentThread."""
|
|
thread = DurableAgentThread()
|
|
|
|
assert isinstance(thread, AgentThread)
|
|
assert isinstance(thread, DurableAgentThread)
|
|
|
|
|
|
class TestModelIntegration:
|
|
"""Test suite for integration between models."""
|
|
|
|
def test_session_id_string_format(self) -> None:
|
|
"""Test that AgentSessionId string format is consistent."""
|
|
session_id = AgentSessionId.with_random_key("AgentEntity")
|
|
session_id_str = str(session_id)
|
|
|
|
assert session_id_str.startswith("@AgentEntity@")
|
|
|
|
async def test_thread_with_session_preserves_on_serialization(self) -> None:
|
|
"""Test that thread with session ID preserves it through serialization."""
|
|
session_id = AgentSessionId(name="TestAgent", key="preserved-key")
|
|
thread = DurableAgentThread.from_session_id(session_id)
|
|
|
|
# Serialize and deserialize
|
|
serialized = await thread.serialize()
|
|
restored = await DurableAgentThread.deserialize(serialized)
|
|
|
|
# Session ID should be preserved
|
|
assert restored.session_id is not None
|
|
assert restored.session_id.name == "TestAgent"
|
|
assert restored.session_id.key == "preserved-key"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v", "--tb=short"])
|