Python: Fix streamed workflow agent continuation context by finalizing AgentExecutor streams (#3882)

* Fix streamed workflow agent continuation context by finalizing AgentExecutor streams

* Fix stream handling

* Fixes

* Fix DevUI and tests
This commit is contained in:
Evan Mattson
2026-02-13 07:45:46 +09:00
committed by GitHub
Unverified
parent 2203fa0f8b
commit a276c1295a
17 changed files with 359 additions and 267 deletions
@@ -14,7 +14,7 @@ from abc import ABC, abstractmethod
from typing import Any, Literal, cast
from agent_framework import AgentSession, Message
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage, WorkflowCheckpoint
from openai.types.conversations import Conversation, ConversationDeletedResource
from openai.types.conversations.conversation_item import ConversationItem
from openai.types.conversations.message import Message as OpenAIMessage
@@ -480,7 +480,7 @@ class InMemoryConversationStore(ConversationStore):
checkpoint_storage = conv_data.get("checkpoint_storage")
if checkpoint_storage:
# Get all checkpoints for this conversation
checkpoints = await checkpoint_storage.list_checkpoints()
checkpoints = self._list_all_checkpoints(checkpoint_storage)
for checkpoint in checkpoints:
# Create a conversation item for each checkpoint with summary metadata
# Full checkpoint state is NOT included here (too large for list view)
@@ -495,7 +495,9 @@ class InMemoryConversationStore(ConversationStore):
"id": f"checkpoint_{checkpoint.checkpoint_id}",
"type": "checkpoint",
"checkpoint_id": checkpoint.checkpoint_id,
"workflow_id": checkpoint.workflow_id,
# Keep workflow_id for backward compatibility with existing UI payloads.
"workflow_id": checkpoint.workflow_name,
"workflow_name": checkpoint.workflow_name,
"timestamp": checkpoint.timestamp,
"status": "completed",
"metadata": {
@@ -506,6 +508,7 @@ class InMemoryConversationStore(ConversationStore):
"message_count": sum(len(msgs) for msgs in checkpoint.messages.values()),
"size_bytes": checkpoint_size,
"version": checkpoint.version,
"graph_signature_hash": checkpoint.graph_signature_hash,
},
}
items.append(cast(ConversationItem, checkpoint_item))
@@ -551,8 +554,9 @@ class InMemoryConversationStore(ConversationStore):
return None
# Load full checkpoint from storage
checkpoint = await checkpoint_storage.load_checkpoint(checkpoint_id)
if not checkpoint:
try:
checkpoint = await checkpoint_storage.load(checkpoint_id)
except Exception:
return None
# Calculate size of checkpoint
@@ -566,7 +570,9 @@ class InMemoryConversationStore(ConversationStore):
"id": item_id,
"type": "checkpoint",
"checkpoint_id": checkpoint.checkpoint_id,
"workflow_id": checkpoint.workflow_id,
# Keep workflow_id for backward compatibility with existing UI payloads.
"workflow_id": checkpoint.workflow_name,
"workflow_name": checkpoint.workflow_name,
"timestamp": checkpoint.timestamp,
"status": "completed",
"metadata": {
@@ -577,6 +583,7 @@ class InMemoryConversationStore(ConversationStore):
"message_count": sum(len(msgs) for msgs in checkpoint.messages.values()),
"size_bytes": checkpoint_size,
"version": checkpoint.version,
"graph_signature_hash": checkpoint.graph_signature_hash,
# 🔥 FULL checkpoint state (lazy loaded)
"full_checkpoint": checkpoint.to_dict(),
},
@@ -631,8 +638,8 @@ class InMemoryConversationStore(ConversationStore):
if conv_meta.get("type") == "workflow_session":
checkpoint_storage = conv_data.get("checkpoint_storage")
if checkpoint_storage:
checkpoints = await checkpoint_storage.list_checkpoints()
latest = checkpoints[0] if checkpoints else None
checkpoints = self._list_all_checkpoints(checkpoint_storage)
latest = max(checkpoints, key=lambda cp: cp.timestamp) if checkpoints else None
conv_meta["checkpoint_summary"] = {
"count": len(checkpoints),
"latest_iteration": latest.iteration_count if latest else 0,
@@ -654,6 +661,19 @@ class InMemoryConversationStore(ConversationStore):
return results
@staticmethod
def _list_all_checkpoints(checkpoint_storage: Any) -> list[WorkflowCheckpoint]:
"""Return all checkpoints from a conversation-scoped storage instance.
DevUI uses one checkpoint storage per conversation. Core storage APIs now
require workflow_name filters, so we gather directly from in-memory storage
internals to provide conversation-wide listing for UI views.
"""
checkpoint_map = getattr(checkpoint_storage, "_checkpoints", None)
if isinstance(checkpoint_map, dict):
return list(cast(dict[str, WorkflowCheckpoint], checkpoint_map).values())
return []
class CheckpointConversationManager:
"""Manages checkpoint storage for workflow sessions - SESSION-SCOPED.
@@ -104,17 +104,21 @@ class TestCheckpointConversationManager:
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
checkpoint = WorkflowCheckpoint(
checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, state={"test": "data"}
checkpoint_id=str(uuid.uuid4()),
workflow_name=test_workflow.name,
graph_signature_hash=test_workflow.graph_signature_hash,
messages={},
state={"test": "data"},
)
# Get checkpoint storage for this conversation and save
storage = checkpoint_manager.get_checkpoint_storage(conversation_id)
checkpoint_id = await storage.save_checkpoint(checkpoint)
checkpoint_id = await storage.save(checkpoint)
assert checkpoint_id == checkpoint.checkpoint_id
# Verify checkpoint stored in THIS conversation only
checkpoints = await storage.list_checkpoints()
checkpoints = await storage.list_checkpoints(workflow_name=test_workflow.name)
assert len(checkpoints) == 1
assert checkpoints[0].checkpoint_id == checkpoint.checkpoint_id
@@ -140,20 +144,21 @@ class TestCheckpointConversationManager:
checkpoint_a = WorkflowCheckpoint(
checkpoint_id=str(uuid.uuid4()),
workflow_id=test_workflow.id,
workflow_name=test_workflow.name,
graph_signature_hash=test_workflow.graph_signature_hash,
messages={},
state={"conversation": "A"},
)
storage_a = checkpoint_manager.get_checkpoint_storage(conv_a)
await storage_a.save_checkpoint(checkpoint_a)
await storage_a.save(checkpoint_a)
# Verify conversation A has checkpoint
checkpoints_a = await storage_a.list_checkpoints()
checkpoints_a = await storage_a.list_checkpoints(workflow_name=test_workflow.name)
assert len(checkpoints_a) == 1
# Verify conversation B has NO checkpoints (isolation)
storage_b = checkpoint_manager.get_checkpoint_storage(conv_b)
checkpoints_b = await storage_b.list_checkpoints()
checkpoints_b = await storage_b.list_checkpoints(workflow_name=test_workflow.name)
assert len(checkpoints_b) == 0
@pytest.mark.asyncio
@@ -177,15 +182,16 @@ class TestCheckpointConversationManager:
for i in range(3):
checkpoint = WorkflowCheckpoint(
checkpoint_id=str(uuid.uuid4()),
workflow_id=test_workflow.id,
workflow_name=test_workflow.name,
graph_signature_hash=test_workflow.graph_signature_hash,
messages={},
state={"iteration": i},
)
saved_id = await storage.save_checkpoint(checkpoint)
saved_id = await storage.save(checkpoint)
checkpoint_ids.append(saved_id)
# List checkpoints using the storage
checkpoints_list = await storage.list_checkpoints()
checkpoints_list = await storage.list_checkpoints(workflow_name=test_workflow.name)
assert len(checkpoints_list) == 3
# Verify all checkpoint IDs are present
@@ -213,11 +219,12 @@ class TestCheckpointConversationManager:
for i in range(2):
checkpoint = WorkflowCheckpoint(
checkpoint_id=f"checkpoint_{i}",
workflow_id=test_workflow.id,
workflow_name=test_workflow.name,
graph_signature_hash=test_workflow.graph_signature_hash,
messages={},
state={"iteration": i},
)
saved_id = await storage.save_checkpoint(checkpoint)
saved_id = await storage.save(checkpoint)
checkpoint_ids.append(saved_id)
# List conversation items - should include checkpoints
@@ -233,7 +240,7 @@ class TestCheckpointConversationManager:
for item in checkpoint_items:
assert item.get("type") == "checkpoint"
assert item.get("checkpoint_id") in checkpoint_ids
assert item.get("workflow_id") == test_workflow.id
assert item.get("workflow_name") == test_workflow.name
assert "timestamp" in item
assert item.get("id").startswith("checkpoint_") # ID format: checkpoint_{checkpoint_id}
@@ -255,21 +262,22 @@ class TestCheckpointConversationManager:
original_checkpoint = WorkflowCheckpoint(
checkpoint_id=str(uuid.uuid4()),
workflow_id=test_workflow.id,
workflow_name=test_workflow.name,
graph_signature_hash=test_workflow.graph_signature_hash,
messages={},
state={"test_key": "test_value"},
)
# Save to this session
storage = checkpoint_manager.get_checkpoint_storage(conversation_id)
await storage.save_checkpoint(original_checkpoint)
await storage.save(original_checkpoint)
# Load checkpoint from this session
loaded_checkpoint = await storage.load_checkpoint(original_checkpoint.checkpoint_id)
loaded_checkpoint = await storage.load(original_checkpoint.checkpoint_id)
assert loaded_checkpoint is not None
assert loaded_checkpoint.checkpoint_id == original_checkpoint.checkpoint_id
assert loaded_checkpoint.workflow_id == original_checkpoint.workflow_id
assert loaded_checkpoint.workflow_name == original_checkpoint.workflow_name
assert loaded_checkpoint.state == {"test_key": "test_value"}
@@ -296,24 +304,28 @@ class TestCheckpointStorage:
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
checkpoint = WorkflowCheckpoint(
checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, state={"test": "data"}
checkpoint_id=str(uuid.uuid4()),
workflow_name=test_workflow.name,
graph_signature_hash=test_workflow.graph_signature_hash,
messages={},
state={"test": "data"},
)
# Test save_checkpoint
checkpoint_id = await storage.save_checkpoint(checkpoint)
# Test save
checkpoint_id = await storage.save(checkpoint)
assert checkpoint_id == checkpoint.checkpoint_id
# Test load_checkpoint
loaded = await storage.load_checkpoint(checkpoint_id)
# Test load
loaded = await storage.load(checkpoint_id)
assert loaded is not None
assert loaded.checkpoint_id == checkpoint_id
# Test list_checkpoint_ids
ids = await storage.list_checkpoint_ids(workflow_id=test_workflow.id)
ids = await storage.list_checkpoint_ids(workflow_name=test_workflow.name)
assert checkpoint_id in ids
# Test list_checkpoints
checkpoints_list = await storage.list_checkpoints(workflow_id=test_workflow.id)
checkpoints_list = await storage.list_checkpoints(workflow_name=test_workflow.name)
assert len(checkpoints_list) >= 1
assert any(cp.checkpoint_id == checkpoint_id for cp in checkpoints_list)
@@ -346,12 +358,16 @@ class TestIntegration:
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
checkpoint = WorkflowCheckpoint(
checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, state={"injected": True}
checkpoint_id=str(uuid.uuid4()),
workflow_name=test_workflow.name,
graph_signature_hash=test_workflow.graph_signature_hash,
messages={},
state={"injected": True},
)
await checkpoint_storage.save_checkpoint(checkpoint)
await checkpoint_storage.save(checkpoint)
# Verify checkpoint is accessible via storage (in this session)
storage_checkpoints = await checkpoint_storage.list_checkpoints()
storage_checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=test_workflow.name)
assert len(storage_checkpoints) > 0
assert storage_checkpoints[0].checkpoint_id == checkpoint.checkpoint_id
@@ -377,20 +393,21 @@ class TestIntegration:
checkpoint = WorkflowCheckpoint(
checkpoint_id=str(uuid.uuid4()),
workflow_id=test_workflow.id,
workflow_name=test_workflow.name,
graph_signature_hash=test_workflow.graph_signature_hash,
messages={},
state={"ready_to_resume": True},
)
checkpoint_id = await checkpoint_storage.save_checkpoint(checkpoint)
checkpoint_id = await checkpoint_storage.save(checkpoint)
# Verify checkpoint can be loaded for resume
loaded = await checkpoint_storage.load_checkpoint(checkpoint_id)
loaded = await checkpoint_storage.load(checkpoint_id)
assert loaded is not None
assert loaded.checkpoint_id == checkpoint_id
assert loaded.state == {"ready_to_resume": True}
# Verify checkpoint is accessible via storage (for UI to list checkpoints)
checkpoints = await checkpoint_storage.list_checkpoints()
checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=test_workflow.name)
assert len(checkpoints) > 0
assert checkpoints[0].checkpoint_id == checkpoint_id
@@ -420,7 +437,7 @@ class TestIntegration:
test_workflow._runner.context._checkpoint_storage = checkpoint_storage
# Verify no checkpoints initially
checkpoints_before = await checkpoint_storage.list_checkpoints()
checkpoints_before = await checkpoint_storage.list_checkpoints(workflow_name=test_workflow.name)
assert len(checkpoints_before) == 0
# Run workflow until it reaches IDLE_WITH_PENDING_REQUESTS (after checkpoint is created)
@@ -435,9 +452,9 @@ class TestIntegration:
assert saw_request_event, "Test workflow should have emitted request_info event (type='request_info')"
# Verify checkpoint was AUTOMATICALLY saved to our storage by the framework
checkpoints_after = await checkpoint_storage.list_checkpoints()
checkpoints_after = await checkpoint_storage.list_checkpoints(workflow_name=test_workflow.name)
assert len(checkpoints_after) > 0, "Workflow should have auto-saved checkpoint at HIL pause"
# Verify checkpoint has correct workflow_id
# Verify checkpoint has correct workflow identity
checkpoint = checkpoints_after[0]
assert checkpoint.workflow_id == test_workflow.id
assert checkpoint.workflow_name == test_workflow.name
@@ -379,26 +379,27 @@ async def test_checkpoint_api_endpoints(test_entities_dir):
storage = executor.checkpoint_manager.get_checkpoint_storage(conv_id)
checkpoint = WorkflowCheckpoint(
checkpoint_id="test_checkpoint_1",
workflow_id="test_workflow",
workflow_name="test_workflow",
graph_signature_hash="test_graph_hash",
state={"key": "value"},
iteration_count=1,
)
await storage.save_checkpoint(checkpoint)
await storage.save(checkpoint)
# Test list checkpoints endpoint
checkpoints = await storage.list_checkpoints()
checkpoints = await storage.list_checkpoints(workflow_name="test_workflow")
assert len(checkpoints) == 1
assert checkpoints[0].checkpoint_id == "test_checkpoint_1"
assert checkpoints[0].workflow_id == "test_workflow"
assert checkpoints[0].workflow_name == "test_workflow"
# Test delete checkpoint endpoint
deleted = await storage.delete_checkpoint("test_checkpoint_1")
deleted = await storage.delete("test_checkpoint_1")
assert deleted is True
# Verify checkpoint was deleted
remaining = await storage.list_checkpoints()
remaining = await storage.list_checkpoints(workflow_name="test_workflow")
assert len(remaining) == 0
# Test delete non-existent checkpoint
deleted = await storage.delete_checkpoint("nonexistent")
deleted = await storage.delete("nonexistent")
assert deleted is False