Files
agent-framework/python/packages/durabletask/tests/test_durable_agent_state.py
T
Eduard van Valkenburg 705ed47a0b Python: Fix missing methods on the Content class in durable tasks (#4738)
* Fix Content serialization in DurableAgentStateUnknownContent (#4719)

DurableAgentStateUnknownContent.from_unknown_content() stored raw Content
objects without converting them to dicts, causing json.dumps to fail in
Azure Durable Functions' entity state serialization. This affected content
types not explicitly handled (e.g., mcp_server_tool_call/result).

The fix converts Content objects to dicts via to_dict() when storing in
DurableAgentStateUnknownContent, and restores them via Content.from_dict()
in to_ai_content().

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Add to_json and from_json methods to Content class (#4719)

Add to_json() and from_json() methods to the Content class to match the
serialization interface provided by SerializationMixin on other model classes.
Also fix pre-existing pyright type errors in durabletask's
DurableAgentStateUnknownContent.to_ai_content().

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Address PR review: add type guard, remove to_json, add fallback, and tests

- Remove Content.to_json() per reviewer request (comment 3)
- Add type guard in Content.from_json() for non-dict JSON (comments 1, 4)
- Wrap json.JSONDecodeError as ValueError for consistent exception contract
- Add try/except fallback in to_ai_content() for invalid Content dicts (comment 5)
- Add test_content_to_dict_exclude_none and test_content_to_dict_exclude_fields (comment 2)
- Add test_unknown_content_to_ai_content_fallback_on_invalid_type_dict (comment 5)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Apply pre-commit auto-fixes

* Address review feedback for #4719: review comment fixes

* Remove Content.from_json, move logic to consuming code (#4719)

Remove the from_json convenience method from Content class per review
feedback. This is the same trivial json.loads + from_dict wrapper as
to_json which was already removed. Consumers should call json.loads
and Content.from_dict directly.

Update tests to use Content.from_dict(json.loads(...)) pattern and
remove from_json-specific error handling tests (those errors are
already covered by json.loads and Content.from_dict).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-18 08:08:44 +00:00

493 lines
18 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
"""Unit tests for DurableAgentState and related classes."""
import json
from datetime import datetime
import pytest
from agent_framework import Content, Message, UsageDetails
from agent_framework_durabletask._durable_agent_state import (
DurableAgentState,
DurableAgentStateContent,
DurableAgentStateMessage,
DurableAgentStateRequest,
DurableAgentStateTextContent,
DurableAgentStateUnknownContent,
DurableAgentStateUsage,
)
from agent_framework_durabletask._models import RunRequest
class TestDurableAgentStateRequestOrchestrationId:
"""Test suite for DurableAgentStateRequest orchestration_id field."""
def test_request_with_orchestration_id(self) -> None:
"""Test creating a request with an orchestration_id."""
request = DurableAgentStateRequest(
correlation_id="corr-123",
created_at=datetime.now(),
messages=[
DurableAgentStateMessage(
role="user",
contents=[DurableAgentStateTextContent(text="test")],
)
],
orchestration_id="orch-456",
)
assert request.orchestration_id == "orch-456"
def test_request_to_dict_includes_orchestration_id(self) -> None:
"""Test that to_dict includes orchestrationId when set."""
request = DurableAgentStateRequest(
correlation_id="corr-123",
created_at=datetime.now(),
messages=[
DurableAgentStateMessage(
role="user",
contents=[DurableAgentStateTextContent(text="test")],
)
],
orchestration_id="orch-789",
)
data = request.to_dict()
assert "orchestrationId" in data
assert data["orchestrationId"] == "orch-789"
def test_request_to_dict_excludes_orchestration_id_when_none(self) -> None:
"""Test that to_dict excludes orchestrationId when not set."""
request = DurableAgentStateRequest(
correlation_id="corr-123",
created_at=datetime.now(),
messages=[
DurableAgentStateMessage(
role="user",
contents=[DurableAgentStateTextContent(text="test")],
)
],
)
data = request.to_dict()
assert "orchestrationId" not in data
def test_request_from_dict_with_orchestration_id(self) -> None:
"""Test from_dict correctly parses orchestrationId."""
data = {
"$type": "request",
"correlationId": "corr-123",
"createdAt": "2024-01-01T00:00:00Z",
"messages": [{"role": "user", "contents": [{"$type": "text", "text": "test"}]}],
"orchestrationId": "orch-from-dict",
}
request = DurableAgentStateRequest.from_dict(data)
assert request.orchestration_id == "orch-from-dict"
def test_request_from_run_request_with_orchestration_id(self) -> None:
"""Test from_run_request correctly transfers orchestration_id."""
run_request = RunRequest(
message="test message",
correlation_id="corr-run",
orchestration_id="orch-from-run-request",
)
durable_request = DurableAgentStateRequest.from_run_request(run_request)
assert durable_request.orchestration_id == "orch-from-run-request"
def test_request_from_run_request_without_orchestration_id(self) -> None:
"""Test from_run_request correctly handles missing orchestration_id."""
run_request = RunRequest(
message="test message",
correlation_id="corr-run",
)
durable_request = DurableAgentStateRequest.from_run_request(run_request)
assert durable_request.orchestration_id is None
class TestDurableAgentStateMessageCreatedAt:
"""Test suite for DurableAgentStateMessage created_at field handling."""
def test_message_from_run_request_without_created_at_preserves_none(self) -> None:
"""Test from_run_request handles auto-populated created_at from RunRequest.
When a RunRequest is created with None for created_at, RunRequest defaults it to
current UTC time. The resulting DurableAgentStateMessage should have this timestamp.
"""
run_request = RunRequest(
message="test message",
correlation_id="corr-run",
created_at=None, # RunRequest will default this to current time
)
durable_message = DurableAgentStateMessage.from_run_request(run_request)
# RunRequest auto-populates created_at, so it should not be None
assert durable_message.created_at is not None
def test_message_from_run_request_with_created_at_parses_correctly(self) -> None:
"""Test from_run_request correctly parses a valid created_at timestamp."""
run_request = RunRequest(
message="test message",
correlation_id="corr-run",
created_at=datetime(2024, 1, 15, 10, 30, 0),
)
durable_message = DurableAgentStateMessage.from_run_request(run_request)
assert durable_message.created_at is not None
assert durable_message.created_at.year == 2024
assert durable_message.created_at.month == 1
assert durable_message.created_at.day == 15
class TestDurableAgentState:
"""Test suite for DurableAgentState."""
def test_schema_version(self) -> None:
"""Test that schema version is set correctly."""
state = DurableAgentState()
assert state.schema_version == "1.1.0"
def test_to_dict_serialization(self) -> None:
"""Test that to_dict produces correct structure."""
state = DurableAgentState()
data = state.to_dict()
assert "schemaVersion" in data
assert "data" in data
assert data["schemaVersion"] == "1.1.0"
assert "conversationHistory" in data["data"]
def test_from_dict_deserialization(self) -> None:
"""Test that from_dict restores state correctly."""
original_data = {
"schemaVersion": "1.1.0",
"data": {
"conversationHistory": [
{
"$type": "request",
"correlationId": "test-123",
"createdAt": "2024-01-01T00:00:00Z",
"messages": [
{
"role": "user",
"contents": [{"$type": "text", "text": "Hello"}],
}
],
}
]
},
}
state = DurableAgentState.from_dict(original_data)
assert state.schema_version == "1.1.0"
assert len(state.data.conversation_history) == 1
assert isinstance(state.data.conversation_history[0], DurableAgentStateRequest)
def test_round_trip_serialization(self) -> None:
"""Test that round-trip serialization preserves data."""
state = DurableAgentState()
state.data.conversation_history.append(
DurableAgentStateRequest(
correlation_id="test-456",
created_at=datetime.now(),
messages=[
DurableAgentStateMessage(
role="user",
contents=[DurableAgentStateTextContent(text="Test message")],
)
],
)
)
data = state.to_dict()
restored = DurableAgentState.from_dict(data)
assert restored.schema_version == state.schema_version
assert len(restored.data.conversation_history) == len(state.data.conversation_history)
assert restored.data.conversation_history[0].correlation_id == "test-456"
class TestDurableAgentStateUsage:
"""Test suite for DurableAgentStateUsage."""
def test_usage_init_with_defaults(self) -> None:
"""Test creating usage with default values."""
usage = DurableAgentStateUsage()
assert usage.input_token_count is None
assert usage.output_token_count is None
assert usage.total_token_count is None
assert usage.extensionData is None
def test_usage_init_with_values(self) -> None:
"""Test creating usage with specific values."""
usage = DurableAgentStateUsage(
input_token_count=100,
output_token_count=200,
total_token_count=300,
extensionData={"custom_field": "value"},
)
assert usage.input_token_count == 100
assert usage.output_token_count == 200
assert usage.total_token_count == 300
assert usage.extensionData == {"custom_field": "value"}
def test_usage_to_dict(self) -> None:
"""Test that to_dict produces correct structure."""
usage = DurableAgentStateUsage(
input_token_count=50,
output_token_count=75,
total_token_count=125,
)
data = usage.to_dict()
assert data["inputTokenCount"] == 50
assert data["outputTokenCount"] == 75
assert data["totalTokenCount"] == 125
def test_usage_to_dict_with_extension_data(self) -> None:
"""Test that to_dict includes extensionData when present."""
usage = DurableAgentStateUsage(
input_token_count=10,
output_token_count=20,
total_token_count=30,
extensionData={"provider_specific": 123},
)
data = usage.to_dict()
assert "extensionData" in data
assert data["extensionData"] == {"provider_specific": 123}
def test_usage_from_dict(self) -> None:
"""Test that from_dict restores usage correctly."""
data = {
"inputTokenCount": 100,
"outputTokenCount": 200,
"totalTokenCount": 300,
"extensionData": {"extra": "data"},
}
usage = DurableAgentStateUsage.from_dict(data)
assert usage.input_token_count == 100
assert usage.output_token_count == 200
assert usage.total_token_count == 300
assert usage.extensionData == {"extra": "data"}
def test_usage_from_usage_details(self) -> None:
"""Test creating DurableAgentStateUsage from UsageDetails."""
usage_details: UsageDetails = {
"input_token_count": 150,
"output_token_count": 250,
"total_token_count": 400,
}
usage = DurableAgentStateUsage.from_usage(usage_details)
assert usage is not None
assert usage.input_token_count == 150
assert usage.output_token_count == 250
assert usage.total_token_count == 400
def test_usage_from_usage_details_with_extension_fields(self) -> None:
"""Test that non-standard fields are captured in extensionData."""
usage_details: UsageDetails = {
"input_token_count": 100,
"output_token_count": 200,
"total_token_count": 300,
}
# Add provider-specific fields (UsageDetails is a TypedDict but allows extra keys)
usage_details["prompt_tokens"] = 100 # type: ignore[typeddict-unknown-key]
usage_details["completion_tokens"] = 200 # type: ignore[typeddict-unknown-key]
usage = DurableAgentStateUsage.from_usage(usage_details)
assert usage is not None
assert usage.extensionData is not None
assert usage.extensionData["prompt_tokens"] == 100
assert usage.extensionData["completion_tokens"] == 200
def test_usage_from_usage_none(self) -> None:
"""Test that from_usage returns None for None input."""
usage = DurableAgentStateUsage.from_usage(None)
assert usage is None
def test_usage_to_usage_details(self) -> None:
"""Test converting back to UsageDetails."""
usage = DurableAgentStateUsage(
input_token_count=100,
output_token_count=200,
total_token_count=300,
)
details = usage.to_usage_details()
assert details.get("input_token_count") == 100
assert details.get("output_token_count") == 200
assert details.get("total_token_count") == 300
def test_usage_to_usage_details_with_extension_data(self) -> None:
"""Test that extensionData is merged into UsageDetails."""
usage = DurableAgentStateUsage(
input_token_count=50,
output_token_count=75,
total_token_count=125,
extensionData={"prompt_tokens": 50, "completion_tokens": 75},
)
details = usage.to_usage_details()
assert details.get("input_token_count") == 50
assert details.get("output_token_count") == 75
assert details.get("total_token_count") == 125
# Extension data should be merged into the result
assert details.get("prompt_tokens") == 50
assert details.get("completion_tokens") == 75
def test_usage_round_trip(self) -> None:
"""Test round-trip conversion from UsageDetails to DurableAgentStateUsage and back."""
original: UsageDetails = {
"input_token_count": 100,
"output_token_count": 200,
"total_token_count": 300,
}
usage = DurableAgentStateUsage.from_usage(original)
assert usage is not None
restored = usage.to_usage_details()
assert restored.get("input_token_count") == original.get("input_token_count")
assert restored.get("output_token_count") == original.get("output_token_count")
assert restored.get("total_token_count") == original.get("total_token_count")
class TestDurableAgentStateUnknownContent:
"""Test suite for DurableAgentStateUnknownContent serialization."""
def test_unknown_content_from_content_object_produces_serializable_dict(self) -> None:
"""Test that from_unknown_content serializes Content objects to dicts."""
content = Content.from_mcp_server_tool_call(
call_id="call-1",
tool_name="search",
server_name="learn-mcp",
arguments={"query": "azure functions"},
)
unknown = DurableAgentStateUnknownContent.from_unknown_content(content)
result = unknown.to_dict()
# The content field should be a dict, not a Content object
assert isinstance(result["content"], dict)
assert result["content"]["type"] == "mcp_server_tool_call"
def test_unknown_content_to_dict_is_json_serializable(self) -> None:
"""Test that to_dict output can be passed to json.dumps without error."""
content = Content.from_mcp_server_tool_result(
call_id="call-1",
output="Azure Functions documentation...",
)
unknown = DurableAgentStateUnknownContent.from_unknown_content(content)
result = unknown.to_dict()
# This must not raise TypeError
serialized = json.dumps(result)
assert serialized is not None
def test_unknown_content_round_trip_preserves_content(self) -> None:
"""Test that Content objects survive serialization and deserialization."""
original = Content.from_mcp_server_tool_call(
call_id="call-1",
tool_name="fetch",
server_name="learn-mcp",
arguments={"url": "https://example.com"},
)
unknown = DurableAgentStateUnknownContent.from_unknown_content(original)
restored = unknown.to_ai_content()
assert restored.type == "mcp_server_tool_call"
assert restored.tool_name == "fetch"
assert restored.server_name == "learn-mcp"
def test_unknown_content_from_plain_dict_unchanged(self) -> None:
"""Test that non-Content values are stored as-is."""
plain = {"some": "data"}
unknown = DurableAgentStateUnknownContent.from_unknown_content(plain)
assert unknown.content == {"some": "data"}
def test_unknown_content_to_ai_content_fallback_on_invalid_type_dict(self) -> None:
"""Test that to_ai_content falls back when dict has 'type' but is not valid Content."""
invalid = {"type": "bogus_not_a_real_content_type", "extra": "stuff"}
unknown = DurableAgentStateUnknownContent(content=invalid)
result = unknown.to_ai_content()
assert result.type == "unknown"
assert result.additional_properties == {"content": invalid}
def test_from_ai_content_unknown_type_produces_serializable_state(self) -> None:
"""Test that unknown content types in message conversion produce JSON-serializable state."""
content = Content.from_mcp_server_tool_call(
call_id="call-1",
tool_name="search",
server_name="learn-mcp",
arguments={"query": "create function app"},
)
durable_content = DurableAgentStateContent.from_ai_content(content)
data = durable_content.to_dict()
# Must be fully JSON-serializable
serialized = json.dumps(data)
assert serialized is not None
def test_state_with_mcp_content_is_json_serializable(self) -> None:
"""Test that full DurableAgentState with MCP content can be serialized to JSON.
This reproduces the scenario from issue #4719 where agent state containing
MCP tool content could not be serialized by Azure Durable Functions.
"""
state = DurableAgentState()
mcp_content = Content.from_mcp_server_tool_call(
call_id="call-1",
tool_name="search",
server_name="learn-mcp",
arguments={"query": "azure functions"},
)
message = DurableAgentStateMessage.from_chat_message(Message(role="assistant", contents=[mcp_content]))
state.data.conversation_history.append(
DurableAgentStateRequest(
correlation_id="test-mcp",
created_at=datetime.now(),
messages=[message],
)
)
state_dict = state.to_dict()
# This simulates what Azure Durable Functions does with entity state
serialized = json.dumps(state_dict)
assert serialized is not None
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])