Python: Fix streamed workflow agent continuation context by finalizing AgentExecutor streams (#3882)

* Fix streamed workflow agent continuation context by finalizing AgentExecutor streams

* Fix stream handling

* Fixes

* Fix DevUI and tests
This commit is contained in:
Evan Mattson
2026-02-13 07:45:46 +09:00
committed by GitHub
Unverified
parent 2203fa0f8b
commit a276c1295a
17 changed files with 359 additions and 267 deletions
@@ -911,19 +911,21 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
if ctx is None: if ctx is None:
return # No context available (shouldn't happen in normal flow) return # No context available (shouldn't happen in normal flow)
# Update thread with conversation_id derived from streaming raw updates.
# Using response_id here can break function-call continuation for APIs
# where response IDs are not valid conversation handles.
conversation_id = self._extract_conversation_id_from_streaming_response(response)
# Ensure author names are set for all messages # Ensure author names are set for all messages
for message in response.messages: for message in response.messages:
if message.author_name is None: if message.author_name is None:
message.author_name = ctx["agent_name"] message.author_name = ctx["agent_name"]
# Propagate conversation_id back to session from streaming updates # Propagate conversation_id back to session from streaming updates.
# For Responses-style APIs this can rotate every turn (response_id-based continuation),
# so refresh when a newer value is returned.
sess = ctx["session"] sess = ctx["session"]
if sess and not sess.service_session_id and response.raw_representation: if sess and conversation_id and sess.service_session_id != conversation_id:
raw_items = response.raw_representation if isinstance(response.raw_representation, list) else [] sess.service_session_id = conversation_id
for item in raw_items:
if hasattr(item, "conversation_id") and item.conversation_id:
sess.service_session_id = item.conversation_id
break
# Run after_run providers (reverse order) # Run after_run providers (reverse order)
session_context = ctx["session_context"] session_context = ctx["session_context"]
@@ -974,6 +976,27 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
output_format_type = response_format if isinstance(response_format, type) else None output_format_type = response_format if isinstance(response_format, type) else None
return AgentResponse.from_updates(updates, output_format_type=output_format_type) return AgentResponse.from_updates(updates, output_format_type=output_format_type)
@staticmethod
def _extract_conversation_id_from_streaming_response(response: AgentResponse[Any]) -> str | None:
"""Extract conversation_id from streaming raw updates, if present."""
raw = response.raw_representation
if raw is None:
return None
raw_items: list[Any] = raw if isinstance(raw, list) else [raw]
for item in reversed(raw_items):
if isinstance(item, Mapping):
value = item.get("conversation_id")
if isinstance(value, str) and value:
return value
continue
value = getattr(item, "conversation_id", None)
if isinstance(value, str) and value:
return value
return None
async def _prepare_run_context( async def _prepare_run_context(
self, self,
*, *,
@@ -1100,8 +1123,10 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
if message.author_name is None: if message.author_name is None:
message.author_name = agent_name message.author_name = agent_name
# Propagate conversation_id back to session (e.g. thread ID from Assistants API) # Propagate conversation_id back to session (e.g. thread ID from Assistants API).
if session and response.conversation_id and not session.service_session_id: # For Responses-style APIs this can rotate every turn (response_id-based continuation),
# so refresh when a newer value is returned.
if session and response.conversation_id and session.service_session_id != response.conversation_id:
session.service_session_id = response.conversation_id session.service_session_id = response.conversation_id
# Set the response on the context for after_run providers # Set the response on the context for after_run providers
+10 -1
View File
@@ -872,7 +872,16 @@ class MCPTool:
k: v k: v
for k, v in kwargs.items() for k, v in kwargs.items()
if k if k
not in {"chat_options", "tools", "tool_choice", "session", "thread", "conversation_id", "options", "response_format"} not in {
"chat_options",
"tools",
"tool_choice",
"session",
"thread",
"conversation_id",
"options",
"response_format",
}
} }
parser = self.parse_tool_results or _parse_tool_result_from_mcp parser = self.parse_tool_results or _parse_tool_result_from_mcp
@@ -2,7 +2,7 @@
import logging import logging
import sys import sys
from collections.abc import Mapping from collections.abc import Awaitable, Callable, Mapping
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, cast from typing import Any, cast
@@ -358,22 +358,31 @@ class AgentExecutor(Executor):
run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {}) run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {})
updates: list[AgentResponseUpdate] = [] updates: list[AgentResponseUpdate] = []
user_input_requests: list[Content] = [] streamed_user_input_requests: list[Content] = []
async for update in self._agent.run( stream = self._agent.run(
self._cache, self._cache,
stream=True, stream=True,
session=self._session, session=self._session,
options=options, options=options,
**run_kwargs, **run_kwargs,
): )
async for update in stream:
updates.append(update) updates.append(update)
await ctx.yield_output(update) await ctx.yield_output(update)
if update.user_input_requests: if update.user_input_requests:
user_input_requests.extend(update.user_input_requests) streamed_user_input_requests.extend(update.user_input_requests)
# Build the final AgentResponse from the collected updates # Prefer stream finalization when available so result hooks run
if is_chat_agent(self._agent): # (e.g., thread conversation updates). Fall back to reconstructing from updates
# for legacy/custom agents that return a plain async iterable.
# TODO(evmattso): Integrate workflow agent run handling around ResponseStream so
# AgentExecutor does not need this conditional stream-finalization branch.
maybe_get_final_response = getattr(stream, "get_final_response", None)
get_final_response = maybe_get_final_response if callable(maybe_get_final_response) else None
response: AgentResponse[Any]
if get_final_response is not None:
response = await cast(Callable[[], Awaitable[AgentResponse[Any]]], get_final_response)()
elif is_chat_agent(self._agent):
response_format = self._agent.default_options.get("response_format") response_format = self._agent.default_options.get("response_format")
response = AgentResponse.from_updates( response = AgentResponse.from_updates(
updates, updates,
@@ -383,6 +392,16 @@ class AgentExecutor(Executor):
response = AgentResponse.from_updates(updates) response = AgentResponse.from_updates(updates)
# Handle any user input requests after the streaming completes # Handle any user input requests after the streaming completes
user_input_requests: list[Content] = []
seen_request_ids: set[str] = set()
for user_input_request in [*streamed_user_input_requests, *response.user_input_requests]:
request_id = getattr(user_input_request, "id", None)
if isinstance(request_id, str) and request_id:
if request_id in seen_request_ids:
continue
seen_request_ids.add(request_id)
user_input_requests.append(user_input_request)
if user_input_requests: if user_input_requests:
for user_input_request in user_input_requests: for user_input_request in user_input_requests:
self._pending_agent_requests[user_input_request.id] = user_input_request # type: ignore[index] self._pending_agent_requests[user_input_request.id] = user_input_request # type: ignore[index]
@@ -17,6 +17,7 @@ from agent_framework import (
BaseContextProvider, BaseContextProvider,
ChatOptions, ChatOptions,
ChatResponse, ChatResponse,
ChatResponseUpdate,
Content, Content,
FunctionTool, FunctionTool,
Message, Message,
@@ -154,6 +155,111 @@ async def test_chat_client_agent_run_with_session(chat_client_base: SupportsChat
assert session.service_session_id == "123" assert session.service_session_id == "123"
async def test_chat_client_agent_updates_existing_session_id_non_streaming(
chat_client_base: SupportsChatGetResponse,
) -> None:
chat_client_base.run_responses = [
ChatResponse(
messages=[Message(role="assistant", contents=[Content.from_text("test response")])],
conversation_id="resp_new_123",
)
]
agent = Agent(client=chat_client_base)
session = agent.get_session(service_session_id="resp_old_123")
await agent.run("Hello", session=session)
assert session.service_session_id == "resp_new_123"
async def test_chat_client_agent_update_session_id_streaming_uses_conversation_id(
chat_client_base: SupportsChatGetResponse,
) -> None:
chat_client_base.streaming_responses = [
[
ChatResponseUpdate(
contents=[Content.from_text("stream part 1")],
role="assistant",
response_id="resp_stream_123",
conversation_id="conv_stream_456",
),
ChatResponseUpdate(
contents=[Content.from_text(" stream part 2")],
role="assistant",
response_id="resp_stream_123",
conversation_id="conv_stream_456",
finish_reason="stop",
),
]
]
agent = Agent(client=chat_client_base)
session = agent.create_session()
stream = agent.run("Hello", session=session, stream=True)
async for _ in stream:
pass
result = await stream.get_final_response()
assert result.text == "stream part 1 stream part 2"
assert session.service_session_id == "conv_stream_456"
async def test_chat_client_agent_updates_existing_session_id_streaming(
chat_client_base: SupportsChatGetResponse,
) -> None:
chat_client_base.streaming_responses = [
[
ChatResponseUpdate(
contents=[Content.from_text("stream part 1")],
role="assistant",
response_id="resp_stream_123",
conversation_id="resp_new_456",
),
ChatResponseUpdate(
contents=[Content.from_text(" stream part 2")],
role="assistant",
response_id="resp_stream_123",
conversation_id="resp_new_456",
finish_reason="stop",
),
]
]
agent = Agent(client=chat_client_base)
session = agent.get_session(service_session_id="resp_old_456")
stream = agent.run("Hello", session=session, stream=True)
async for _ in stream:
pass
await stream.get_final_response()
assert session.service_session_id == "resp_new_456"
async def test_chat_client_agent_update_session_id_streaming_does_not_use_response_id(
chat_client_base: SupportsChatGetResponse,
) -> None:
chat_client_base.streaming_responses = [
[
ChatResponseUpdate(
contents=[Content.from_text("stream response without conversation id")],
role="assistant",
response_id="resp_only_123",
finish_reason="stop",
),
]
]
agent = Agent(client=chat_client_base)
session = agent.create_session()
stream = agent.run("Hello", session=session, stream=True)
async for _ in stream:
pass
result = await stream.get_final_response()
assert result.text == "stream response without conversation id"
assert session.service_session_id is None
async def test_chat_client_agent_update_session_messages(client: SupportsChatGetResponse) -> None: async def test_chat_client_agent_update_session_messages(client: SupportsChatGetResponse) -> None:
agent = Agent(client=client) agent = Agent(client=client)
session = agent.create_session() session = agent.create_session()
@@ -0,0 +1 @@
# Copyright (c) Microsoft. All rights reserved.
@@ -50,6 +50,57 @@ class _CountingAgent(BaseAgent):
return _run() return _run()
class _StreamingHookAgent(BaseAgent):
"""Agent that exposes whether its streaming result hook was executed."""
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.result_hook_called = False
def run(
self,
messages: str | Message | list[str] | list[Message] | None = None,
*,
stream: bool = False,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
if stream:
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
yield AgentResponseUpdate(
contents=[Content.from_text(text="hook test")],
role="assistant",
)
async def _mark_result_hook_called(response: AgentResponse) -> AgentResponse:
self.result_hook_called = True
return response
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates).with_result_hook(
_mark_result_hook_called
)
async def _run() -> AgentResponse:
return AgentResponse(messages=[Message("assistant", ["hook test"])])
return _run()
async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> None:
"""AgentExecutor should call get_final_response() so stream result hooks execute."""
agent = _StreamingHookAgent(id="hook_agent", name="HookAgent")
executor = AgentExecutor(agent, id="hook_exec")
workflow = SequentialBuilder(participants=[executor]).build()
output_events: list[Any] = []
async for event in workflow.run("run hook test", stream=True):
if event.type == "output":
output_events.append(event)
assert output_events
assert agent.result_hook_called
async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
"""Test that workflow checkpoint stores AgentExecutor's cache and session states and restores them correctly.""" """Test that workflow checkpoint stores AgentExecutor's cache and session states and restores them correctly."""
storage = InMemoryCheckpointStorage() storage = InMemoryCheckpointStorage()
@@ -12,7 +12,7 @@ from agent_framework import (
WorkflowRunState, WorkflowRunState,
) )
from agent_framework._workflows._checkpoint_encoding import ( from agent_framework._workflows._checkpoint_encoding import (
_PICKLE_MARKER, _PICKLE_MARKER, # type: ignore
encode_checkpoint_value, encode_checkpoint_value,
) )
from agent_framework._workflows._events import WorkflowEvent from agent_framework._workflows._events import WorkflowEvent
@@ -14,7 +14,7 @@ from abc import ABC, abstractmethod
from typing import Any, Literal, cast from typing import Any, Literal, cast
from agent_framework import AgentSession, Message from agent_framework import AgentSession, Message
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage, WorkflowCheckpoint
from openai.types.conversations import Conversation, ConversationDeletedResource from openai.types.conversations import Conversation, ConversationDeletedResource
from openai.types.conversations.conversation_item import ConversationItem from openai.types.conversations.conversation_item import ConversationItem
from openai.types.conversations.message import Message as OpenAIMessage from openai.types.conversations.message import Message as OpenAIMessage
@@ -480,7 +480,7 @@ class InMemoryConversationStore(ConversationStore):
checkpoint_storage = conv_data.get("checkpoint_storage") checkpoint_storage = conv_data.get("checkpoint_storage")
if checkpoint_storage: if checkpoint_storage:
# Get all checkpoints for this conversation # Get all checkpoints for this conversation
checkpoints = await checkpoint_storage.list_checkpoints() checkpoints = self._list_all_checkpoints(checkpoint_storage)
for checkpoint in checkpoints: for checkpoint in checkpoints:
# Create a conversation item for each checkpoint with summary metadata # Create a conversation item for each checkpoint with summary metadata
# Full checkpoint state is NOT included here (too large for list view) # Full checkpoint state is NOT included here (too large for list view)
@@ -495,7 +495,9 @@ class InMemoryConversationStore(ConversationStore):
"id": f"checkpoint_{checkpoint.checkpoint_id}", "id": f"checkpoint_{checkpoint.checkpoint_id}",
"type": "checkpoint", "type": "checkpoint",
"checkpoint_id": checkpoint.checkpoint_id, "checkpoint_id": checkpoint.checkpoint_id,
"workflow_id": checkpoint.workflow_id, # Keep workflow_id for backward compatibility with existing UI payloads.
"workflow_id": checkpoint.workflow_name,
"workflow_name": checkpoint.workflow_name,
"timestamp": checkpoint.timestamp, "timestamp": checkpoint.timestamp,
"status": "completed", "status": "completed",
"metadata": { "metadata": {
@@ -506,6 +508,7 @@ class InMemoryConversationStore(ConversationStore):
"message_count": sum(len(msgs) for msgs in checkpoint.messages.values()), "message_count": sum(len(msgs) for msgs in checkpoint.messages.values()),
"size_bytes": checkpoint_size, "size_bytes": checkpoint_size,
"version": checkpoint.version, "version": checkpoint.version,
"graph_signature_hash": checkpoint.graph_signature_hash,
}, },
} }
items.append(cast(ConversationItem, checkpoint_item)) items.append(cast(ConversationItem, checkpoint_item))
@@ -551,8 +554,9 @@ class InMemoryConversationStore(ConversationStore):
return None return None
# Load full checkpoint from storage # Load full checkpoint from storage
checkpoint = await checkpoint_storage.load_checkpoint(checkpoint_id) try:
if not checkpoint: checkpoint = await checkpoint_storage.load(checkpoint_id)
except Exception:
return None return None
# Calculate size of checkpoint # Calculate size of checkpoint
@@ -566,7 +570,9 @@ class InMemoryConversationStore(ConversationStore):
"id": item_id, "id": item_id,
"type": "checkpoint", "type": "checkpoint",
"checkpoint_id": checkpoint.checkpoint_id, "checkpoint_id": checkpoint.checkpoint_id,
"workflow_id": checkpoint.workflow_id, # Keep workflow_id for backward compatibility with existing UI payloads.
"workflow_id": checkpoint.workflow_name,
"workflow_name": checkpoint.workflow_name,
"timestamp": checkpoint.timestamp, "timestamp": checkpoint.timestamp,
"status": "completed", "status": "completed",
"metadata": { "metadata": {
@@ -577,6 +583,7 @@ class InMemoryConversationStore(ConversationStore):
"message_count": sum(len(msgs) for msgs in checkpoint.messages.values()), "message_count": sum(len(msgs) for msgs in checkpoint.messages.values()),
"size_bytes": checkpoint_size, "size_bytes": checkpoint_size,
"version": checkpoint.version, "version": checkpoint.version,
"graph_signature_hash": checkpoint.graph_signature_hash,
# 🔥 FULL checkpoint state (lazy loaded) # 🔥 FULL checkpoint state (lazy loaded)
"full_checkpoint": checkpoint.to_dict(), "full_checkpoint": checkpoint.to_dict(),
}, },
@@ -631,8 +638,8 @@ class InMemoryConversationStore(ConversationStore):
if conv_meta.get("type") == "workflow_session": if conv_meta.get("type") == "workflow_session":
checkpoint_storage = conv_data.get("checkpoint_storage") checkpoint_storage = conv_data.get("checkpoint_storage")
if checkpoint_storage: if checkpoint_storage:
checkpoints = await checkpoint_storage.list_checkpoints() checkpoints = self._list_all_checkpoints(checkpoint_storage)
latest = checkpoints[0] if checkpoints else None latest = max(checkpoints, key=lambda cp: cp.timestamp) if checkpoints else None
conv_meta["checkpoint_summary"] = { conv_meta["checkpoint_summary"] = {
"count": len(checkpoints), "count": len(checkpoints),
"latest_iteration": latest.iteration_count if latest else 0, "latest_iteration": latest.iteration_count if latest else 0,
@@ -654,6 +661,19 @@ class InMemoryConversationStore(ConversationStore):
return results return results
@staticmethod
def _list_all_checkpoints(checkpoint_storage: Any) -> list[WorkflowCheckpoint]:
"""Return all checkpoints from a conversation-scoped storage instance.
DevUI uses one checkpoint storage per conversation. Core storage APIs now
require workflow_name filters, so we gather directly from in-memory storage
internals to provide conversation-wide listing for UI views.
"""
checkpoint_map = getattr(checkpoint_storage, "_checkpoints", None)
if isinstance(checkpoint_map, dict):
return list(cast(dict[str, WorkflowCheckpoint], checkpoint_map).values())
return []
class CheckpointConversationManager: class CheckpointConversationManager:
"""Manages checkpoint storage for workflow sessions - SESSION-SCOPED. """Manages checkpoint storage for workflow sessions - SESSION-SCOPED.
@@ -104,17 +104,21 @@ class TestCheckpointConversationManager:
from agent_framework._workflows._checkpoint import WorkflowCheckpoint from agent_framework._workflows._checkpoint import WorkflowCheckpoint
checkpoint = WorkflowCheckpoint( checkpoint = WorkflowCheckpoint(
checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, state={"test": "data"} checkpoint_id=str(uuid.uuid4()),
workflow_name=test_workflow.name,
graph_signature_hash=test_workflow.graph_signature_hash,
messages={},
state={"test": "data"},
) )
# Get checkpoint storage for this conversation and save # Get checkpoint storage for this conversation and save
storage = checkpoint_manager.get_checkpoint_storage(conversation_id) storage = checkpoint_manager.get_checkpoint_storage(conversation_id)
checkpoint_id = await storage.save_checkpoint(checkpoint) checkpoint_id = await storage.save(checkpoint)
assert checkpoint_id == checkpoint.checkpoint_id assert checkpoint_id == checkpoint.checkpoint_id
# Verify checkpoint stored in THIS conversation only # Verify checkpoint stored in THIS conversation only
checkpoints = await storage.list_checkpoints() checkpoints = await storage.list_checkpoints(workflow_name=test_workflow.name)
assert len(checkpoints) == 1 assert len(checkpoints) == 1
assert checkpoints[0].checkpoint_id == checkpoint.checkpoint_id assert checkpoints[0].checkpoint_id == checkpoint.checkpoint_id
@@ -140,20 +144,21 @@ class TestCheckpointConversationManager:
checkpoint_a = WorkflowCheckpoint( checkpoint_a = WorkflowCheckpoint(
checkpoint_id=str(uuid.uuid4()), checkpoint_id=str(uuid.uuid4()),
workflow_id=test_workflow.id, workflow_name=test_workflow.name,
graph_signature_hash=test_workflow.graph_signature_hash,
messages={}, messages={},
state={"conversation": "A"}, state={"conversation": "A"},
) )
storage_a = checkpoint_manager.get_checkpoint_storage(conv_a) storage_a = checkpoint_manager.get_checkpoint_storage(conv_a)
await storage_a.save_checkpoint(checkpoint_a) await storage_a.save(checkpoint_a)
# Verify conversation A has checkpoint # Verify conversation A has checkpoint
checkpoints_a = await storage_a.list_checkpoints() checkpoints_a = await storage_a.list_checkpoints(workflow_name=test_workflow.name)
assert len(checkpoints_a) == 1 assert len(checkpoints_a) == 1
# Verify conversation B has NO checkpoints (isolation) # Verify conversation B has NO checkpoints (isolation)
storage_b = checkpoint_manager.get_checkpoint_storage(conv_b) storage_b = checkpoint_manager.get_checkpoint_storage(conv_b)
checkpoints_b = await storage_b.list_checkpoints() checkpoints_b = await storage_b.list_checkpoints(workflow_name=test_workflow.name)
assert len(checkpoints_b) == 0 assert len(checkpoints_b) == 0
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -177,15 +182,16 @@ class TestCheckpointConversationManager:
for i in range(3): for i in range(3):
checkpoint = WorkflowCheckpoint( checkpoint = WorkflowCheckpoint(
checkpoint_id=str(uuid.uuid4()), checkpoint_id=str(uuid.uuid4()),
workflow_id=test_workflow.id, workflow_name=test_workflow.name,
graph_signature_hash=test_workflow.graph_signature_hash,
messages={}, messages={},
state={"iteration": i}, state={"iteration": i},
) )
saved_id = await storage.save_checkpoint(checkpoint) saved_id = await storage.save(checkpoint)
checkpoint_ids.append(saved_id) checkpoint_ids.append(saved_id)
# List checkpoints using the storage # List checkpoints using the storage
checkpoints_list = await storage.list_checkpoints() checkpoints_list = await storage.list_checkpoints(workflow_name=test_workflow.name)
assert len(checkpoints_list) == 3 assert len(checkpoints_list) == 3
# Verify all checkpoint IDs are present # Verify all checkpoint IDs are present
@@ -213,11 +219,12 @@ class TestCheckpointConversationManager:
for i in range(2): for i in range(2):
checkpoint = WorkflowCheckpoint( checkpoint = WorkflowCheckpoint(
checkpoint_id=f"checkpoint_{i}", checkpoint_id=f"checkpoint_{i}",
workflow_id=test_workflow.id, workflow_name=test_workflow.name,
graph_signature_hash=test_workflow.graph_signature_hash,
messages={}, messages={},
state={"iteration": i}, state={"iteration": i},
) )
saved_id = await storage.save_checkpoint(checkpoint) saved_id = await storage.save(checkpoint)
checkpoint_ids.append(saved_id) checkpoint_ids.append(saved_id)
# List conversation items - should include checkpoints # List conversation items - should include checkpoints
@@ -233,7 +240,7 @@ class TestCheckpointConversationManager:
for item in checkpoint_items: for item in checkpoint_items:
assert item.get("type") == "checkpoint" assert item.get("type") == "checkpoint"
assert item.get("checkpoint_id") in checkpoint_ids assert item.get("checkpoint_id") in checkpoint_ids
assert item.get("workflow_id") == test_workflow.id assert item.get("workflow_name") == test_workflow.name
assert "timestamp" in item assert "timestamp" in item
assert item.get("id").startswith("checkpoint_") # ID format: checkpoint_{checkpoint_id} assert item.get("id").startswith("checkpoint_") # ID format: checkpoint_{checkpoint_id}
@@ -255,21 +262,22 @@ class TestCheckpointConversationManager:
original_checkpoint = WorkflowCheckpoint( original_checkpoint = WorkflowCheckpoint(
checkpoint_id=str(uuid.uuid4()), checkpoint_id=str(uuid.uuid4()),
workflow_id=test_workflow.id, workflow_name=test_workflow.name,
graph_signature_hash=test_workflow.graph_signature_hash,
messages={}, messages={},
state={"test_key": "test_value"}, state={"test_key": "test_value"},
) )
# Save to this session # Save to this session
storage = checkpoint_manager.get_checkpoint_storage(conversation_id) storage = checkpoint_manager.get_checkpoint_storage(conversation_id)
await storage.save_checkpoint(original_checkpoint) await storage.save(original_checkpoint)
# Load checkpoint from this session # Load checkpoint from this session
loaded_checkpoint = await storage.load_checkpoint(original_checkpoint.checkpoint_id) loaded_checkpoint = await storage.load(original_checkpoint.checkpoint_id)
assert loaded_checkpoint is not None assert loaded_checkpoint is not None
assert loaded_checkpoint.checkpoint_id == original_checkpoint.checkpoint_id assert loaded_checkpoint.checkpoint_id == original_checkpoint.checkpoint_id
assert loaded_checkpoint.workflow_id == original_checkpoint.workflow_id assert loaded_checkpoint.workflow_name == original_checkpoint.workflow_name
assert loaded_checkpoint.state == {"test_key": "test_value"} assert loaded_checkpoint.state == {"test_key": "test_value"}
@@ -296,24 +304,28 @@ class TestCheckpointStorage:
from agent_framework._workflows._checkpoint import WorkflowCheckpoint from agent_framework._workflows._checkpoint import WorkflowCheckpoint
checkpoint = WorkflowCheckpoint( checkpoint = WorkflowCheckpoint(
checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, state={"test": "data"} checkpoint_id=str(uuid.uuid4()),
workflow_name=test_workflow.name,
graph_signature_hash=test_workflow.graph_signature_hash,
messages={},
state={"test": "data"},
) )
# Test save_checkpoint # Test save
checkpoint_id = await storage.save_checkpoint(checkpoint) checkpoint_id = await storage.save(checkpoint)
assert checkpoint_id == checkpoint.checkpoint_id assert checkpoint_id == checkpoint.checkpoint_id
# Test load_checkpoint # Test load
loaded = await storage.load_checkpoint(checkpoint_id) loaded = await storage.load(checkpoint_id)
assert loaded is not None assert loaded is not None
assert loaded.checkpoint_id == checkpoint_id assert loaded.checkpoint_id == checkpoint_id
# Test list_checkpoint_ids # Test list_checkpoint_ids
ids = await storage.list_checkpoint_ids(workflow_id=test_workflow.id) ids = await storage.list_checkpoint_ids(workflow_name=test_workflow.name)
assert checkpoint_id in ids assert checkpoint_id in ids
# Test list_checkpoints # Test list_checkpoints
checkpoints_list = await storage.list_checkpoints(workflow_id=test_workflow.id) checkpoints_list = await storage.list_checkpoints(workflow_name=test_workflow.name)
assert len(checkpoints_list) >= 1 assert len(checkpoints_list) >= 1
assert any(cp.checkpoint_id == checkpoint_id for cp in checkpoints_list) assert any(cp.checkpoint_id == checkpoint_id for cp in checkpoints_list)
@@ -346,12 +358,16 @@ class TestIntegration:
from agent_framework._workflows._checkpoint import WorkflowCheckpoint from agent_framework._workflows._checkpoint import WorkflowCheckpoint
checkpoint = WorkflowCheckpoint( checkpoint = WorkflowCheckpoint(
checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, state={"injected": True} checkpoint_id=str(uuid.uuid4()),
workflow_name=test_workflow.name,
graph_signature_hash=test_workflow.graph_signature_hash,
messages={},
state={"injected": True},
) )
await checkpoint_storage.save_checkpoint(checkpoint) await checkpoint_storage.save(checkpoint)
# Verify checkpoint is accessible via storage (in this session) # Verify checkpoint is accessible via storage (in this session)
storage_checkpoints = await checkpoint_storage.list_checkpoints() storage_checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=test_workflow.name)
assert len(storage_checkpoints) > 0 assert len(storage_checkpoints) > 0
assert storage_checkpoints[0].checkpoint_id == checkpoint.checkpoint_id assert storage_checkpoints[0].checkpoint_id == checkpoint.checkpoint_id
@@ -377,20 +393,21 @@ class TestIntegration:
checkpoint = WorkflowCheckpoint( checkpoint = WorkflowCheckpoint(
checkpoint_id=str(uuid.uuid4()), checkpoint_id=str(uuid.uuid4()),
workflow_id=test_workflow.id, workflow_name=test_workflow.name,
graph_signature_hash=test_workflow.graph_signature_hash,
messages={}, messages={},
state={"ready_to_resume": True}, state={"ready_to_resume": True},
) )
checkpoint_id = await checkpoint_storage.save_checkpoint(checkpoint) checkpoint_id = await checkpoint_storage.save(checkpoint)
# Verify checkpoint can be loaded for resume # Verify checkpoint can be loaded for resume
loaded = await checkpoint_storage.load_checkpoint(checkpoint_id) loaded = await checkpoint_storage.load(checkpoint_id)
assert loaded is not None assert loaded is not None
assert loaded.checkpoint_id == checkpoint_id assert loaded.checkpoint_id == checkpoint_id
assert loaded.state == {"ready_to_resume": True} assert loaded.state == {"ready_to_resume": True}
# Verify checkpoint is accessible via storage (for UI to list checkpoints) # Verify checkpoint is accessible via storage (for UI to list checkpoints)
checkpoints = await checkpoint_storage.list_checkpoints() checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=test_workflow.name)
assert len(checkpoints) > 0 assert len(checkpoints) > 0
assert checkpoints[0].checkpoint_id == checkpoint_id assert checkpoints[0].checkpoint_id == checkpoint_id
@@ -420,7 +437,7 @@ class TestIntegration:
test_workflow._runner.context._checkpoint_storage = checkpoint_storage test_workflow._runner.context._checkpoint_storage = checkpoint_storage
# Verify no checkpoints initially # Verify no checkpoints initially
checkpoints_before = await checkpoint_storage.list_checkpoints() checkpoints_before = await checkpoint_storage.list_checkpoints(workflow_name=test_workflow.name)
assert len(checkpoints_before) == 0 assert len(checkpoints_before) == 0
# Run workflow until it reaches IDLE_WITH_PENDING_REQUESTS (after checkpoint is created) # Run workflow until it reaches IDLE_WITH_PENDING_REQUESTS (after checkpoint is created)
@@ -435,9 +452,9 @@ class TestIntegration:
assert saw_request_event, "Test workflow should have emitted request_info event (type='request_info')" assert saw_request_event, "Test workflow should have emitted request_info event (type='request_info')"
# Verify checkpoint was AUTOMATICALLY saved to our storage by the framework # Verify checkpoint was AUTOMATICALLY saved to our storage by the framework
checkpoints_after = await checkpoint_storage.list_checkpoints() checkpoints_after = await checkpoint_storage.list_checkpoints(workflow_name=test_workflow.name)
assert len(checkpoints_after) > 0, "Workflow should have auto-saved checkpoint at HIL pause" assert len(checkpoints_after) > 0, "Workflow should have auto-saved checkpoint at HIL pause"
# Verify checkpoint has correct workflow_id # Verify checkpoint has correct workflow identity
checkpoint = checkpoints_after[0] checkpoint = checkpoints_after[0]
assert checkpoint.workflow_id == test_workflow.id assert checkpoint.workflow_name == test_workflow.name
@@ -379,26 +379,27 @@ async def test_checkpoint_api_endpoints(test_entities_dir):
storage = executor.checkpoint_manager.get_checkpoint_storage(conv_id) storage = executor.checkpoint_manager.get_checkpoint_storage(conv_id)
checkpoint = WorkflowCheckpoint( checkpoint = WorkflowCheckpoint(
checkpoint_id="test_checkpoint_1", checkpoint_id="test_checkpoint_1",
workflow_id="test_workflow", workflow_name="test_workflow",
graph_signature_hash="test_graph_hash",
state={"key": "value"}, state={"key": "value"},
iteration_count=1, iteration_count=1,
) )
await storage.save_checkpoint(checkpoint) await storage.save(checkpoint)
# Test list checkpoints endpoint # Test list checkpoints endpoint
checkpoints = await storage.list_checkpoints() checkpoints = await storage.list_checkpoints(workflow_name="test_workflow")
assert len(checkpoints) == 1 assert len(checkpoints) == 1
assert checkpoints[0].checkpoint_id == "test_checkpoint_1" assert checkpoints[0].checkpoint_id == "test_checkpoint_1"
assert checkpoints[0].workflow_id == "test_workflow" assert checkpoints[0].workflow_name == "test_workflow"
# Test delete checkpoint endpoint # Test delete checkpoint endpoint
deleted = await storage.delete_checkpoint("test_checkpoint_1") deleted = await storage.delete("test_checkpoint_1")
assert deleted is True assert deleted is True
# Verify checkpoint was deleted # Verify checkpoint was deleted
remaining = await storage.list_checkpoints() remaining = await storage.list_checkpoints(workflow_name="test_workflow")
assert len(remaining) == 0 assert len(remaining) == 0
# Test delete non-existent checkpoint # Test delete non-existent checkpoint
deleted = await storage.delete_checkpoint("nonexistent") deleted = await storage.delete("nonexistent")
assert deleted is False assert deleted is False
@@ -189,19 +189,29 @@ class TestClientAgentExecutorPollingConfiguration:
# Verify get_entity was called 2 times (max_poll_retries) # Verify get_entity was called 2 times (max_poll_retries)
assert mock_client.get_entity.call_count == 2 assert mock_client.get_entity.call_count == 2
def test_executor_respects_custom_poll_interval(self, mock_client: Mock, sample_run_request: RunRequest) -> None: def test_executor_respects_custom_poll_interval(
self,
mock_client: Mock,
sample_run_request: RunRequest,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Verify executor respects custom poll_interval_seconds during polling.""" """Verify executor respects custom poll_interval_seconds during polling."""
# Create executor with very short interval # Create executor with very short interval
executor = ClientAgentExecutor(mock_client, max_poll_retries=3, poll_interval_seconds=0.01) executor = ClientAgentExecutor(mock_client, max_poll_retries=3, poll_interval_seconds=0.01)
# Measure time taken sleep_calls: list[float] = []
start = time.time()
result = executor.run_durable_agent("test_agent", sample_run_request)
elapsed = time.time() - start
# Should take roughly 3 * 0.01 = 0.03 seconds (plus overhead) def fake_sleep(seconds: float) -> None:
# Be generous with timing to avoid flakiness sleep_calls.append(seconds)
assert elapsed < 0.2 # Should be quick with 0.01 interval
# Use deterministic assertions instead of wall-clock timing to avoid CI flakiness.
monkeypatch.setattr("agent_framework_durabletask._executors.time.sleep", fake_sleep)
result = executor.run_durable_agent("test_agent", sample_run_request)
assert len(sleep_calls) == 3
assert sleep_calls == pytest.approx([0.01, 0.01, 0.01])
assert mock_client.get_entity.call_count == 3
assert isinstance(result, AgentResponse) assert isinstance(result, AgentResponse)
@@ -8,9 +8,11 @@ from typing import Any
from agent_framework import ( from agent_framework import (
Agent, Agent,
AgentResponseUpdate,
Content, Content,
FileCheckpointStorage, FileCheckpointStorage,
Workflow, Workflow,
WorkflowEvent,
tool, tool,
) )
from agent_framework.azure import AzureOpenAIResponsesClient from agent_framework.azure import AzureOpenAIResponsesClient
@@ -183,8 +185,16 @@ async def main() -> None:
initial_request = "Hi, my order 12345 arrived damaged. I need a refund." initial_request = "Hi, my order 12345 arrived damaged. I need a refund."
# Phase 1: Initial run - workflow will pause when it needs user input # Phase 1: Initial run - workflow will pause when it needs user input
results = await workflow.run(message=initial_request) print("Running initial workflow...")
request_events = results.get_request_info_events() results = await workflow.run(message=initial_request, stream=True)
# Iterate through streamed events and collect request_info events
request_events: list[WorkflowEvent] = []
async for event in results:
event: WorkflowEvent
if event.type == "request_info":
request_events.append(event)
if not request_events: if not request_events:
print("Workflow completed without needing user input") print("Workflow completed without needing user input")
return return
@@ -224,8 +234,17 @@ async def main() -> None:
raise RuntimeError("No checkpoints found.") raise RuntimeError("No checkpoints found.")
checkpoint_id = checkpoint.checkpoint_id checkpoint_id = checkpoint.checkpoint_id
results = await workflow.run(responses=responses, checkpoint_id=checkpoint_id) print("Resuming workflow from checkpoint...")
request_events = results.get_request_info_events() results = await workflow.run(responses=responses, checkpoint_id=checkpoint_id, stream=True)
# Iterate through streamed events and collect request_info events
request_events: list[WorkflowEvent] = []
async for event in results:
event: WorkflowEvent
if event.type == "request_info":
request_events.append(event)
elif event.type == "output" and isinstance(event.data, AgentResponseUpdate):
print(event.data.text, end="", flush=True)
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("DEMO COMPLETE") print("DEMO COMPLETE")
Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 MiB

@@ -1,186 +0,0 @@
# Copyright (c) Microsoft. All rights reserved.
"""
Handoff Workflow with Code Interpreter File Generation Sample
This sample demonstrates retrieving file IDs from code interpreter output
in a handoff workflow context. A triage agent routes to a code specialist
that generates a text file, and we verify the file_id is captured correctly
from the streaming workflow events.
Verifies GitHub issue #2718: files generated by code interpreter in
HandoffBuilder workflows can be properly retrieved.
Prerequisites:
- AZURE_AI_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint.
- `az login` (Azure CLI authentication)
- AZURE_AI_MODEL_DEPLOYMENT_NAME
"""
import asyncio
import os
from collections.abc import AsyncIterable
from typing import cast
from agent_framework import (
AgentResponseUpdate,
Message,
WorkflowEvent,
WorkflowRunState,
)
from agent_framework.azure import AzureOpenAIResponsesClient
from agent_framework.orchestrations import HandoffAgentUserRequest, HandoffBuilder
from azure.identity import AzureCliCredential
async def _drain(stream: AsyncIterable[WorkflowEvent]) -> list[WorkflowEvent]:
"""Collect all events from an async stream."""
return [event async for event in stream]
def _handle_events(events: list[WorkflowEvent]) -> tuple[list[WorkflowEvent[HandoffAgentUserRequest]], list[str]]:
"""Process workflow events and extract file IDs and pending requests.
Returns:
Tuple of (pending_requests, file_ids_found)
"""
requests: list[WorkflowEvent[HandoffAgentUserRequest]] = []
file_ids: list[str] = []
for event in events:
if event.type == "handoff_sent":
print(f"\n[Handoff from {event.data.source} to {event.data.target} initiated.]")
elif event.type == "status" and event.state in {
WorkflowRunState.IDLE,
WorkflowRunState.IDLE_WITH_PENDING_REQUESTS,
}:
print(f"[status] {event.state}")
elif event.type == "request_info" and isinstance(event.data, HandoffAgentUserRequest):
requests.append(cast(WorkflowEvent[HandoffAgentUserRequest], event))
elif event.type == "output":
data = event.data
if isinstance(data, AgentResponseUpdate):
for content in data.contents:
if content.type == "hosted_file":
file_ids.append(content.file_id) # type: ignore
print(f"[Found HostedFileContent: file_id={content.file_id}]")
elif content.type == "text" and content.annotations:
for annotation in content.annotations:
file_id = annotation["file_id"] # type: ignore
file_ids.append(file_id)
print(f"[Found file annotation: file_id={file_id}]")
elif isinstance(data, list):
conversation = cast(list[Message], data)
if isinstance(conversation, list):
print("\n=== Final Conversation Snapshot ===")
for message in conversation:
speaker = message.author_name or message.role
print(f"- {speaker}: {message.text or [content.type for content in message.contents]}")
print("===================================")
return requests, file_ids
async def main() -> None:
"""Run a simple handoff workflow with code interpreter file generation."""
print("=== Handoff Workflow with Code Interpreter File Generation ===\n")
client = AzureOpenAIResponsesClient(
project_endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"],
deployment_name=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
credential=AzureCliCredential(),
)
triage = client.as_agent(
name="triage_agent",
instructions=(
"You are a triage agent. Route code-related requests to the code_specialist. "
"When the user asks to create or generate files, hand off to code_specialist "
"by calling handoff_to_code_specialist."
),
)
code_interpreter_tool = client.get_code_interpreter_tool()
code_specialist = client.as_agent(
name="code_specialist",
instructions=(
"You are a Python code specialist. Use the code interpreter to execute Python code "
"and create files when requested. Always save files to /mnt/data/ directory."
),
tools=[code_interpreter_tool],
)
workflow = (
HandoffBuilder(
termination_condition=lambda conv: sum(1 for msg in conv if msg.role == "user") >= 2,
)
.participants([triage, code_specialist])
.with_start_agent(triage)
.build()
)
user_inputs = [
"Please create a text file called hello.txt with 'Hello from handoff workflow!' inside it.",
"exit",
]
input_index = 0
all_file_ids: list[str] = []
print(f"User: {user_inputs[0]}")
events = await _drain(workflow.run(user_inputs[0], stream=True))
requests, file_ids = _handle_events(events)
all_file_ids.extend(file_ids)
input_index += 1
while requests:
request = requests[0]
if input_index >= len(user_inputs):
break
user_input = user_inputs[input_index]
print(f"\nUser: {user_input}")
responses = {request.request_id: HandoffAgentUserRequest.create_response(user_input)}
events = await _drain(workflow.run(stream=True, responses=responses))
requests, file_ids = _handle_events(events)
all_file_ids.extend(file_ids)
input_index += 1
print("\n" + "=" * 50)
if all_file_ids:
print(f"SUCCESS: Found {len(all_file_ids)} file ID(s) in handoff workflow:")
for fid in all_file_ids:
print(f" - {fid}")
else:
print("WARNING: No file IDs captured from the handoff workflow.")
print("=" * 50)
"""
Sample Output:
User: Please create a text file called hello.txt with 'Hello from handoff workflow!' inside it.
[Found HostedFileContent: file_id=assistant-JT1sA...]
=== Conversation So Far ===
- user: Please create a text file called hello.txt with 'Hello from handoff workflow!' inside it.
- triage_agent: I am handing off your request to create the text file "hello.txt" with the specified content to the code specialist. They will assist you shortly.
- code_specialist: The file "hello.txt" has been created with the content "Hello from handoff workflow!". You can download it using the link below:
[hello.txt](sandbox:/mnt/data/hello.txt)
===========================
[status] IDLE_WITH_PENDING_REQUESTS
User: exit
[status] IDLE
==================================================
SUCCESS: Found 1 file ID(s) in handoff workflow:
- assistant-JT1sA...
==================================================
""" # noqa: E501
if __name__ == "__main__":
asyncio.run(main())