mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
705ed47a0b
* Fix Content serialization in DurableAgentStateUnknownContent (#4719) DurableAgentStateUnknownContent.from_unknown_content() stored raw Content objects without converting them to dicts, causing json.dumps to fail in Azure Durable Functions' entity state serialization. This affected content types not explicitly handled (e.g., mcp_server_tool_call/result). The fix converts Content objects to dicts via to_dict() when storing in DurableAgentStateUnknownContent, and restores them via Content.from_dict() in to_ai_content(). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Add to_json and from_json methods to Content class (#4719) Add to_json() and from_json() methods to the Content class to match the serialization interface provided by SerializationMixin on other model classes. Also fix pre-existing pyright type errors in durabletask's DurableAgentStateUnknownContent.to_ai_content(). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address PR review: add type guard, remove to_json, add fallback, and tests - Remove Content.to_json() per reviewer request (comment 3) - Add type guard in Content.from_json() for non-dict JSON (comments 1, 4) - Wrap json.JSONDecodeError as ValueError for consistent exception contract - Add try/except fallback in to_ai_content() for invalid Content dicts (comment 5) - Add test_content_to_dict_exclude_none and test_content_to_dict_exclude_fields (comment 2) - Add test_unknown_content_to_ai_content_fallback_on_invalid_type_dict (comment 5) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Apply pre-commit auto-fixes * Address review feedback for #4719: review comment fixes * Remove Content.from_json, move logic to consuming code (#4719) Remove the from_json convenience method from Content class per review feedback. This is the same trivial json.loads + from_dict wrapper as to_json which was already removed. Consumers should call json.loads and Content.from_dict directly. Update tests to use Content.from_dict(json.loads(...)) pattern and remove from_json-specific error handling tests (those errors are already covered by json.loads and Content.from_dict). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
493 lines
18 KiB
Python
493 lines
18 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
"""Unit tests for DurableAgentState and related classes."""
|
|
|
|
import json
|
|
from datetime import datetime
|
|
|
|
import pytest
|
|
from agent_framework import Content, Message, UsageDetails
|
|
|
|
from agent_framework_durabletask._durable_agent_state import (
|
|
DurableAgentState,
|
|
DurableAgentStateContent,
|
|
DurableAgentStateMessage,
|
|
DurableAgentStateRequest,
|
|
DurableAgentStateTextContent,
|
|
DurableAgentStateUnknownContent,
|
|
DurableAgentStateUsage,
|
|
)
|
|
from agent_framework_durabletask._models import RunRequest
|
|
|
|
|
|
class TestDurableAgentStateRequestOrchestrationId:
|
|
"""Test suite for DurableAgentStateRequest orchestration_id field."""
|
|
|
|
def test_request_with_orchestration_id(self) -> None:
|
|
"""Test creating a request with an orchestration_id."""
|
|
request = DurableAgentStateRequest(
|
|
correlation_id="corr-123",
|
|
created_at=datetime.now(),
|
|
messages=[
|
|
DurableAgentStateMessage(
|
|
role="user",
|
|
contents=[DurableAgentStateTextContent(text="test")],
|
|
)
|
|
],
|
|
orchestration_id="orch-456",
|
|
)
|
|
|
|
assert request.orchestration_id == "orch-456"
|
|
|
|
def test_request_to_dict_includes_orchestration_id(self) -> None:
|
|
"""Test that to_dict includes orchestrationId when set."""
|
|
request = DurableAgentStateRequest(
|
|
correlation_id="corr-123",
|
|
created_at=datetime.now(),
|
|
messages=[
|
|
DurableAgentStateMessage(
|
|
role="user",
|
|
contents=[DurableAgentStateTextContent(text="test")],
|
|
)
|
|
],
|
|
orchestration_id="orch-789",
|
|
)
|
|
|
|
data = request.to_dict()
|
|
|
|
assert "orchestrationId" in data
|
|
assert data["orchestrationId"] == "orch-789"
|
|
|
|
def test_request_to_dict_excludes_orchestration_id_when_none(self) -> None:
|
|
"""Test that to_dict excludes orchestrationId when not set."""
|
|
request = DurableAgentStateRequest(
|
|
correlation_id="corr-123",
|
|
created_at=datetime.now(),
|
|
messages=[
|
|
DurableAgentStateMessage(
|
|
role="user",
|
|
contents=[DurableAgentStateTextContent(text="test")],
|
|
)
|
|
],
|
|
)
|
|
|
|
data = request.to_dict()
|
|
|
|
assert "orchestrationId" not in data
|
|
|
|
def test_request_from_dict_with_orchestration_id(self) -> None:
|
|
"""Test from_dict correctly parses orchestrationId."""
|
|
data = {
|
|
"$type": "request",
|
|
"correlationId": "corr-123",
|
|
"createdAt": "2024-01-01T00:00:00Z",
|
|
"messages": [{"role": "user", "contents": [{"$type": "text", "text": "test"}]}],
|
|
"orchestrationId": "orch-from-dict",
|
|
}
|
|
|
|
request = DurableAgentStateRequest.from_dict(data)
|
|
|
|
assert request.orchestration_id == "orch-from-dict"
|
|
|
|
def test_request_from_run_request_with_orchestration_id(self) -> None:
|
|
"""Test from_run_request correctly transfers orchestration_id."""
|
|
run_request = RunRequest(
|
|
message="test message",
|
|
correlation_id="corr-run",
|
|
orchestration_id="orch-from-run-request",
|
|
)
|
|
|
|
durable_request = DurableAgentStateRequest.from_run_request(run_request)
|
|
|
|
assert durable_request.orchestration_id == "orch-from-run-request"
|
|
|
|
def test_request_from_run_request_without_orchestration_id(self) -> None:
|
|
"""Test from_run_request correctly handles missing orchestration_id."""
|
|
run_request = RunRequest(
|
|
message="test message",
|
|
correlation_id="corr-run",
|
|
)
|
|
|
|
durable_request = DurableAgentStateRequest.from_run_request(run_request)
|
|
|
|
assert durable_request.orchestration_id is None
|
|
|
|
|
|
class TestDurableAgentStateMessageCreatedAt:
|
|
"""Test suite for DurableAgentStateMessage created_at field handling."""
|
|
|
|
def test_message_from_run_request_without_created_at_preserves_none(self) -> None:
|
|
"""Test from_run_request handles auto-populated created_at from RunRequest.
|
|
|
|
When a RunRequest is created with None for created_at, RunRequest defaults it to
|
|
current UTC time. The resulting DurableAgentStateMessage should have this timestamp.
|
|
"""
|
|
run_request = RunRequest(
|
|
message="test message",
|
|
correlation_id="corr-run",
|
|
created_at=None, # RunRequest will default this to current time
|
|
)
|
|
|
|
durable_message = DurableAgentStateMessage.from_run_request(run_request)
|
|
|
|
# RunRequest auto-populates created_at, so it should not be None
|
|
assert durable_message.created_at is not None
|
|
|
|
def test_message_from_run_request_with_created_at_parses_correctly(self) -> None:
|
|
"""Test from_run_request correctly parses a valid created_at timestamp."""
|
|
run_request = RunRequest(
|
|
message="test message",
|
|
correlation_id="corr-run",
|
|
created_at=datetime(2024, 1, 15, 10, 30, 0),
|
|
)
|
|
|
|
durable_message = DurableAgentStateMessage.from_run_request(run_request)
|
|
|
|
assert durable_message.created_at is not None
|
|
assert durable_message.created_at.year == 2024
|
|
assert durable_message.created_at.month == 1
|
|
assert durable_message.created_at.day == 15
|
|
|
|
|
|
class TestDurableAgentState:
|
|
"""Test suite for DurableAgentState."""
|
|
|
|
def test_schema_version(self) -> None:
|
|
"""Test that schema version is set correctly."""
|
|
state = DurableAgentState()
|
|
assert state.schema_version == "1.1.0"
|
|
|
|
def test_to_dict_serialization(self) -> None:
|
|
"""Test that to_dict produces correct structure."""
|
|
state = DurableAgentState()
|
|
data = state.to_dict()
|
|
|
|
assert "schemaVersion" in data
|
|
assert "data" in data
|
|
assert data["schemaVersion"] == "1.1.0"
|
|
assert "conversationHistory" in data["data"]
|
|
|
|
def test_from_dict_deserialization(self) -> None:
|
|
"""Test that from_dict restores state correctly."""
|
|
original_data = {
|
|
"schemaVersion": "1.1.0",
|
|
"data": {
|
|
"conversationHistory": [
|
|
{
|
|
"$type": "request",
|
|
"correlationId": "test-123",
|
|
"createdAt": "2024-01-01T00:00:00Z",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"contents": [{"$type": "text", "text": "Hello"}],
|
|
}
|
|
],
|
|
}
|
|
]
|
|
},
|
|
}
|
|
|
|
state = DurableAgentState.from_dict(original_data)
|
|
|
|
assert state.schema_version == "1.1.0"
|
|
assert len(state.data.conversation_history) == 1
|
|
assert isinstance(state.data.conversation_history[0], DurableAgentStateRequest)
|
|
|
|
def test_round_trip_serialization(self) -> None:
|
|
"""Test that round-trip serialization preserves data."""
|
|
state = DurableAgentState()
|
|
state.data.conversation_history.append(
|
|
DurableAgentStateRequest(
|
|
correlation_id="test-456",
|
|
created_at=datetime.now(),
|
|
messages=[
|
|
DurableAgentStateMessage(
|
|
role="user",
|
|
contents=[DurableAgentStateTextContent(text="Test message")],
|
|
)
|
|
],
|
|
)
|
|
)
|
|
|
|
data = state.to_dict()
|
|
restored = DurableAgentState.from_dict(data)
|
|
|
|
assert restored.schema_version == state.schema_version
|
|
assert len(restored.data.conversation_history) == len(state.data.conversation_history)
|
|
assert restored.data.conversation_history[0].correlation_id == "test-456"
|
|
|
|
|
|
class TestDurableAgentStateUsage:
|
|
"""Test suite for DurableAgentStateUsage."""
|
|
|
|
def test_usage_init_with_defaults(self) -> None:
|
|
"""Test creating usage with default values."""
|
|
usage = DurableAgentStateUsage()
|
|
|
|
assert usage.input_token_count is None
|
|
assert usage.output_token_count is None
|
|
assert usage.total_token_count is None
|
|
assert usage.extensionData is None
|
|
|
|
def test_usage_init_with_values(self) -> None:
|
|
"""Test creating usage with specific values."""
|
|
usage = DurableAgentStateUsage(
|
|
input_token_count=100,
|
|
output_token_count=200,
|
|
total_token_count=300,
|
|
extensionData={"custom_field": "value"},
|
|
)
|
|
|
|
assert usage.input_token_count == 100
|
|
assert usage.output_token_count == 200
|
|
assert usage.total_token_count == 300
|
|
assert usage.extensionData == {"custom_field": "value"}
|
|
|
|
def test_usage_to_dict(self) -> None:
|
|
"""Test that to_dict produces correct structure."""
|
|
usage = DurableAgentStateUsage(
|
|
input_token_count=50,
|
|
output_token_count=75,
|
|
total_token_count=125,
|
|
)
|
|
|
|
data = usage.to_dict()
|
|
|
|
assert data["inputTokenCount"] == 50
|
|
assert data["outputTokenCount"] == 75
|
|
assert data["totalTokenCount"] == 125
|
|
|
|
def test_usage_to_dict_with_extension_data(self) -> None:
|
|
"""Test that to_dict includes extensionData when present."""
|
|
usage = DurableAgentStateUsage(
|
|
input_token_count=10,
|
|
output_token_count=20,
|
|
total_token_count=30,
|
|
extensionData={"provider_specific": 123},
|
|
)
|
|
|
|
data = usage.to_dict()
|
|
|
|
assert "extensionData" in data
|
|
assert data["extensionData"] == {"provider_specific": 123}
|
|
|
|
def test_usage_from_dict(self) -> None:
|
|
"""Test that from_dict restores usage correctly."""
|
|
data = {
|
|
"inputTokenCount": 100,
|
|
"outputTokenCount": 200,
|
|
"totalTokenCount": 300,
|
|
"extensionData": {"extra": "data"},
|
|
}
|
|
|
|
usage = DurableAgentStateUsage.from_dict(data)
|
|
|
|
assert usage.input_token_count == 100
|
|
assert usage.output_token_count == 200
|
|
assert usage.total_token_count == 300
|
|
assert usage.extensionData == {"extra": "data"}
|
|
|
|
def test_usage_from_usage_details(self) -> None:
|
|
"""Test creating DurableAgentStateUsage from UsageDetails."""
|
|
usage_details: UsageDetails = {
|
|
"input_token_count": 150,
|
|
"output_token_count": 250,
|
|
"total_token_count": 400,
|
|
}
|
|
|
|
usage = DurableAgentStateUsage.from_usage(usage_details)
|
|
|
|
assert usage is not None
|
|
assert usage.input_token_count == 150
|
|
assert usage.output_token_count == 250
|
|
assert usage.total_token_count == 400
|
|
|
|
def test_usage_from_usage_details_with_extension_fields(self) -> None:
|
|
"""Test that non-standard fields are captured in extensionData."""
|
|
usage_details: UsageDetails = {
|
|
"input_token_count": 100,
|
|
"output_token_count": 200,
|
|
"total_token_count": 300,
|
|
}
|
|
# Add provider-specific fields (UsageDetails is a TypedDict but allows extra keys)
|
|
usage_details["prompt_tokens"] = 100 # type: ignore[typeddict-unknown-key]
|
|
usage_details["completion_tokens"] = 200 # type: ignore[typeddict-unknown-key]
|
|
|
|
usage = DurableAgentStateUsage.from_usage(usage_details)
|
|
|
|
assert usage is not None
|
|
assert usage.extensionData is not None
|
|
assert usage.extensionData["prompt_tokens"] == 100
|
|
assert usage.extensionData["completion_tokens"] == 200
|
|
|
|
def test_usage_from_usage_none(self) -> None:
|
|
"""Test that from_usage returns None for None input."""
|
|
usage = DurableAgentStateUsage.from_usage(None)
|
|
|
|
assert usage is None
|
|
|
|
def test_usage_to_usage_details(self) -> None:
|
|
"""Test converting back to UsageDetails."""
|
|
usage = DurableAgentStateUsage(
|
|
input_token_count=100,
|
|
output_token_count=200,
|
|
total_token_count=300,
|
|
)
|
|
|
|
details = usage.to_usage_details()
|
|
|
|
assert details.get("input_token_count") == 100
|
|
assert details.get("output_token_count") == 200
|
|
assert details.get("total_token_count") == 300
|
|
|
|
def test_usage_to_usage_details_with_extension_data(self) -> None:
|
|
"""Test that extensionData is merged into UsageDetails."""
|
|
usage = DurableAgentStateUsage(
|
|
input_token_count=50,
|
|
output_token_count=75,
|
|
total_token_count=125,
|
|
extensionData={"prompt_tokens": 50, "completion_tokens": 75},
|
|
)
|
|
|
|
details = usage.to_usage_details()
|
|
|
|
assert details.get("input_token_count") == 50
|
|
assert details.get("output_token_count") == 75
|
|
assert details.get("total_token_count") == 125
|
|
# Extension data should be merged into the result
|
|
assert details.get("prompt_tokens") == 50
|
|
assert details.get("completion_tokens") == 75
|
|
|
|
def test_usage_round_trip(self) -> None:
|
|
"""Test round-trip conversion from UsageDetails to DurableAgentStateUsage and back."""
|
|
original: UsageDetails = {
|
|
"input_token_count": 100,
|
|
"output_token_count": 200,
|
|
"total_token_count": 300,
|
|
}
|
|
|
|
usage = DurableAgentStateUsage.from_usage(original)
|
|
assert usage is not None
|
|
restored = usage.to_usage_details()
|
|
|
|
assert restored.get("input_token_count") == original.get("input_token_count")
|
|
assert restored.get("output_token_count") == original.get("output_token_count")
|
|
assert restored.get("total_token_count") == original.get("total_token_count")
|
|
|
|
|
|
class TestDurableAgentStateUnknownContent:
|
|
"""Test suite for DurableAgentStateUnknownContent serialization."""
|
|
|
|
def test_unknown_content_from_content_object_produces_serializable_dict(self) -> None:
|
|
"""Test that from_unknown_content serializes Content objects to dicts."""
|
|
content = Content.from_mcp_server_tool_call(
|
|
call_id="call-1",
|
|
tool_name="search",
|
|
server_name="learn-mcp",
|
|
arguments={"query": "azure functions"},
|
|
)
|
|
|
|
unknown = DurableAgentStateUnknownContent.from_unknown_content(content)
|
|
result = unknown.to_dict()
|
|
|
|
# The content field should be a dict, not a Content object
|
|
assert isinstance(result["content"], dict)
|
|
assert result["content"]["type"] == "mcp_server_tool_call"
|
|
|
|
def test_unknown_content_to_dict_is_json_serializable(self) -> None:
|
|
"""Test that to_dict output can be passed to json.dumps without error."""
|
|
content = Content.from_mcp_server_tool_result(
|
|
call_id="call-1",
|
|
output="Azure Functions documentation...",
|
|
)
|
|
|
|
unknown = DurableAgentStateUnknownContent.from_unknown_content(content)
|
|
result = unknown.to_dict()
|
|
|
|
# This must not raise TypeError
|
|
serialized = json.dumps(result)
|
|
assert serialized is not None
|
|
|
|
def test_unknown_content_round_trip_preserves_content(self) -> None:
|
|
"""Test that Content objects survive serialization and deserialization."""
|
|
original = Content.from_mcp_server_tool_call(
|
|
call_id="call-1",
|
|
tool_name="fetch",
|
|
server_name="learn-mcp",
|
|
arguments={"url": "https://example.com"},
|
|
)
|
|
|
|
unknown = DurableAgentStateUnknownContent.from_unknown_content(original)
|
|
restored = unknown.to_ai_content()
|
|
|
|
assert restored.type == "mcp_server_tool_call"
|
|
assert restored.tool_name == "fetch"
|
|
assert restored.server_name == "learn-mcp"
|
|
|
|
def test_unknown_content_from_plain_dict_unchanged(self) -> None:
|
|
"""Test that non-Content values are stored as-is."""
|
|
plain = {"some": "data"}
|
|
|
|
unknown = DurableAgentStateUnknownContent.from_unknown_content(plain)
|
|
|
|
assert unknown.content == {"some": "data"}
|
|
|
|
def test_unknown_content_to_ai_content_fallback_on_invalid_type_dict(self) -> None:
|
|
"""Test that to_ai_content falls back when dict has 'type' but is not valid Content."""
|
|
invalid = {"type": "bogus_not_a_real_content_type", "extra": "stuff"}
|
|
unknown = DurableAgentStateUnknownContent(content=invalid)
|
|
|
|
result = unknown.to_ai_content()
|
|
|
|
assert result.type == "unknown"
|
|
assert result.additional_properties == {"content": invalid}
|
|
|
|
def test_from_ai_content_unknown_type_produces_serializable_state(self) -> None:
|
|
"""Test that unknown content types in message conversion produce JSON-serializable state."""
|
|
content = Content.from_mcp_server_tool_call(
|
|
call_id="call-1",
|
|
tool_name="search",
|
|
server_name="learn-mcp",
|
|
arguments={"query": "create function app"},
|
|
)
|
|
|
|
durable_content = DurableAgentStateContent.from_ai_content(content)
|
|
data = durable_content.to_dict()
|
|
|
|
# Must be fully JSON-serializable
|
|
serialized = json.dumps(data)
|
|
assert serialized is not None
|
|
|
|
def test_state_with_mcp_content_is_json_serializable(self) -> None:
|
|
"""Test that full DurableAgentState with MCP content can be serialized to JSON.
|
|
|
|
This reproduces the scenario from issue #4719 where agent state containing
|
|
MCP tool content could not be serialized by Azure Durable Functions.
|
|
"""
|
|
state = DurableAgentState()
|
|
mcp_content = Content.from_mcp_server_tool_call(
|
|
call_id="call-1",
|
|
tool_name="search",
|
|
server_name="learn-mcp",
|
|
arguments={"query": "azure functions"},
|
|
)
|
|
message = DurableAgentStateMessage.from_chat_message(Message(role="assistant", contents=[mcp_content]))
|
|
state.data.conversation_history.append(
|
|
DurableAgentStateRequest(
|
|
correlation_id="test-mcp",
|
|
created_at=datetime.now(),
|
|
messages=[message],
|
|
)
|
|
)
|
|
|
|
state_dict = state.to_dict()
|
|
|
|
# This simulates what Azure Durable Functions does with entity state
|
|
serialized = json.dumps(state_dict)
|
|
assert serialized is not None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v", "--tb=short"])
|