mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
c361ad8d33
* Add checkpoint hooks * Deprecate get_executor_state and set_executor_state * Fix tests and samples * Add doc strings * Add sample * Fix import * Address comments and fix tests * Address comments * conditional import
447 lines
18 KiB
Python
447 lines
18 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
"""Tests for checkpoint-as-conversation-items implementation."""
|
|
|
|
from dataclasses import dataclass
|
|
|
|
import pytest
|
|
from agent_framework import (
|
|
Executor,
|
|
InMemoryCheckpointStorage,
|
|
WorkflowBuilder,
|
|
WorkflowContext,
|
|
handler,
|
|
response_handler,
|
|
)
|
|
|
|
from agent_framework_devui._conversations import (
|
|
CheckpointConversationManager,
|
|
InMemoryConversationStore,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class WorkflowTestData:
|
|
"""Simple test data."""
|
|
|
|
value: str
|
|
|
|
|
|
@dataclass
|
|
class WorkflowHILRequest:
|
|
"""HIL request for testing."""
|
|
|
|
question: str
|
|
|
|
|
|
class WorkflowTestExecutor(Executor):
|
|
"""Test executor with HIL."""
|
|
|
|
def __init__(self, id: str) -> None:
|
|
super().__init__(id=id)
|
|
self._data_value: str | None = None
|
|
|
|
@handler
|
|
async def process(self, data: WorkflowTestData, ctx: WorkflowContext) -> None:
|
|
"""Process data and request approval."""
|
|
self._data_value = data.value
|
|
|
|
# Request HIL (checkpoint created here)
|
|
await ctx.request_info(request_data=WorkflowHILRequest(question=f"Approve {data.value}?"), response_type=str)
|
|
|
|
@response_handler
|
|
async def handle_response(
|
|
self, original_request: WorkflowHILRequest, response: str, ctx: WorkflowContext[str]
|
|
) -> None:
|
|
"""Handle HIL response."""
|
|
value = self._data_value or ""
|
|
await ctx.send_message(f"{value}_approved" if response.lower() == "yes" else f"{value}_rejected")
|
|
|
|
|
|
@pytest.fixture
|
|
def conversation_store():
|
|
"""Create in-memory conversation store."""
|
|
return InMemoryConversationStore()
|
|
|
|
|
|
@pytest.fixture
|
|
def checkpoint_manager(conversation_store):
|
|
"""Create checkpoint manager."""
|
|
return CheckpointConversationManager(conversation_store)
|
|
|
|
|
|
@pytest.fixture
|
|
def test_workflow():
|
|
"""Create test workflow with checkpointing."""
|
|
executor = WorkflowTestExecutor(id="test_executor")
|
|
checkpoint_storage = InMemoryCheckpointStorage()
|
|
|
|
return (
|
|
WorkflowBuilder(name="Test Workflow", description="Test checkpoint behavior")
|
|
.set_start_executor(executor)
|
|
.with_checkpointing(checkpoint_storage)
|
|
.build()
|
|
)
|
|
|
|
|
|
class TestCheckpointConversationManager:
|
|
"""Test CheckpointConversationManager functionality - CONVERSATION-SCOPED."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_conversation_scoped_checkpoint_save(self, checkpoint_manager, test_workflow):
|
|
"""Test checkpoint save in a specific conversation."""
|
|
entity_id = "test_entity"
|
|
conversation_id = f"conv_{entity_id}_test123"
|
|
|
|
# Create conversation first
|
|
checkpoint_manager.conversation_store.create_conversation(
|
|
metadata={"entity_id": entity_id, "type": "workflow_session"}, conversation_id=conversation_id
|
|
)
|
|
|
|
# Create test checkpoint
|
|
import uuid
|
|
|
|
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
|
|
|
|
checkpoint = WorkflowCheckpoint(
|
|
checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, shared_state={"test": "data"}
|
|
)
|
|
|
|
# Get checkpoint storage for this conversation and save
|
|
storage = checkpoint_manager.get_checkpoint_storage(conversation_id)
|
|
checkpoint_id = await storage.save_checkpoint(checkpoint)
|
|
|
|
assert checkpoint_id == checkpoint.checkpoint_id
|
|
|
|
# Verify checkpoint stored in THIS conversation only
|
|
checkpoints = await storage.list_checkpoints()
|
|
assert len(checkpoints) == 1
|
|
assert checkpoints[0].checkpoint_id == checkpoint.checkpoint_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_conversation_isolation(self, checkpoint_manager, test_workflow):
|
|
"""Test that conversations are isolated - checkpoints don't leak between conversations."""
|
|
entity_id = "test_entity"
|
|
conv_a = f"conv_{entity_id}_aaa"
|
|
conv_b = f"conv_{entity_id}_bbb"
|
|
|
|
# Create two conversations
|
|
checkpoint_manager.conversation_store.create_conversation(
|
|
metadata={"entity_id": entity_id, "type": "workflow_session"}, conversation_id=conv_a
|
|
)
|
|
checkpoint_manager.conversation_store.create_conversation(
|
|
metadata={"entity_id": entity_id, "type": "workflow_session"}, conversation_id=conv_b
|
|
)
|
|
|
|
# Save checkpoint to conversation A
|
|
import uuid
|
|
|
|
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
|
|
|
|
checkpoint_a = WorkflowCheckpoint(
|
|
checkpoint_id=str(uuid.uuid4()),
|
|
workflow_id=test_workflow.id,
|
|
messages={},
|
|
shared_state={"conversation": "A"},
|
|
)
|
|
storage_a = checkpoint_manager.get_checkpoint_storage(conv_a)
|
|
await storage_a.save_checkpoint(checkpoint_a)
|
|
|
|
# Verify conversation A has checkpoint
|
|
checkpoints_a = await storage_a.list_checkpoints()
|
|
assert len(checkpoints_a) == 1
|
|
|
|
# Verify conversation B has NO checkpoints (isolation)
|
|
storage_b = checkpoint_manager.get_checkpoint_storage(conv_b)
|
|
checkpoints_b = await storage_b.list_checkpoints()
|
|
assert len(checkpoints_b) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_checkpoints_in_session(self, checkpoint_manager, test_workflow):
|
|
"""Test listing checkpoints within a session."""
|
|
entity_id = "test_entity"
|
|
conversation_id = f"session_{entity_id}_test456"
|
|
|
|
# Create session
|
|
checkpoint_manager.conversation_store.create_conversation(
|
|
metadata={"entity_id": entity_id, "type": "workflow_session"}, conversation_id=conversation_id
|
|
)
|
|
|
|
# Save multiple checkpoints
|
|
import uuid
|
|
|
|
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
|
|
|
|
storage = checkpoint_manager.get_checkpoint_storage(conversation_id)
|
|
checkpoint_ids = []
|
|
for i in range(3):
|
|
checkpoint = WorkflowCheckpoint(
|
|
checkpoint_id=str(uuid.uuid4()),
|
|
workflow_id=test_workflow.id,
|
|
messages={},
|
|
shared_state={"iteration": i},
|
|
)
|
|
saved_id = await storage.save_checkpoint(checkpoint)
|
|
checkpoint_ids.append(saved_id)
|
|
|
|
# List checkpoints using the storage
|
|
checkpoints_list = await storage.list_checkpoints()
|
|
assert len(checkpoints_list) == 3
|
|
|
|
# Verify all checkpoint IDs are present
|
|
loaded_ids = [cp.checkpoint_id for cp in checkpoints_list]
|
|
for saved_id in checkpoint_ids:
|
|
assert saved_id in loaded_ids
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_checkpoints_appear_as_conversation_items(self, checkpoint_manager, test_workflow):
|
|
"""Test that checkpoints appear as conversation items through the standard API."""
|
|
entity_id = "test_entity"
|
|
conversation_id = f"session_{entity_id}_items_test"
|
|
|
|
# Create session
|
|
checkpoint_manager.conversation_store.create_conversation(
|
|
metadata={"entity_id": entity_id, "type": "workflow_session"}, conversation_id=conversation_id
|
|
)
|
|
|
|
# Save multiple checkpoints
|
|
|
|
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
|
|
|
|
storage = checkpoint_manager.get_checkpoint_storage(conversation_id)
|
|
checkpoint_ids = []
|
|
for i in range(2):
|
|
checkpoint = WorkflowCheckpoint(
|
|
checkpoint_id=f"checkpoint_{i}",
|
|
workflow_id=test_workflow.id,
|
|
messages={},
|
|
shared_state={"iteration": i},
|
|
)
|
|
saved_id = await storage.save_checkpoint(checkpoint)
|
|
checkpoint_ids.append(saved_id)
|
|
|
|
# List conversation items - should include checkpoints
|
|
items, has_more = await checkpoint_manager.conversation_store.list_items(conversation_id)
|
|
|
|
# Filter for checkpoint items
|
|
checkpoint_items = [item for item in items if (isinstance(item, dict) and item.get("type") == "checkpoint")]
|
|
|
|
# Verify we have the correct number of checkpoint items
|
|
assert len(checkpoint_items) == 2, f"Expected 2 checkpoint items, got {len(checkpoint_items)}"
|
|
|
|
# Verify checkpoint items have correct structure
|
|
for item in checkpoint_items:
|
|
assert item.get("type") == "checkpoint"
|
|
assert item.get("checkpoint_id") in checkpoint_ids
|
|
assert item.get("workflow_id") == test_workflow.id
|
|
assert "timestamp" in item
|
|
assert item.get("id").startswith("checkpoint_") # ID format: checkpoint_{checkpoint_id}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_load_checkpoint_from_session(self, checkpoint_manager, test_workflow):
|
|
"""Test loading checkpoint from a specific session."""
|
|
entity_id = "test_entity"
|
|
conversation_id = f"session_{entity_id}_test789"
|
|
|
|
# Create session
|
|
checkpoint_manager.conversation_store.create_conversation(
|
|
metadata={"entity_id": entity_id, "type": "workflow_session"}, conversation_id=conversation_id
|
|
)
|
|
|
|
# Create and save a checkpoint
|
|
import uuid
|
|
|
|
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
|
|
|
|
original_checkpoint = WorkflowCheckpoint(
|
|
checkpoint_id=str(uuid.uuid4()),
|
|
workflow_id=test_workflow.id,
|
|
messages={},
|
|
shared_state={"test_key": "test_value"},
|
|
)
|
|
|
|
# Save to this session
|
|
storage = checkpoint_manager.get_checkpoint_storage(conversation_id)
|
|
await storage.save_checkpoint(original_checkpoint)
|
|
|
|
# Load checkpoint from this session
|
|
loaded_checkpoint = await storage.load_checkpoint(original_checkpoint.checkpoint_id)
|
|
|
|
assert loaded_checkpoint is not None
|
|
assert loaded_checkpoint.checkpoint_id == original_checkpoint.checkpoint_id
|
|
assert loaded_checkpoint.workflow_id == original_checkpoint.workflow_id
|
|
assert loaded_checkpoint.shared_state == {"test_key": "test_value"}
|
|
|
|
|
|
class TestCheckpointStorage:
|
|
"""Test InMemoryCheckpointStorage per conversation - SESSION-SCOPED."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_checkpoint_storage_protocol(self, checkpoint_manager, test_workflow):
|
|
"""Test that adapter implements CheckpointStorage protocol."""
|
|
entity_id = "test_entity"
|
|
conversation_id = f"session_{entity_id}_adapter_test"
|
|
|
|
# Create session
|
|
checkpoint_manager.conversation_store.create_conversation(
|
|
metadata={"entity_id": entity_id, "type": "workflow_session"}, conversation_id=conversation_id
|
|
)
|
|
|
|
# Get storage adapter for this session
|
|
storage = checkpoint_manager.get_checkpoint_storage(conversation_id)
|
|
|
|
# Create test checkpoint
|
|
import uuid
|
|
|
|
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
|
|
|
|
checkpoint = WorkflowCheckpoint(
|
|
checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, shared_state={"test": "data"}
|
|
)
|
|
|
|
# Test save_checkpoint
|
|
checkpoint_id = await storage.save_checkpoint(checkpoint)
|
|
assert checkpoint_id == checkpoint.checkpoint_id
|
|
|
|
# Test load_checkpoint
|
|
loaded = await storage.load_checkpoint(checkpoint_id)
|
|
assert loaded is not None
|
|
assert loaded.checkpoint_id == checkpoint_id
|
|
|
|
# Test list_checkpoint_ids
|
|
ids = await storage.list_checkpoint_ids(workflow_id=test_workflow.id)
|
|
assert checkpoint_id in ids
|
|
|
|
# Test list_checkpoints
|
|
checkpoints_list = await storage.list_checkpoints(workflow_id=test_workflow.id)
|
|
assert len(checkpoints_list) >= 1
|
|
assert any(cp.checkpoint_id == checkpoint_id for cp in checkpoints_list)
|
|
|
|
|
|
class TestIntegration:
|
|
"""Integration tests for checkpoint workflow execution."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_manual_checkpoint_save_via_injected_storage(self, checkpoint_manager, test_workflow):
|
|
"""Test manual checkpoint save via build-time storage injection."""
|
|
entity_id = "test_entity"
|
|
conversation_id = f"session_{entity_id}_integration_test1"
|
|
|
|
# Create session conversation
|
|
checkpoint_manager.conversation_store.create_conversation(
|
|
metadata={"entity_id": entity_id, "type": "workflow_session"}, conversation_id=conversation_id
|
|
)
|
|
|
|
# Get checkpoint storage for this session
|
|
checkpoint_storage = checkpoint_manager.get_checkpoint_storage(conversation_id)
|
|
|
|
# Set build-time storage (equivalent to .with_checkpointing() at build time)
|
|
# Note: In production, DevUI uses runtime injection via run_stream() parameter
|
|
if hasattr(test_workflow, "_runner") and hasattr(test_workflow._runner, "context"):
|
|
test_workflow._runner.context._checkpoint_storage = checkpoint_storage
|
|
|
|
# Create and save a checkpoint via injected storage
|
|
import uuid
|
|
|
|
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
|
|
|
|
checkpoint = WorkflowCheckpoint(
|
|
checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, shared_state={"injected": True}
|
|
)
|
|
await checkpoint_storage.save_checkpoint(checkpoint)
|
|
|
|
# Verify checkpoint is accessible via storage (in this session)
|
|
storage_checkpoints = await checkpoint_storage.list_checkpoints()
|
|
assert len(storage_checkpoints) > 0
|
|
assert storage_checkpoints[0].checkpoint_id == checkpoint.checkpoint_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_checkpoint_roundtrip_via_storage(self, checkpoint_manager, test_workflow):
|
|
"""Test checkpoint save/load roundtrip via storage adapter."""
|
|
entity_id = "test_entity"
|
|
conversation_id = f"session_{entity_id}_integration_test2"
|
|
|
|
# Create session conversation
|
|
checkpoint_manager.conversation_store.create_conversation(
|
|
metadata={"entity_id": entity_id, "type": "workflow_session"}, conversation_id=conversation_id
|
|
)
|
|
|
|
# Set build-time storage for testing
|
|
checkpoint_storage = checkpoint_manager.get_checkpoint_storage(conversation_id)
|
|
test_workflow._runner.context._checkpoint_storage = checkpoint_storage
|
|
|
|
# Create checkpoint
|
|
import uuid
|
|
|
|
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
|
|
|
|
checkpoint = WorkflowCheckpoint(
|
|
checkpoint_id=str(uuid.uuid4()),
|
|
workflow_id=test_workflow.id,
|
|
messages={},
|
|
shared_state={"ready_to_resume": True},
|
|
)
|
|
checkpoint_id = await checkpoint_storage.save_checkpoint(checkpoint)
|
|
|
|
# Verify checkpoint can be loaded for resume
|
|
loaded = await checkpoint_storage.load_checkpoint(checkpoint_id)
|
|
assert loaded is not None
|
|
assert loaded.checkpoint_id == checkpoint_id
|
|
assert loaded.shared_state == {"ready_to_resume": True}
|
|
|
|
# Verify checkpoint is accessible via storage (for UI to list checkpoints)
|
|
checkpoints = await checkpoint_storage.list_checkpoints()
|
|
assert len(checkpoints) > 0
|
|
assert checkpoints[0].checkpoint_id == checkpoint_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_workflow_auto_saves_checkpoints_to_injected_storage(self, checkpoint_manager, test_workflow):
|
|
"""Test that workflows automatically save checkpoints to our conversation-backed storage.
|
|
|
|
This is the critical end-to-end test that verifies the entire checkpoint flow:
|
|
1. Storage is set as build-time storage (simulates .with_checkpointing())
|
|
2. Workflow runs and pauses at HIL point (IDLE_WITH_PENDING_REQUESTS status)
|
|
3. Framework automatically saves checkpoint to our storage
|
|
4. Checkpoint is accessible via manager for UI to list/resume
|
|
|
|
Note: In production, DevUI passes checkpoint_storage to run_stream() as runtime parameter.
|
|
This test uses build-time injection to verify framework's checkpoint auto-save behavior.
|
|
"""
|
|
entity_id = "test_entity"
|
|
conversation_id = f"session_{entity_id}_integration_test3"
|
|
|
|
# Create session conversation
|
|
checkpoint_manager.conversation_store.create_conversation(
|
|
metadata={"entity_id": entity_id, "type": "workflow_session"}, conversation_id=conversation_id
|
|
)
|
|
|
|
# Set build-time storage to test automatic checkpoint saves
|
|
checkpoint_storage = checkpoint_manager.get_checkpoint_storage(conversation_id)
|
|
test_workflow._runner.context._checkpoint_storage = checkpoint_storage
|
|
|
|
# Verify no checkpoints initially
|
|
checkpoints_before = await checkpoint_storage.list_checkpoints()
|
|
assert len(checkpoints_before) == 0
|
|
|
|
# Run workflow until it reaches IDLE_WITH_PENDING_REQUESTS (after checkpoint is created)
|
|
saw_request_event = False
|
|
async for event in test_workflow.run_stream(WorkflowTestData(value="test")):
|
|
if hasattr(event, "__class__"):
|
|
if event.__class__.__name__ == "RequestInfoEvent":
|
|
saw_request_event = True
|
|
# Wait for IDLE_WITH_PENDING_REQUESTS status (comes after checkpoint creation)
|
|
is_status_event = event.__class__.__name__ == "WorkflowStatusEvent"
|
|
has_pending_status = hasattr(event, "status") and "IDLE_WITH_PENDING_REQUESTS" in str(event.status)
|
|
if is_status_event and has_pending_status:
|
|
break
|
|
|
|
assert saw_request_event, "Test workflow should have emitted RequestInfoEvent"
|
|
|
|
# Verify checkpoint was AUTOMATICALLY saved to our storage by the framework
|
|
checkpoints_after = await checkpoint_storage.list_checkpoints()
|
|
assert len(checkpoints_after) > 0, "Workflow should have auto-saved checkpoint at HIL pause"
|
|
|
|
# Verify checkpoint has correct workflow_id
|
|
checkpoint = checkpoints_after[0]
|
|
assert checkpoint.workflow_id == test_workflow.id
|