mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
19676978e9
* Add workflow checkpointing functionality. * Reintroduce protocol that went missing during merge * Checkpoint updates * Fix ordering of checkpointing * Cleanup * Cleanup - thanks Copilot * Cleanup - thanks Copilot * State reset updates * State reset updates 2 * Workflow fixes and updates. Addressed PR feedback * A few updates
335 lines
14 KiB
Python
335 lines
14 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
import json
|
|
import tempfile
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
from agent_framework_workflow._checkpoint import (
|
|
FileCheckpointStorage,
|
|
InMemoryCheckpointStorage,
|
|
WorkflowCheckpoint,
|
|
)
|
|
|
|
|
|
def test_workflow_checkpoint_default_values():
|
|
checkpoint = WorkflowCheckpoint()
|
|
|
|
assert checkpoint.checkpoint_id != ""
|
|
assert checkpoint.workflow_id == ""
|
|
assert checkpoint.timestamp != ""
|
|
assert checkpoint.messages == {}
|
|
assert checkpoint.shared_state == {}
|
|
assert checkpoint.executor_states == {}
|
|
assert checkpoint.iteration_count == 0
|
|
assert checkpoint.max_iterations == 100
|
|
assert checkpoint.metadata == {}
|
|
assert checkpoint.version == "1.0"
|
|
|
|
|
|
def test_workflow_checkpoint_custom_values():
|
|
custom_timestamp = datetime.now(timezone.utc).isoformat()
|
|
checkpoint = WorkflowCheckpoint(
|
|
checkpoint_id="test-checkpoint-123",
|
|
workflow_id="test-workflow-456",
|
|
timestamp=custom_timestamp,
|
|
messages={"executor1": [{"data": "test"}]},
|
|
shared_state={"key": "value"},
|
|
executor_states={"executor1": {"state": "active"}},
|
|
iteration_count=5,
|
|
max_iterations=50,
|
|
metadata={"test": True},
|
|
version="2.0",
|
|
)
|
|
|
|
assert checkpoint.checkpoint_id == "test-checkpoint-123"
|
|
assert checkpoint.workflow_id == "test-workflow-456"
|
|
assert checkpoint.timestamp == custom_timestamp
|
|
assert checkpoint.messages == {"executor1": [{"data": "test"}]}
|
|
assert checkpoint.shared_state == {"key": "value"}
|
|
assert checkpoint.executor_states == {"executor1": {"state": "active"}}
|
|
assert checkpoint.iteration_count == 5
|
|
assert checkpoint.max_iterations == 50
|
|
assert checkpoint.metadata == {"test": True}
|
|
assert checkpoint.version == "2.0"
|
|
|
|
|
|
async def test_memory_checkpoint_storage_save_and_load():
|
|
storage = InMemoryCheckpointStorage()
|
|
checkpoint = WorkflowCheckpoint(workflow_id="test-workflow", messages={"executor1": [{"data": "hello"}]})
|
|
|
|
# Save checkpoint
|
|
saved_id = await storage.save_checkpoint(checkpoint)
|
|
assert saved_id == checkpoint.checkpoint_id
|
|
|
|
# Load checkpoint
|
|
loaded_checkpoint = await storage.load_checkpoint(checkpoint.checkpoint_id)
|
|
assert loaded_checkpoint is not None
|
|
assert loaded_checkpoint.checkpoint_id == checkpoint.checkpoint_id
|
|
assert loaded_checkpoint.workflow_id == checkpoint.workflow_id
|
|
assert loaded_checkpoint.messages == checkpoint.messages
|
|
|
|
|
|
async def test_memory_checkpoint_storage_load_nonexistent():
|
|
storage = InMemoryCheckpointStorage()
|
|
|
|
result = await storage.load_checkpoint("nonexistent-id")
|
|
assert result is None
|
|
|
|
|
|
async def test_memory_checkpoint_storage_list_checkpoints():
|
|
storage = InMemoryCheckpointStorage()
|
|
|
|
# Create checkpoints for different workflows
|
|
checkpoint1 = WorkflowCheckpoint(workflow_id="workflow-1")
|
|
checkpoint2 = WorkflowCheckpoint(workflow_id="workflow-1")
|
|
checkpoint3 = WorkflowCheckpoint(workflow_id="workflow-2")
|
|
|
|
await storage.save_checkpoint(checkpoint1)
|
|
await storage.save_checkpoint(checkpoint2)
|
|
await storage.save_checkpoint(checkpoint3)
|
|
|
|
# Test list_checkpoint_ids for workflow-1
|
|
workflow1_checkpoint_ids = await storage.list_checkpoint_ids("workflow-1")
|
|
assert len(workflow1_checkpoint_ids) == 2
|
|
assert checkpoint1.checkpoint_id in workflow1_checkpoint_ids
|
|
assert checkpoint2.checkpoint_id in workflow1_checkpoint_ids
|
|
|
|
# Test list_checkpoints for workflow-1 (returns objects)
|
|
workflow1_checkpoints = await storage.list_checkpoints("workflow-1")
|
|
assert len(workflow1_checkpoints) == 2
|
|
assert all(isinstance(cp, WorkflowCheckpoint) for cp in workflow1_checkpoints)
|
|
assert {cp.checkpoint_id for cp in workflow1_checkpoints} == {checkpoint1.checkpoint_id, checkpoint2.checkpoint_id}
|
|
|
|
# Test list_checkpoint_ids for workflow-2
|
|
workflow2_checkpoint_ids = await storage.list_checkpoint_ids("workflow-2")
|
|
assert len(workflow2_checkpoint_ids) == 1
|
|
assert checkpoint3.checkpoint_id in workflow2_checkpoint_ids
|
|
|
|
# Test list_checkpoints for workflow-2 (returns objects)
|
|
workflow2_checkpoints = await storage.list_checkpoints("workflow-2")
|
|
assert len(workflow2_checkpoints) == 1
|
|
assert workflow2_checkpoints[0].checkpoint_id == checkpoint3.checkpoint_id
|
|
|
|
# Test list_checkpoint_ids for non-existent workflow
|
|
empty_checkpoint_ids = await storage.list_checkpoint_ids("nonexistent-workflow")
|
|
assert len(empty_checkpoint_ids) == 0
|
|
|
|
# Test list_checkpoints for non-existent workflow
|
|
empty_checkpoints = await storage.list_checkpoints("nonexistent-workflow")
|
|
assert len(empty_checkpoints) == 0
|
|
|
|
# Test list_checkpoint_ids without workflow filter (all checkpoints)
|
|
all_checkpoint_ids = await storage.list_checkpoint_ids()
|
|
assert len(all_checkpoint_ids) == 3
|
|
expected_ids = {checkpoint1.checkpoint_id, checkpoint2.checkpoint_id, checkpoint3.checkpoint_id}
|
|
assert expected_ids.issubset(set(all_checkpoint_ids))
|
|
|
|
# Test list_checkpoints without workflow filter (all checkpoints)
|
|
all_checkpoints = await storage.list_checkpoints()
|
|
assert len(all_checkpoints) == 3
|
|
assert all(isinstance(cp, WorkflowCheckpoint) for cp in all_checkpoints)
|
|
|
|
|
|
async def test_memory_checkpoint_storage_delete():
|
|
storage = InMemoryCheckpointStorage()
|
|
checkpoint = WorkflowCheckpoint(workflow_id="test-workflow")
|
|
|
|
# Save checkpoint
|
|
await storage.save_checkpoint(checkpoint)
|
|
assert await storage.load_checkpoint(checkpoint.checkpoint_id) is not None
|
|
|
|
# Delete checkpoint
|
|
result = await storage.delete_checkpoint(checkpoint.checkpoint_id)
|
|
assert result is True
|
|
|
|
# Verify deletion
|
|
assert await storage.load_checkpoint(checkpoint.checkpoint_id) is None
|
|
|
|
# Try to delete again
|
|
result = await storage.delete_checkpoint(checkpoint.checkpoint_id)
|
|
assert result is False
|
|
|
|
|
|
async def test_file_checkpoint_storage_save_and_load():
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
storage = FileCheckpointStorage(temp_dir)
|
|
checkpoint = WorkflowCheckpoint(
|
|
workflow_id="test-workflow",
|
|
messages={"executor1": [{"data": "hello", "source_id": "test", "target_id": None}]},
|
|
shared_state={"key": "value"},
|
|
)
|
|
|
|
# Save checkpoint
|
|
saved_id = await storage.save_checkpoint(checkpoint)
|
|
assert saved_id == checkpoint.checkpoint_id
|
|
|
|
# Verify file was created
|
|
file_path = Path(temp_dir) / f"{checkpoint.checkpoint_id}.json"
|
|
assert file_path.exists()
|
|
|
|
# Load checkpoint
|
|
loaded_checkpoint = await storage.load_checkpoint(checkpoint.checkpoint_id)
|
|
assert loaded_checkpoint is not None
|
|
assert loaded_checkpoint.checkpoint_id == checkpoint.checkpoint_id
|
|
assert loaded_checkpoint.workflow_id == checkpoint.workflow_id
|
|
assert loaded_checkpoint.messages == checkpoint.messages
|
|
assert loaded_checkpoint.shared_state == checkpoint.shared_state
|
|
|
|
|
|
async def test_file_checkpoint_storage_load_nonexistent():
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
storage = FileCheckpointStorage(temp_dir)
|
|
|
|
result = await storage.load_checkpoint("nonexistent-id")
|
|
assert result is None
|
|
|
|
|
|
async def test_file_checkpoint_storage_list_checkpoints():
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
storage = FileCheckpointStorage(temp_dir)
|
|
|
|
# Create checkpoints for different workflows
|
|
checkpoint1 = WorkflowCheckpoint(workflow_id="workflow-1")
|
|
checkpoint2 = WorkflowCheckpoint(workflow_id="workflow-1")
|
|
checkpoint3 = WorkflowCheckpoint(workflow_id="workflow-2")
|
|
|
|
await storage.save_checkpoint(checkpoint1)
|
|
await storage.save_checkpoint(checkpoint2)
|
|
await storage.save_checkpoint(checkpoint3)
|
|
|
|
# Test list_checkpoint_ids for workflow-1
|
|
workflow1_checkpoint_ids = await storage.list_checkpoint_ids("workflow-1")
|
|
assert len(workflow1_checkpoint_ids) == 2
|
|
assert checkpoint1.checkpoint_id in workflow1_checkpoint_ids
|
|
assert checkpoint2.checkpoint_id in workflow1_checkpoint_ids
|
|
|
|
# Test list_checkpoints for workflow-1 (returns objects)
|
|
workflow1_checkpoints = await storage.list_checkpoints("workflow-1")
|
|
assert len(workflow1_checkpoints) == 2
|
|
assert all(isinstance(cp, WorkflowCheckpoint) for cp in workflow1_checkpoints)
|
|
checkpoint_ids = {cp.checkpoint_id for cp in workflow1_checkpoints}
|
|
assert checkpoint_ids == {checkpoint1.checkpoint_id, checkpoint2.checkpoint_id}
|
|
|
|
# Test list_checkpoint_ids for workflow-2
|
|
workflow2_checkpoint_ids = await storage.list_checkpoint_ids("workflow-2")
|
|
assert len(workflow2_checkpoint_ids) == 1
|
|
assert checkpoint3.checkpoint_id in workflow2_checkpoint_ids
|
|
|
|
# Test list_checkpoints for workflow-2 (returns objects)
|
|
workflow2_checkpoints = await storage.list_checkpoints("workflow-2")
|
|
assert len(workflow2_checkpoints) == 1
|
|
assert workflow2_checkpoints[0].checkpoint_id == checkpoint3.checkpoint_id
|
|
|
|
# Test list all checkpoints
|
|
all_checkpoint_ids = await storage.list_checkpoint_ids()
|
|
assert len(all_checkpoint_ids) == 3
|
|
|
|
all_checkpoints = await storage.list_checkpoints()
|
|
assert len(all_checkpoints) == 3
|
|
assert all(isinstance(cp, WorkflowCheckpoint) for cp in all_checkpoints)
|
|
|
|
|
|
async def test_file_checkpoint_storage_delete():
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
storage = FileCheckpointStorage(temp_dir)
|
|
checkpoint = WorkflowCheckpoint(workflow_id="test-workflow")
|
|
|
|
# Save checkpoint
|
|
await storage.save_checkpoint(checkpoint)
|
|
file_path = Path(temp_dir) / f"{checkpoint.checkpoint_id}.json"
|
|
assert file_path.exists()
|
|
|
|
# Delete checkpoint
|
|
result = await storage.delete_checkpoint(checkpoint.checkpoint_id)
|
|
assert result is True
|
|
assert not file_path.exists()
|
|
|
|
# Try to delete again
|
|
result = await storage.delete_checkpoint(checkpoint.checkpoint_id)
|
|
assert result is False
|
|
|
|
|
|
async def test_file_checkpoint_storage_directory_creation():
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
nested_path = Path(temp_dir) / "nested" / "checkpoint" / "storage"
|
|
storage = FileCheckpointStorage(nested_path)
|
|
|
|
# Directory should be created
|
|
assert nested_path.exists()
|
|
assert nested_path.is_dir()
|
|
|
|
# Should be able to save checkpoints
|
|
checkpoint = WorkflowCheckpoint(workflow_id="test")
|
|
await storage.save_checkpoint(checkpoint)
|
|
|
|
file_path = nested_path / f"{checkpoint.checkpoint_id}.json"
|
|
assert file_path.exists()
|
|
|
|
|
|
async def test_file_checkpoint_storage_corrupted_file():
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
storage = FileCheckpointStorage(temp_dir)
|
|
|
|
# Create a corrupted JSON file
|
|
corrupted_file = Path(temp_dir) / "corrupted.json"
|
|
with open(corrupted_file, "w") as f: # noqa: ASYNC230
|
|
f.write("{ invalid json }")
|
|
|
|
# list_checkpoints should handle the corrupted file gracefully
|
|
checkpoints = await storage.list_checkpoints("any-workflow")
|
|
assert checkpoints == []
|
|
|
|
|
|
async def test_file_checkpoint_storage_json_serialization():
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
storage = FileCheckpointStorage(temp_dir)
|
|
|
|
# Create checkpoint with complex nested data
|
|
checkpoint = WorkflowCheckpoint(
|
|
workflow_id="complex-workflow",
|
|
messages={"executor1": [{"data": {"nested": {"value": 42}}, "source_id": "test", "target_id": None}]},
|
|
shared_state={"list": [1, 2, 3], "dict": {"a": "b", "c": {"d": "e"}}, "bool": True, "null": None},
|
|
executor_states={"executor1": {"state": "active", "config": {"timeout": 30, "retries": 3}}},
|
|
)
|
|
|
|
# Save and load
|
|
await storage.save_checkpoint(checkpoint)
|
|
loaded = await storage.load_checkpoint(checkpoint.checkpoint_id)
|
|
|
|
assert loaded is not None
|
|
assert loaded.messages == checkpoint.messages
|
|
assert loaded.shared_state == checkpoint.shared_state
|
|
assert loaded.executor_states == checkpoint.executor_states
|
|
|
|
# Verify the JSON file is properly formatted
|
|
file_path = Path(temp_dir) / f"{checkpoint.checkpoint_id}.json"
|
|
with open(file_path) as f: # noqa: ASYNC230
|
|
data = json.load(f)
|
|
|
|
assert data["messages"]["executor1"][0]["data"]["nested"]["value"] == 42
|
|
assert data["shared_state"]["list"] == [1, 2, 3]
|
|
assert data["shared_state"]["bool"] is True
|
|
assert data["shared_state"]["null"] is None
|
|
|
|
|
|
def test_checkpoint_storage_protocol_compliance():
|
|
# This test ensures both implementations have all required methods
|
|
memory_storage = InMemoryCheckpointStorage()
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
file_storage = FileCheckpointStorage(temp_dir)
|
|
|
|
for storage in [memory_storage, file_storage]:
|
|
# Test that all protocol methods exist and are callable
|
|
assert hasattr(storage, "save_checkpoint")
|
|
assert callable(storage.save_checkpoint)
|
|
assert hasattr(storage, "load_checkpoint")
|
|
assert callable(storage.load_checkpoint)
|
|
assert hasattr(storage, "list_checkpoint_ids")
|
|
assert callable(storage.list_checkpoint_ids)
|
|
assert hasattr(storage, "list_checkpoints")
|
|
assert callable(storage.list_checkpoints)
|
|
assert hasattr(storage, "delete_checkpoint")
|
|
assert callable(storage.delete_checkpoint)
|