mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix streamed workflow agent continuation context by finalizing AgentExecutor streams (#3882)
* Fix streamed workflow agent continuation context by finalizing AgentExecutor streams * Fix stream handling * Fixes * Fix DevUI and tests
This commit is contained in:
committed by
GitHub
Unverified
parent
2203fa0f8b
commit
a276c1295a
@@ -911,19 +911,21 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
|||||||
if ctx is None:
|
if ctx is None:
|
||||||
return # No context available (shouldn't happen in normal flow)
|
return # No context available (shouldn't happen in normal flow)
|
||||||
|
|
||||||
|
# Update thread with conversation_id derived from streaming raw updates.
|
||||||
|
# Using response_id here can break function-call continuation for APIs
|
||||||
|
# where response IDs are not valid conversation handles.
|
||||||
|
conversation_id = self._extract_conversation_id_from_streaming_response(response)
|
||||||
# Ensure author names are set for all messages
|
# Ensure author names are set for all messages
|
||||||
for message in response.messages:
|
for message in response.messages:
|
||||||
if message.author_name is None:
|
if message.author_name is None:
|
||||||
message.author_name = ctx["agent_name"]
|
message.author_name = ctx["agent_name"]
|
||||||
|
|
||||||
# Propagate conversation_id back to session from streaming updates
|
# Propagate conversation_id back to session from streaming updates.
|
||||||
|
# For Responses-style APIs this can rotate every turn (response_id-based continuation),
|
||||||
|
# so refresh when a newer value is returned.
|
||||||
sess = ctx["session"]
|
sess = ctx["session"]
|
||||||
if sess and not sess.service_session_id and response.raw_representation:
|
if sess and conversation_id and sess.service_session_id != conversation_id:
|
||||||
raw_items = response.raw_representation if isinstance(response.raw_representation, list) else []
|
sess.service_session_id = conversation_id
|
||||||
for item in raw_items:
|
|
||||||
if hasattr(item, "conversation_id") and item.conversation_id:
|
|
||||||
sess.service_session_id = item.conversation_id
|
|
||||||
break
|
|
||||||
|
|
||||||
# Run after_run providers (reverse order)
|
# Run after_run providers (reverse order)
|
||||||
session_context = ctx["session_context"]
|
session_context = ctx["session_context"]
|
||||||
@@ -974,6 +976,27 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
|||||||
output_format_type = response_format if isinstance(response_format, type) else None
|
output_format_type = response_format if isinstance(response_format, type) else None
|
||||||
return AgentResponse.from_updates(updates, output_format_type=output_format_type)
|
return AgentResponse.from_updates(updates, output_format_type=output_format_type)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_conversation_id_from_streaming_response(response: AgentResponse[Any]) -> str | None:
|
||||||
|
"""Extract conversation_id from streaming raw updates, if present."""
|
||||||
|
raw = response.raw_representation
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
raw_items: list[Any] = raw if isinstance(raw, list) else [raw]
|
||||||
|
for item in reversed(raw_items):
|
||||||
|
if isinstance(item, Mapping):
|
||||||
|
value = item.get("conversation_id")
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
return value
|
||||||
|
continue
|
||||||
|
|
||||||
|
value = getattr(item, "conversation_id", None)
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
return value
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
async def _prepare_run_context(
|
async def _prepare_run_context(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -1100,8 +1123,10 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
|||||||
if message.author_name is None:
|
if message.author_name is None:
|
||||||
message.author_name = agent_name
|
message.author_name = agent_name
|
||||||
|
|
||||||
# Propagate conversation_id back to session (e.g. thread ID from Assistants API)
|
# Propagate conversation_id back to session (e.g. thread ID from Assistants API).
|
||||||
if session and response.conversation_id and not session.service_session_id:
|
# For Responses-style APIs this can rotate every turn (response_id-based continuation),
|
||||||
|
# so refresh when a newer value is returned.
|
||||||
|
if session and response.conversation_id and session.service_session_id != response.conversation_id:
|
||||||
session.service_session_id = response.conversation_id
|
session.service_session_id = response.conversation_id
|
||||||
|
|
||||||
# Set the response on the context for after_run providers
|
# Set the response on the context for after_run providers
|
||||||
|
|||||||
@@ -872,7 +872,16 @@ class MCPTool:
|
|||||||
k: v
|
k: v
|
||||||
for k, v in kwargs.items()
|
for k, v in kwargs.items()
|
||||||
if k
|
if k
|
||||||
not in {"chat_options", "tools", "tool_choice", "session", "thread", "conversation_id", "options", "response_format"}
|
not in {
|
||||||
|
"chat_options",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
"session",
|
||||||
|
"thread",
|
||||||
|
"conversation_id",
|
||||||
|
"options",
|
||||||
|
"response_format",
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
parser = self.parse_tool_results or _parse_tool_result_from_mcp
|
parser = self.parse_tool_results or _parse_tool_result_from_mcp
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from collections.abc import Mapping
|
from collections.abc import Awaitable, Callable, Mapping
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
@@ -358,22 +358,31 @@ class AgentExecutor(Executor):
|
|||||||
run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {})
|
run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {})
|
||||||
|
|
||||||
updates: list[AgentResponseUpdate] = []
|
updates: list[AgentResponseUpdate] = []
|
||||||
user_input_requests: list[Content] = []
|
streamed_user_input_requests: list[Content] = []
|
||||||
async for update in self._agent.run(
|
stream = self._agent.run(
|
||||||
self._cache,
|
self._cache,
|
||||||
stream=True,
|
stream=True,
|
||||||
session=self._session,
|
session=self._session,
|
||||||
options=options,
|
options=options,
|
||||||
**run_kwargs,
|
**run_kwargs,
|
||||||
):
|
)
|
||||||
|
async for update in stream:
|
||||||
updates.append(update)
|
updates.append(update)
|
||||||
await ctx.yield_output(update)
|
await ctx.yield_output(update)
|
||||||
|
|
||||||
if update.user_input_requests:
|
if update.user_input_requests:
|
||||||
user_input_requests.extend(update.user_input_requests)
|
streamed_user_input_requests.extend(update.user_input_requests)
|
||||||
|
|
||||||
# Build the final AgentResponse from the collected updates
|
# Prefer stream finalization when available so result hooks run
|
||||||
if is_chat_agent(self._agent):
|
# (e.g., thread conversation updates). Fall back to reconstructing from updates
|
||||||
|
# for legacy/custom agents that return a plain async iterable.
|
||||||
|
# TODO(evmattso): Integrate workflow agent run handling around ResponseStream so
|
||||||
|
# AgentExecutor does not need this conditional stream-finalization branch.
|
||||||
|
maybe_get_final_response = getattr(stream, "get_final_response", None)
|
||||||
|
get_final_response = maybe_get_final_response if callable(maybe_get_final_response) else None
|
||||||
|
response: AgentResponse[Any]
|
||||||
|
if get_final_response is not None:
|
||||||
|
response = await cast(Callable[[], Awaitable[AgentResponse[Any]]], get_final_response)()
|
||||||
|
elif is_chat_agent(self._agent):
|
||||||
response_format = self._agent.default_options.get("response_format")
|
response_format = self._agent.default_options.get("response_format")
|
||||||
response = AgentResponse.from_updates(
|
response = AgentResponse.from_updates(
|
||||||
updates,
|
updates,
|
||||||
@@ -383,6 +392,16 @@ class AgentExecutor(Executor):
|
|||||||
response = AgentResponse.from_updates(updates)
|
response = AgentResponse.from_updates(updates)
|
||||||
|
|
||||||
# Handle any user input requests after the streaming completes
|
# Handle any user input requests after the streaming completes
|
||||||
|
user_input_requests: list[Content] = []
|
||||||
|
seen_request_ids: set[str] = set()
|
||||||
|
for user_input_request in [*streamed_user_input_requests, *response.user_input_requests]:
|
||||||
|
request_id = getattr(user_input_request, "id", None)
|
||||||
|
if isinstance(request_id, str) and request_id:
|
||||||
|
if request_id in seen_request_ids:
|
||||||
|
continue
|
||||||
|
seen_request_ids.add(request_id)
|
||||||
|
user_input_requests.append(user_input_request)
|
||||||
|
|
||||||
if user_input_requests:
|
if user_input_requests:
|
||||||
for user_input_request in user_input_requests:
|
for user_input_request in user_input_requests:
|
||||||
self._pending_agent_requests[user_input_request.id] = user_input_request # type: ignore[index]
|
self._pending_agent_requests[user_input_request.id] = user_input_request # type: ignore[index]
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from agent_framework import (
|
|||||||
BaseContextProvider,
|
BaseContextProvider,
|
||||||
ChatOptions,
|
ChatOptions,
|
||||||
ChatResponse,
|
ChatResponse,
|
||||||
|
ChatResponseUpdate,
|
||||||
Content,
|
Content,
|
||||||
FunctionTool,
|
FunctionTool,
|
||||||
Message,
|
Message,
|
||||||
@@ -154,6 +155,111 @@ async def test_chat_client_agent_run_with_session(chat_client_base: SupportsChat
|
|||||||
assert session.service_session_id == "123"
|
assert session.service_session_id == "123"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_chat_client_agent_updates_existing_session_id_non_streaming(
|
||||||
|
chat_client_base: SupportsChatGetResponse,
|
||||||
|
) -> None:
|
||||||
|
chat_client_base.run_responses = [
|
||||||
|
ChatResponse(
|
||||||
|
messages=[Message(role="assistant", contents=[Content.from_text("test response")])],
|
||||||
|
conversation_id="resp_new_123",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
agent = Agent(client=chat_client_base)
|
||||||
|
session = agent.get_session(service_session_id="resp_old_123")
|
||||||
|
|
||||||
|
await agent.run("Hello", session=session)
|
||||||
|
assert session.service_session_id == "resp_new_123"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_chat_client_agent_update_session_id_streaming_uses_conversation_id(
|
||||||
|
chat_client_base: SupportsChatGetResponse,
|
||||||
|
) -> None:
|
||||||
|
chat_client_base.streaming_responses = [
|
||||||
|
[
|
||||||
|
ChatResponseUpdate(
|
||||||
|
contents=[Content.from_text("stream part 1")],
|
||||||
|
role="assistant",
|
||||||
|
response_id="resp_stream_123",
|
||||||
|
conversation_id="conv_stream_456",
|
||||||
|
),
|
||||||
|
ChatResponseUpdate(
|
||||||
|
contents=[Content.from_text(" stream part 2")],
|
||||||
|
role="assistant",
|
||||||
|
response_id="resp_stream_123",
|
||||||
|
conversation_id="conv_stream_456",
|
||||||
|
finish_reason="stop",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
agent = Agent(client=chat_client_base)
|
||||||
|
session = agent.create_session()
|
||||||
|
|
||||||
|
stream = agent.run("Hello", session=session, stream=True)
|
||||||
|
async for _ in stream:
|
||||||
|
pass
|
||||||
|
result = await stream.get_final_response()
|
||||||
|
assert result.text == "stream part 1 stream part 2"
|
||||||
|
assert session.service_session_id == "conv_stream_456"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_chat_client_agent_updates_existing_session_id_streaming(
|
||||||
|
chat_client_base: SupportsChatGetResponse,
|
||||||
|
) -> None:
|
||||||
|
chat_client_base.streaming_responses = [
|
||||||
|
[
|
||||||
|
ChatResponseUpdate(
|
||||||
|
contents=[Content.from_text("stream part 1")],
|
||||||
|
role="assistant",
|
||||||
|
response_id="resp_stream_123",
|
||||||
|
conversation_id="resp_new_456",
|
||||||
|
),
|
||||||
|
ChatResponseUpdate(
|
||||||
|
contents=[Content.from_text(" stream part 2")],
|
||||||
|
role="assistant",
|
||||||
|
response_id="resp_stream_123",
|
||||||
|
conversation_id="resp_new_456",
|
||||||
|
finish_reason="stop",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
agent = Agent(client=chat_client_base)
|
||||||
|
session = agent.get_session(service_session_id="resp_old_456")
|
||||||
|
|
||||||
|
stream = agent.run("Hello", session=session, stream=True)
|
||||||
|
async for _ in stream:
|
||||||
|
pass
|
||||||
|
await stream.get_final_response()
|
||||||
|
assert session.service_session_id == "resp_new_456"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_chat_client_agent_update_session_id_streaming_does_not_use_response_id(
|
||||||
|
chat_client_base: SupportsChatGetResponse,
|
||||||
|
) -> None:
|
||||||
|
chat_client_base.streaming_responses = [
|
||||||
|
[
|
||||||
|
ChatResponseUpdate(
|
||||||
|
contents=[Content.from_text("stream response without conversation id")],
|
||||||
|
role="assistant",
|
||||||
|
response_id="resp_only_123",
|
||||||
|
finish_reason="stop",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
agent = Agent(client=chat_client_base)
|
||||||
|
session = agent.create_session()
|
||||||
|
|
||||||
|
stream = agent.run("Hello", session=session, stream=True)
|
||||||
|
async for _ in stream:
|
||||||
|
pass
|
||||||
|
result = await stream.get_final_response()
|
||||||
|
assert result.text == "stream response without conversation id"
|
||||||
|
assert session.service_session_id is None
|
||||||
|
|
||||||
|
|
||||||
async def test_chat_client_agent_update_session_messages(client: SupportsChatGetResponse) -> None:
|
async def test_chat_client_agent_update_session_messages(client: SupportsChatGetResponse) -> None:
|
||||||
agent = Agent(client=client)
|
agent = Agent(client=client)
|
||||||
session = agent.create_session()
|
session = agent.create_session()
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
# Copyright (c) Microsoft. All rights reserved.
|
||||||
|
|||||||
@@ -50,6 +50,57 @@ class _CountingAgent(BaseAgent):
|
|||||||
return _run()
|
return _run()
|
||||||
|
|
||||||
|
|
||||||
|
class _StreamingHookAgent(BaseAgent):
|
||||||
|
"""Agent that exposes whether its streaming result hook was executed."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.result_hook_called = False
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
messages: str | Message | list[str] | list[Message] | None = None,
|
||||||
|
*,
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||||
|
if stream:
|
||||||
|
|
||||||
|
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
|
||||||
|
yield AgentResponseUpdate(
|
||||||
|
contents=[Content.from_text(text="hook test")],
|
||||||
|
role="assistant",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _mark_result_hook_called(response: AgentResponse) -> AgentResponse:
|
||||||
|
self.result_hook_called = True
|
||||||
|
return response
|
||||||
|
|
||||||
|
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates).with_result_hook(
|
||||||
|
_mark_result_hook_called
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _run() -> AgentResponse:
|
||||||
|
return AgentResponse(messages=[Message("assistant", ["hook test"])])
|
||||||
|
|
||||||
|
return _run()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> None:
|
||||||
|
"""AgentExecutor should call get_final_response() so stream result hooks execute."""
|
||||||
|
agent = _StreamingHookAgent(id="hook_agent", name="HookAgent")
|
||||||
|
executor = AgentExecutor(agent, id="hook_exec")
|
||||||
|
workflow = SequentialBuilder(participants=[executor]).build()
|
||||||
|
|
||||||
|
output_events: list[Any] = []
|
||||||
|
async for event in workflow.run("run hook test", stream=True):
|
||||||
|
if event.type == "output":
|
||||||
|
output_events.append(event)
|
||||||
|
|
||||||
|
assert output_events
|
||||||
|
assert agent.result_hook_called
|
||||||
|
|
||||||
|
|
||||||
async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
|
async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
|
||||||
"""Test that workflow checkpoint stores AgentExecutor's cache and session states and restores them correctly."""
|
"""Test that workflow checkpoint stores AgentExecutor's cache and session states and restores them correctly."""
|
||||||
storage = InMemoryCheckpointStorage()
|
storage = InMemoryCheckpointStorage()
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from agent_framework import (
|
|||||||
WorkflowRunState,
|
WorkflowRunState,
|
||||||
)
|
)
|
||||||
from agent_framework._workflows._checkpoint_encoding import (
|
from agent_framework._workflows._checkpoint_encoding import (
|
||||||
_PICKLE_MARKER,
|
_PICKLE_MARKER, # type: ignore
|
||||||
encode_checkpoint_value,
|
encode_checkpoint_value,
|
||||||
)
|
)
|
||||||
from agent_framework._workflows._events import WorkflowEvent
|
from agent_framework._workflows._events import WorkflowEvent
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Any, Literal, cast
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
from agent_framework import AgentSession, Message
|
from agent_framework import AgentSession, Message
|
||||||
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage
|
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage, WorkflowCheckpoint
|
||||||
from openai.types.conversations import Conversation, ConversationDeletedResource
|
from openai.types.conversations import Conversation, ConversationDeletedResource
|
||||||
from openai.types.conversations.conversation_item import ConversationItem
|
from openai.types.conversations.conversation_item import ConversationItem
|
||||||
from openai.types.conversations.message import Message as OpenAIMessage
|
from openai.types.conversations.message import Message as OpenAIMessage
|
||||||
@@ -480,7 +480,7 @@ class InMemoryConversationStore(ConversationStore):
|
|||||||
checkpoint_storage = conv_data.get("checkpoint_storage")
|
checkpoint_storage = conv_data.get("checkpoint_storage")
|
||||||
if checkpoint_storage:
|
if checkpoint_storage:
|
||||||
# Get all checkpoints for this conversation
|
# Get all checkpoints for this conversation
|
||||||
checkpoints = await checkpoint_storage.list_checkpoints()
|
checkpoints = self._list_all_checkpoints(checkpoint_storage)
|
||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
# Create a conversation item for each checkpoint with summary metadata
|
# Create a conversation item for each checkpoint with summary metadata
|
||||||
# Full checkpoint state is NOT included here (too large for list view)
|
# Full checkpoint state is NOT included here (too large for list view)
|
||||||
@@ -495,7 +495,9 @@ class InMemoryConversationStore(ConversationStore):
|
|||||||
"id": f"checkpoint_{checkpoint.checkpoint_id}",
|
"id": f"checkpoint_{checkpoint.checkpoint_id}",
|
||||||
"type": "checkpoint",
|
"type": "checkpoint",
|
||||||
"checkpoint_id": checkpoint.checkpoint_id,
|
"checkpoint_id": checkpoint.checkpoint_id,
|
||||||
"workflow_id": checkpoint.workflow_id,
|
# Keep workflow_id for backward compatibility with existing UI payloads.
|
||||||
|
"workflow_id": checkpoint.workflow_name,
|
||||||
|
"workflow_name": checkpoint.workflow_name,
|
||||||
"timestamp": checkpoint.timestamp,
|
"timestamp": checkpoint.timestamp,
|
||||||
"status": "completed",
|
"status": "completed",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@@ -506,6 +508,7 @@ class InMemoryConversationStore(ConversationStore):
|
|||||||
"message_count": sum(len(msgs) for msgs in checkpoint.messages.values()),
|
"message_count": sum(len(msgs) for msgs in checkpoint.messages.values()),
|
||||||
"size_bytes": checkpoint_size,
|
"size_bytes": checkpoint_size,
|
||||||
"version": checkpoint.version,
|
"version": checkpoint.version,
|
||||||
|
"graph_signature_hash": checkpoint.graph_signature_hash,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
items.append(cast(ConversationItem, checkpoint_item))
|
items.append(cast(ConversationItem, checkpoint_item))
|
||||||
@@ -551,8 +554,9 @@ class InMemoryConversationStore(ConversationStore):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Load full checkpoint from storage
|
# Load full checkpoint from storage
|
||||||
checkpoint = await checkpoint_storage.load_checkpoint(checkpoint_id)
|
try:
|
||||||
if not checkpoint:
|
checkpoint = await checkpoint_storage.load(checkpoint_id)
|
||||||
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Calculate size of checkpoint
|
# Calculate size of checkpoint
|
||||||
@@ -566,7 +570,9 @@ class InMemoryConversationStore(ConversationStore):
|
|||||||
"id": item_id,
|
"id": item_id,
|
||||||
"type": "checkpoint",
|
"type": "checkpoint",
|
||||||
"checkpoint_id": checkpoint.checkpoint_id,
|
"checkpoint_id": checkpoint.checkpoint_id,
|
||||||
"workflow_id": checkpoint.workflow_id,
|
# Keep workflow_id for backward compatibility with existing UI payloads.
|
||||||
|
"workflow_id": checkpoint.workflow_name,
|
||||||
|
"workflow_name": checkpoint.workflow_name,
|
||||||
"timestamp": checkpoint.timestamp,
|
"timestamp": checkpoint.timestamp,
|
||||||
"status": "completed",
|
"status": "completed",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@@ -577,6 +583,7 @@ class InMemoryConversationStore(ConversationStore):
|
|||||||
"message_count": sum(len(msgs) for msgs in checkpoint.messages.values()),
|
"message_count": sum(len(msgs) for msgs in checkpoint.messages.values()),
|
||||||
"size_bytes": checkpoint_size,
|
"size_bytes": checkpoint_size,
|
||||||
"version": checkpoint.version,
|
"version": checkpoint.version,
|
||||||
|
"graph_signature_hash": checkpoint.graph_signature_hash,
|
||||||
# 🔥 FULL checkpoint state (lazy loaded)
|
# 🔥 FULL checkpoint state (lazy loaded)
|
||||||
"full_checkpoint": checkpoint.to_dict(),
|
"full_checkpoint": checkpoint.to_dict(),
|
||||||
},
|
},
|
||||||
@@ -631,8 +638,8 @@ class InMemoryConversationStore(ConversationStore):
|
|||||||
if conv_meta.get("type") == "workflow_session":
|
if conv_meta.get("type") == "workflow_session":
|
||||||
checkpoint_storage = conv_data.get("checkpoint_storage")
|
checkpoint_storage = conv_data.get("checkpoint_storage")
|
||||||
if checkpoint_storage:
|
if checkpoint_storage:
|
||||||
checkpoints = await checkpoint_storage.list_checkpoints()
|
checkpoints = self._list_all_checkpoints(checkpoint_storage)
|
||||||
latest = checkpoints[0] if checkpoints else None
|
latest = max(checkpoints, key=lambda cp: cp.timestamp) if checkpoints else None
|
||||||
conv_meta["checkpoint_summary"] = {
|
conv_meta["checkpoint_summary"] = {
|
||||||
"count": len(checkpoints),
|
"count": len(checkpoints),
|
||||||
"latest_iteration": latest.iteration_count if latest else 0,
|
"latest_iteration": latest.iteration_count if latest else 0,
|
||||||
@@ -654,6 +661,19 @@ class InMemoryConversationStore(ConversationStore):
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _list_all_checkpoints(checkpoint_storage: Any) -> list[WorkflowCheckpoint]:
|
||||||
|
"""Return all checkpoints from a conversation-scoped storage instance.
|
||||||
|
|
||||||
|
DevUI uses one checkpoint storage per conversation. Core storage APIs now
|
||||||
|
require workflow_name filters, so we gather directly from in-memory storage
|
||||||
|
internals to provide conversation-wide listing for UI views.
|
||||||
|
"""
|
||||||
|
checkpoint_map = getattr(checkpoint_storage, "_checkpoints", None)
|
||||||
|
if isinstance(checkpoint_map, dict):
|
||||||
|
return list(cast(dict[str, WorkflowCheckpoint], checkpoint_map).values())
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class CheckpointConversationManager:
|
class CheckpointConversationManager:
|
||||||
"""Manages checkpoint storage for workflow sessions - SESSION-SCOPED.
|
"""Manages checkpoint storage for workflow sessions - SESSION-SCOPED.
|
||||||
|
|||||||
@@ -104,17 +104,21 @@ class TestCheckpointConversationManager:
|
|||||||
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
|
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
|
||||||
|
|
||||||
checkpoint = WorkflowCheckpoint(
|
checkpoint = WorkflowCheckpoint(
|
||||||
checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, state={"test": "data"}
|
checkpoint_id=str(uuid.uuid4()),
|
||||||
|
workflow_name=test_workflow.name,
|
||||||
|
graph_signature_hash=test_workflow.graph_signature_hash,
|
||||||
|
messages={},
|
||||||
|
state={"test": "data"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get checkpoint storage for this conversation and save
|
# Get checkpoint storage for this conversation and save
|
||||||
storage = checkpoint_manager.get_checkpoint_storage(conversation_id)
|
storage = checkpoint_manager.get_checkpoint_storage(conversation_id)
|
||||||
checkpoint_id = await storage.save_checkpoint(checkpoint)
|
checkpoint_id = await storage.save(checkpoint)
|
||||||
|
|
||||||
assert checkpoint_id == checkpoint.checkpoint_id
|
assert checkpoint_id == checkpoint.checkpoint_id
|
||||||
|
|
||||||
# Verify checkpoint stored in THIS conversation only
|
# Verify checkpoint stored in THIS conversation only
|
||||||
checkpoints = await storage.list_checkpoints()
|
checkpoints = await storage.list_checkpoints(workflow_name=test_workflow.name)
|
||||||
assert len(checkpoints) == 1
|
assert len(checkpoints) == 1
|
||||||
assert checkpoints[0].checkpoint_id == checkpoint.checkpoint_id
|
assert checkpoints[0].checkpoint_id == checkpoint.checkpoint_id
|
||||||
|
|
||||||
@@ -140,20 +144,21 @@ class TestCheckpointConversationManager:
|
|||||||
|
|
||||||
checkpoint_a = WorkflowCheckpoint(
|
checkpoint_a = WorkflowCheckpoint(
|
||||||
checkpoint_id=str(uuid.uuid4()),
|
checkpoint_id=str(uuid.uuid4()),
|
||||||
workflow_id=test_workflow.id,
|
workflow_name=test_workflow.name,
|
||||||
|
graph_signature_hash=test_workflow.graph_signature_hash,
|
||||||
messages={},
|
messages={},
|
||||||
state={"conversation": "A"},
|
state={"conversation": "A"},
|
||||||
)
|
)
|
||||||
storage_a = checkpoint_manager.get_checkpoint_storage(conv_a)
|
storage_a = checkpoint_manager.get_checkpoint_storage(conv_a)
|
||||||
await storage_a.save_checkpoint(checkpoint_a)
|
await storage_a.save(checkpoint_a)
|
||||||
|
|
||||||
# Verify conversation A has checkpoint
|
# Verify conversation A has checkpoint
|
||||||
checkpoints_a = await storage_a.list_checkpoints()
|
checkpoints_a = await storage_a.list_checkpoints(workflow_name=test_workflow.name)
|
||||||
assert len(checkpoints_a) == 1
|
assert len(checkpoints_a) == 1
|
||||||
|
|
||||||
# Verify conversation B has NO checkpoints (isolation)
|
# Verify conversation B has NO checkpoints (isolation)
|
||||||
storage_b = checkpoint_manager.get_checkpoint_storage(conv_b)
|
storage_b = checkpoint_manager.get_checkpoint_storage(conv_b)
|
||||||
checkpoints_b = await storage_b.list_checkpoints()
|
checkpoints_b = await storage_b.list_checkpoints(workflow_name=test_workflow.name)
|
||||||
assert len(checkpoints_b) == 0
|
assert len(checkpoints_b) == 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -177,15 +182,16 @@ class TestCheckpointConversationManager:
|
|||||||
for i in range(3):
|
for i in range(3):
|
||||||
checkpoint = WorkflowCheckpoint(
|
checkpoint = WorkflowCheckpoint(
|
||||||
checkpoint_id=str(uuid.uuid4()),
|
checkpoint_id=str(uuid.uuid4()),
|
||||||
workflow_id=test_workflow.id,
|
workflow_name=test_workflow.name,
|
||||||
|
graph_signature_hash=test_workflow.graph_signature_hash,
|
||||||
messages={},
|
messages={},
|
||||||
state={"iteration": i},
|
state={"iteration": i},
|
||||||
)
|
)
|
||||||
saved_id = await storage.save_checkpoint(checkpoint)
|
saved_id = await storage.save(checkpoint)
|
||||||
checkpoint_ids.append(saved_id)
|
checkpoint_ids.append(saved_id)
|
||||||
|
|
||||||
# List checkpoints using the storage
|
# List checkpoints using the storage
|
||||||
checkpoints_list = await storage.list_checkpoints()
|
checkpoints_list = await storage.list_checkpoints(workflow_name=test_workflow.name)
|
||||||
assert len(checkpoints_list) == 3
|
assert len(checkpoints_list) == 3
|
||||||
|
|
||||||
# Verify all checkpoint IDs are present
|
# Verify all checkpoint IDs are present
|
||||||
@@ -213,11 +219,12 @@ class TestCheckpointConversationManager:
|
|||||||
for i in range(2):
|
for i in range(2):
|
||||||
checkpoint = WorkflowCheckpoint(
|
checkpoint = WorkflowCheckpoint(
|
||||||
checkpoint_id=f"checkpoint_{i}",
|
checkpoint_id=f"checkpoint_{i}",
|
||||||
workflow_id=test_workflow.id,
|
workflow_name=test_workflow.name,
|
||||||
|
graph_signature_hash=test_workflow.graph_signature_hash,
|
||||||
messages={},
|
messages={},
|
||||||
state={"iteration": i},
|
state={"iteration": i},
|
||||||
)
|
)
|
||||||
saved_id = await storage.save_checkpoint(checkpoint)
|
saved_id = await storage.save(checkpoint)
|
||||||
checkpoint_ids.append(saved_id)
|
checkpoint_ids.append(saved_id)
|
||||||
|
|
||||||
# List conversation items - should include checkpoints
|
# List conversation items - should include checkpoints
|
||||||
@@ -233,7 +240,7 @@ class TestCheckpointConversationManager:
|
|||||||
for item in checkpoint_items:
|
for item in checkpoint_items:
|
||||||
assert item.get("type") == "checkpoint"
|
assert item.get("type") == "checkpoint"
|
||||||
assert item.get("checkpoint_id") in checkpoint_ids
|
assert item.get("checkpoint_id") in checkpoint_ids
|
||||||
assert item.get("workflow_id") == test_workflow.id
|
assert item.get("workflow_name") == test_workflow.name
|
||||||
assert "timestamp" in item
|
assert "timestamp" in item
|
||||||
assert item.get("id").startswith("checkpoint_") # ID format: checkpoint_{checkpoint_id}
|
assert item.get("id").startswith("checkpoint_") # ID format: checkpoint_{checkpoint_id}
|
||||||
|
|
||||||
@@ -255,21 +262,22 @@ class TestCheckpointConversationManager:
|
|||||||
|
|
||||||
original_checkpoint = WorkflowCheckpoint(
|
original_checkpoint = WorkflowCheckpoint(
|
||||||
checkpoint_id=str(uuid.uuid4()),
|
checkpoint_id=str(uuid.uuid4()),
|
||||||
workflow_id=test_workflow.id,
|
workflow_name=test_workflow.name,
|
||||||
|
graph_signature_hash=test_workflow.graph_signature_hash,
|
||||||
messages={},
|
messages={},
|
||||||
state={"test_key": "test_value"},
|
state={"test_key": "test_value"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save to this session
|
# Save to this session
|
||||||
storage = checkpoint_manager.get_checkpoint_storage(conversation_id)
|
storage = checkpoint_manager.get_checkpoint_storage(conversation_id)
|
||||||
await storage.save_checkpoint(original_checkpoint)
|
await storage.save(original_checkpoint)
|
||||||
|
|
||||||
# Load checkpoint from this session
|
# Load checkpoint from this session
|
||||||
loaded_checkpoint = await storage.load_checkpoint(original_checkpoint.checkpoint_id)
|
loaded_checkpoint = await storage.load(original_checkpoint.checkpoint_id)
|
||||||
|
|
||||||
assert loaded_checkpoint is not None
|
assert loaded_checkpoint is not None
|
||||||
assert loaded_checkpoint.checkpoint_id == original_checkpoint.checkpoint_id
|
assert loaded_checkpoint.checkpoint_id == original_checkpoint.checkpoint_id
|
||||||
assert loaded_checkpoint.workflow_id == original_checkpoint.workflow_id
|
assert loaded_checkpoint.workflow_name == original_checkpoint.workflow_name
|
||||||
assert loaded_checkpoint.state == {"test_key": "test_value"}
|
assert loaded_checkpoint.state == {"test_key": "test_value"}
|
||||||
|
|
||||||
|
|
||||||
@@ -296,24 +304,28 @@ class TestCheckpointStorage:
|
|||||||
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
|
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
|
||||||
|
|
||||||
checkpoint = WorkflowCheckpoint(
|
checkpoint = WorkflowCheckpoint(
|
||||||
checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, state={"test": "data"}
|
checkpoint_id=str(uuid.uuid4()),
|
||||||
|
workflow_name=test_workflow.name,
|
||||||
|
graph_signature_hash=test_workflow.graph_signature_hash,
|
||||||
|
messages={},
|
||||||
|
state={"test": "data"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test save_checkpoint
|
# Test save
|
||||||
checkpoint_id = await storage.save_checkpoint(checkpoint)
|
checkpoint_id = await storage.save(checkpoint)
|
||||||
assert checkpoint_id == checkpoint.checkpoint_id
|
assert checkpoint_id == checkpoint.checkpoint_id
|
||||||
|
|
||||||
# Test load_checkpoint
|
# Test load
|
||||||
loaded = await storage.load_checkpoint(checkpoint_id)
|
loaded = await storage.load(checkpoint_id)
|
||||||
assert loaded is not None
|
assert loaded is not None
|
||||||
assert loaded.checkpoint_id == checkpoint_id
|
assert loaded.checkpoint_id == checkpoint_id
|
||||||
|
|
||||||
# Test list_checkpoint_ids
|
# Test list_checkpoint_ids
|
||||||
ids = await storage.list_checkpoint_ids(workflow_id=test_workflow.id)
|
ids = await storage.list_checkpoint_ids(workflow_name=test_workflow.name)
|
||||||
assert checkpoint_id in ids
|
assert checkpoint_id in ids
|
||||||
|
|
||||||
# Test list_checkpoints
|
# Test list_checkpoints
|
||||||
checkpoints_list = await storage.list_checkpoints(workflow_id=test_workflow.id)
|
checkpoints_list = await storage.list_checkpoints(workflow_name=test_workflow.name)
|
||||||
assert len(checkpoints_list) >= 1
|
assert len(checkpoints_list) >= 1
|
||||||
assert any(cp.checkpoint_id == checkpoint_id for cp in checkpoints_list)
|
assert any(cp.checkpoint_id == checkpoint_id for cp in checkpoints_list)
|
||||||
|
|
||||||
@@ -346,12 +358,16 @@ class TestIntegration:
|
|||||||
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
|
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
|
||||||
|
|
||||||
checkpoint = WorkflowCheckpoint(
|
checkpoint = WorkflowCheckpoint(
|
||||||
checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, state={"injected": True}
|
checkpoint_id=str(uuid.uuid4()),
|
||||||
|
workflow_name=test_workflow.name,
|
||||||
|
graph_signature_hash=test_workflow.graph_signature_hash,
|
||||||
|
messages={},
|
||||||
|
state={"injected": True},
|
||||||
)
|
)
|
||||||
await checkpoint_storage.save_checkpoint(checkpoint)
|
await checkpoint_storage.save(checkpoint)
|
||||||
|
|
||||||
# Verify checkpoint is accessible via storage (in this session)
|
# Verify checkpoint is accessible via storage (in this session)
|
||||||
storage_checkpoints = await checkpoint_storage.list_checkpoints()
|
storage_checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=test_workflow.name)
|
||||||
assert len(storage_checkpoints) > 0
|
assert len(storage_checkpoints) > 0
|
||||||
assert storage_checkpoints[0].checkpoint_id == checkpoint.checkpoint_id
|
assert storage_checkpoints[0].checkpoint_id == checkpoint.checkpoint_id
|
||||||
|
|
||||||
@@ -377,20 +393,21 @@ class TestIntegration:
|
|||||||
|
|
||||||
checkpoint = WorkflowCheckpoint(
|
checkpoint = WorkflowCheckpoint(
|
||||||
checkpoint_id=str(uuid.uuid4()),
|
checkpoint_id=str(uuid.uuid4()),
|
||||||
workflow_id=test_workflow.id,
|
workflow_name=test_workflow.name,
|
||||||
|
graph_signature_hash=test_workflow.graph_signature_hash,
|
||||||
messages={},
|
messages={},
|
||||||
state={"ready_to_resume": True},
|
state={"ready_to_resume": True},
|
||||||
)
|
)
|
||||||
checkpoint_id = await checkpoint_storage.save_checkpoint(checkpoint)
|
checkpoint_id = await checkpoint_storage.save(checkpoint)
|
||||||
|
|
||||||
# Verify checkpoint can be loaded for resume
|
# Verify checkpoint can be loaded for resume
|
||||||
loaded = await checkpoint_storage.load_checkpoint(checkpoint_id)
|
loaded = await checkpoint_storage.load(checkpoint_id)
|
||||||
assert loaded is not None
|
assert loaded is not None
|
||||||
assert loaded.checkpoint_id == checkpoint_id
|
assert loaded.checkpoint_id == checkpoint_id
|
||||||
assert loaded.state == {"ready_to_resume": True}
|
assert loaded.state == {"ready_to_resume": True}
|
||||||
|
|
||||||
# Verify checkpoint is accessible via storage (for UI to list checkpoints)
|
# Verify checkpoint is accessible via storage (for UI to list checkpoints)
|
||||||
checkpoints = await checkpoint_storage.list_checkpoints()
|
checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=test_workflow.name)
|
||||||
assert len(checkpoints) > 0
|
assert len(checkpoints) > 0
|
||||||
assert checkpoints[0].checkpoint_id == checkpoint_id
|
assert checkpoints[0].checkpoint_id == checkpoint_id
|
||||||
|
|
||||||
@@ -420,7 +437,7 @@ class TestIntegration:
|
|||||||
test_workflow._runner.context._checkpoint_storage = checkpoint_storage
|
test_workflow._runner.context._checkpoint_storage = checkpoint_storage
|
||||||
|
|
||||||
# Verify no checkpoints initially
|
# Verify no checkpoints initially
|
||||||
checkpoints_before = await checkpoint_storage.list_checkpoints()
|
checkpoints_before = await checkpoint_storage.list_checkpoints(workflow_name=test_workflow.name)
|
||||||
assert len(checkpoints_before) == 0
|
assert len(checkpoints_before) == 0
|
||||||
|
|
||||||
# Run workflow until it reaches IDLE_WITH_PENDING_REQUESTS (after checkpoint is created)
|
# Run workflow until it reaches IDLE_WITH_PENDING_REQUESTS (after checkpoint is created)
|
||||||
@@ -435,9 +452,9 @@ class TestIntegration:
|
|||||||
assert saw_request_event, "Test workflow should have emitted request_info event (type='request_info')"
|
assert saw_request_event, "Test workflow should have emitted request_info event (type='request_info')"
|
||||||
|
|
||||||
# Verify checkpoint was AUTOMATICALLY saved to our storage by the framework
|
# Verify checkpoint was AUTOMATICALLY saved to our storage by the framework
|
||||||
checkpoints_after = await checkpoint_storage.list_checkpoints()
|
checkpoints_after = await checkpoint_storage.list_checkpoints(workflow_name=test_workflow.name)
|
||||||
assert len(checkpoints_after) > 0, "Workflow should have auto-saved checkpoint at HIL pause"
|
assert len(checkpoints_after) > 0, "Workflow should have auto-saved checkpoint at HIL pause"
|
||||||
|
|
||||||
# Verify checkpoint has correct workflow_id
|
# Verify checkpoint has correct workflow identity
|
||||||
checkpoint = checkpoints_after[0]
|
checkpoint = checkpoints_after[0]
|
||||||
assert checkpoint.workflow_id == test_workflow.id
|
assert checkpoint.workflow_name == test_workflow.name
|
||||||
|
|||||||
@@ -379,26 +379,27 @@ async def test_checkpoint_api_endpoints(test_entities_dir):
|
|||||||
storage = executor.checkpoint_manager.get_checkpoint_storage(conv_id)
|
storage = executor.checkpoint_manager.get_checkpoint_storage(conv_id)
|
||||||
checkpoint = WorkflowCheckpoint(
|
checkpoint = WorkflowCheckpoint(
|
||||||
checkpoint_id="test_checkpoint_1",
|
checkpoint_id="test_checkpoint_1",
|
||||||
workflow_id="test_workflow",
|
workflow_name="test_workflow",
|
||||||
|
graph_signature_hash="test_graph_hash",
|
||||||
state={"key": "value"},
|
state={"key": "value"},
|
||||||
iteration_count=1,
|
iteration_count=1,
|
||||||
)
|
)
|
||||||
await storage.save_checkpoint(checkpoint)
|
await storage.save(checkpoint)
|
||||||
|
|
||||||
# Test list checkpoints endpoint
|
# Test list checkpoints endpoint
|
||||||
checkpoints = await storage.list_checkpoints()
|
checkpoints = await storage.list_checkpoints(workflow_name="test_workflow")
|
||||||
assert len(checkpoints) == 1
|
assert len(checkpoints) == 1
|
||||||
assert checkpoints[0].checkpoint_id == "test_checkpoint_1"
|
assert checkpoints[0].checkpoint_id == "test_checkpoint_1"
|
||||||
assert checkpoints[0].workflow_id == "test_workflow"
|
assert checkpoints[0].workflow_name == "test_workflow"
|
||||||
|
|
||||||
# Test delete checkpoint endpoint
|
# Test delete checkpoint endpoint
|
||||||
deleted = await storage.delete_checkpoint("test_checkpoint_1")
|
deleted = await storage.delete("test_checkpoint_1")
|
||||||
assert deleted is True
|
assert deleted is True
|
||||||
|
|
||||||
# Verify checkpoint was deleted
|
# Verify checkpoint was deleted
|
||||||
remaining = await storage.list_checkpoints()
|
remaining = await storage.list_checkpoints(workflow_name="test_workflow")
|
||||||
assert len(remaining) == 0
|
assert len(remaining) == 0
|
||||||
|
|
||||||
# Test delete non-existent checkpoint
|
# Test delete non-existent checkpoint
|
||||||
deleted = await storage.delete_checkpoint("nonexistent")
|
deleted = await storage.delete("nonexistent")
|
||||||
assert deleted is False
|
assert deleted is False
|
||||||
|
|||||||
@@ -189,19 +189,29 @@ class TestClientAgentExecutorPollingConfiguration:
|
|||||||
# Verify get_entity was called 2 times (max_poll_retries)
|
# Verify get_entity was called 2 times (max_poll_retries)
|
||||||
assert mock_client.get_entity.call_count == 2
|
assert mock_client.get_entity.call_count == 2
|
||||||
|
|
||||||
def test_executor_respects_custom_poll_interval(self, mock_client: Mock, sample_run_request: RunRequest) -> None:
|
def test_executor_respects_custom_poll_interval(
|
||||||
|
self,
|
||||||
|
mock_client: Mock,
|
||||||
|
sample_run_request: RunRequest,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
"""Verify executor respects custom poll_interval_seconds during polling."""
|
"""Verify executor respects custom poll_interval_seconds during polling."""
|
||||||
# Create executor with very short interval
|
# Create executor with very short interval
|
||||||
executor = ClientAgentExecutor(mock_client, max_poll_retries=3, poll_interval_seconds=0.01)
|
executor = ClientAgentExecutor(mock_client, max_poll_retries=3, poll_interval_seconds=0.01)
|
||||||
|
|
||||||
# Measure time taken
|
sleep_calls: list[float] = []
|
||||||
start = time.time()
|
|
||||||
result = executor.run_durable_agent("test_agent", sample_run_request)
|
|
||||||
elapsed = time.time() - start
|
|
||||||
|
|
||||||
# Should take roughly 3 * 0.01 = 0.03 seconds (plus overhead)
|
def fake_sleep(seconds: float) -> None:
|
||||||
# Be generous with timing to avoid flakiness
|
sleep_calls.append(seconds)
|
||||||
assert elapsed < 0.2 # Should be quick with 0.01 interval
|
|
||||||
|
# Use deterministic assertions instead of wall-clock timing to avoid CI flakiness.
|
||||||
|
monkeypatch.setattr("agent_framework_durabletask._executors.time.sleep", fake_sleep)
|
||||||
|
|
||||||
|
result = executor.run_durable_agent("test_agent", sample_run_request)
|
||||||
|
|
||||||
|
assert len(sleep_calls) == 3
|
||||||
|
assert sleep_calls == pytest.approx([0.01, 0.01, 0.01])
|
||||||
|
assert mock_client.get_entity.call_count == 3
|
||||||
assert isinstance(result, AgentResponse)
|
assert isinstance(result, AgentResponse)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+23
-4
@@ -8,9 +8,11 @@ from typing import Any
|
|||||||
|
|
||||||
from agent_framework import (
|
from agent_framework import (
|
||||||
Agent,
|
Agent,
|
||||||
|
AgentResponseUpdate,
|
||||||
Content,
|
Content,
|
||||||
FileCheckpointStorage,
|
FileCheckpointStorage,
|
||||||
Workflow,
|
Workflow,
|
||||||
|
WorkflowEvent,
|
||||||
tool,
|
tool,
|
||||||
)
|
)
|
||||||
from agent_framework.azure import AzureOpenAIResponsesClient
|
from agent_framework.azure import AzureOpenAIResponsesClient
|
||||||
@@ -183,8 +185,16 @@ async def main() -> None:
|
|||||||
initial_request = "Hi, my order 12345 arrived damaged. I need a refund."
|
initial_request = "Hi, my order 12345 arrived damaged. I need a refund."
|
||||||
|
|
||||||
# Phase 1: Initial run - workflow will pause when it needs user input
|
# Phase 1: Initial run - workflow will pause when it needs user input
|
||||||
results = await workflow.run(message=initial_request)
|
print("Running initial workflow...")
|
||||||
request_events = results.get_request_info_events()
|
results = await workflow.run(message=initial_request, stream=True)
|
||||||
|
|
||||||
|
# Iterate through streamed events and collect request_info events
|
||||||
|
request_events: list[WorkflowEvent] = []
|
||||||
|
async for event in results:
|
||||||
|
event: WorkflowEvent
|
||||||
|
if event.type == "request_info":
|
||||||
|
request_events.append(event)
|
||||||
|
|
||||||
if not request_events:
|
if not request_events:
|
||||||
print("Workflow completed without needing user input")
|
print("Workflow completed without needing user input")
|
||||||
return
|
return
|
||||||
@@ -224,8 +234,17 @@ async def main() -> None:
|
|||||||
raise RuntimeError("No checkpoints found.")
|
raise RuntimeError("No checkpoints found.")
|
||||||
checkpoint_id = checkpoint.checkpoint_id
|
checkpoint_id = checkpoint.checkpoint_id
|
||||||
|
|
||||||
results = await workflow.run(responses=responses, checkpoint_id=checkpoint_id)
|
print("Resuming workflow from checkpoint...")
|
||||||
request_events = results.get_request_info_events()
|
results = await workflow.run(responses=responses, checkpoint_id=checkpoint_id, stream=True)
|
||||||
|
|
||||||
|
# Iterate through streamed events and collect request_info events
|
||||||
|
request_events: list[WorkflowEvent] = []
|
||||||
|
async for event in results:
|
||||||
|
event: WorkflowEvent
|
||||||
|
if event.type == "request_info":
|
||||||
|
request_events.append(event)
|
||||||
|
elif event.type == "output" and isinstance(event.data, AgentResponseUpdate):
|
||||||
|
print(event.data.text, end="", flush=True)
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("DEMO COMPLETE")
|
print("DEMO COMPLETE")
|
||||||
|
|||||||
Binary file not shown.
|
After Width: | Height: | Size: 4.5 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 4.5 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 4.5 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 4.5 MiB |
-186
@@ -1,186 +0,0 @@
|
|||||||
# Copyright (c) Microsoft. All rights reserved.
|
|
||||||
|
|
||||||
"""
|
|
||||||
Handoff Workflow with Code Interpreter File Generation Sample
|
|
||||||
|
|
||||||
This sample demonstrates retrieving file IDs from code interpreter output
|
|
||||||
in a handoff workflow context. A triage agent routes to a code specialist
|
|
||||||
that generates a text file, and we verify the file_id is captured correctly
|
|
||||||
from the streaming workflow events.
|
|
||||||
|
|
||||||
Verifies GitHub issue #2718: files generated by code interpreter in
|
|
||||||
HandoffBuilder workflows can be properly retrieved.
|
|
||||||
|
|
||||||
Prerequisites:
|
|
||||||
- AZURE_AI_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint.
|
|
||||||
- `az login` (Azure CLI authentication)
|
|
||||||
- AZURE_AI_MODEL_DEPLOYMENT_NAME
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
from collections.abc import AsyncIterable
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from agent_framework import (
|
|
||||||
AgentResponseUpdate,
|
|
||||||
Message,
|
|
||||||
WorkflowEvent,
|
|
||||||
WorkflowRunState,
|
|
||||||
)
|
|
||||||
from agent_framework.azure import AzureOpenAIResponsesClient
|
|
||||||
from agent_framework.orchestrations import HandoffAgentUserRequest, HandoffBuilder
|
|
||||||
from azure.identity import AzureCliCredential
|
|
||||||
|
|
||||||
|
|
||||||
async def _drain(stream: AsyncIterable[WorkflowEvent]) -> list[WorkflowEvent]:
|
|
||||||
"""Collect all events from an async stream."""
|
|
||||||
return [event async for event in stream]
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_events(events: list[WorkflowEvent]) -> tuple[list[WorkflowEvent[HandoffAgentUserRequest]], list[str]]:
|
|
||||||
"""Process workflow events and extract file IDs and pending requests.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (pending_requests, file_ids_found)
|
|
||||||
"""
|
|
||||||
|
|
||||||
requests: list[WorkflowEvent[HandoffAgentUserRequest]] = []
|
|
||||||
file_ids: list[str] = []
|
|
||||||
|
|
||||||
for event in events:
|
|
||||||
if event.type == "handoff_sent":
|
|
||||||
print(f"\n[Handoff from {event.data.source} to {event.data.target} initiated.]")
|
|
||||||
elif event.type == "status" and event.state in {
|
|
||||||
WorkflowRunState.IDLE,
|
|
||||||
WorkflowRunState.IDLE_WITH_PENDING_REQUESTS,
|
|
||||||
}:
|
|
||||||
print(f"[status] {event.state}")
|
|
||||||
elif event.type == "request_info" and isinstance(event.data, HandoffAgentUserRequest):
|
|
||||||
requests.append(cast(WorkflowEvent[HandoffAgentUserRequest], event))
|
|
||||||
elif event.type == "output":
|
|
||||||
data = event.data
|
|
||||||
if isinstance(data, AgentResponseUpdate):
|
|
||||||
for content in data.contents:
|
|
||||||
if content.type == "hosted_file":
|
|
||||||
file_ids.append(content.file_id) # type: ignore
|
|
||||||
print(f"[Found HostedFileContent: file_id={content.file_id}]")
|
|
||||||
elif content.type == "text" and content.annotations:
|
|
||||||
for annotation in content.annotations:
|
|
||||||
file_id = annotation["file_id"] # type: ignore
|
|
||||||
file_ids.append(file_id)
|
|
||||||
print(f"[Found file annotation: file_id={file_id}]")
|
|
||||||
elif isinstance(data, list):
|
|
||||||
conversation = cast(list[Message], data)
|
|
||||||
if isinstance(conversation, list):
|
|
||||||
print("\n=== Final Conversation Snapshot ===")
|
|
||||||
for message in conversation:
|
|
||||||
speaker = message.author_name or message.role
|
|
||||||
print(f"- {speaker}: {message.text or [content.type for content in message.contents]}")
|
|
||||||
print("===================================")
|
|
||||||
|
|
||||||
return requests, file_ids
|
|
||||||
|
|
||||||
|
|
||||||
async def main() -> None:
|
|
||||||
"""Run a simple handoff workflow with code interpreter file generation."""
|
|
||||||
print("=== Handoff Workflow with Code Interpreter File Generation ===\n")
|
|
||||||
|
|
||||||
client = AzureOpenAIResponsesClient(
|
|
||||||
project_endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"],
|
|
||||||
deployment_name=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
|
|
||||||
credential=AzureCliCredential(),
|
|
||||||
)
|
|
||||||
|
|
||||||
triage = client.as_agent(
|
|
||||||
name="triage_agent",
|
|
||||||
instructions=(
|
|
||||||
"You are a triage agent. Route code-related requests to the code_specialist. "
|
|
||||||
"When the user asks to create or generate files, hand off to code_specialist "
|
|
||||||
"by calling handoff_to_code_specialist."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
code_interpreter_tool = client.get_code_interpreter_tool()
|
|
||||||
|
|
||||||
code_specialist = client.as_agent(
|
|
||||||
name="code_specialist",
|
|
||||||
instructions=(
|
|
||||||
"You are a Python code specialist. Use the code interpreter to execute Python code "
|
|
||||||
"and create files when requested. Always save files to /mnt/data/ directory."
|
|
||||||
),
|
|
||||||
tools=[code_interpreter_tool],
|
|
||||||
)
|
|
||||||
|
|
||||||
workflow = (
|
|
||||||
HandoffBuilder(
|
|
||||||
termination_condition=lambda conv: sum(1 for msg in conv if msg.role == "user") >= 2,
|
|
||||||
)
|
|
||||||
.participants([triage, code_specialist])
|
|
||||||
.with_start_agent(triage)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
|
|
||||||
user_inputs = [
|
|
||||||
"Please create a text file called hello.txt with 'Hello from handoff workflow!' inside it.",
|
|
||||||
"exit",
|
|
||||||
]
|
|
||||||
input_index = 0
|
|
||||||
all_file_ids: list[str] = []
|
|
||||||
|
|
||||||
print(f"User: {user_inputs[0]}")
|
|
||||||
events = await _drain(workflow.run(user_inputs[0], stream=True))
|
|
||||||
requests, file_ids = _handle_events(events)
|
|
||||||
all_file_ids.extend(file_ids)
|
|
||||||
input_index += 1
|
|
||||||
|
|
||||||
while requests:
|
|
||||||
request = requests[0]
|
|
||||||
if input_index >= len(user_inputs):
|
|
||||||
break
|
|
||||||
user_input = user_inputs[input_index]
|
|
||||||
print(f"\nUser: {user_input}")
|
|
||||||
|
|
||||||
responses = {request.request_id: HandoffAgentUserRequest.create_response(user_input)}
|
|
||||||
events = await _drain(workflow.run(stream=True, responses=responses))
|
|
||||||
requests, file_ids = _handle_events(events)
|
|
||||||
all_file_ids.extend(file_ids)
|
|
||||||
input_index += 1
|
|
||||||
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
if all_file_ids:
|
|
||||||
print(f"SUCCESS: Found {len(all_file_ids)} file ID(s) in handoff workflow:")
|
|
||||||
for fid in all_file_ids:
|
|
||||||
print(f" - {fid}")
|
|
||||||
else:
|
|
||||||
print("WARNING: No file IDs captured from the handoff workflow.")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
"""
|
|
||||||
Sample Output:
|
|
||||||
|
|
||||||
User: Please create a text file called hello.txt with 'Hello from handoff workflow!' inside it.
|
|
||||||
[Found HostedFileContent: file_id=assistant-JT1sA...]
|
|
||||||
|
|
||||||
=== Conversation So Far ===
|
|
||||||
- user: Please create a text file called hello.txt with 'Hello from handoff workflow!' inside it.
|
|
||||||
- triage_agent: I am handing off your request to create the text file "hello.txt" with the specified content to the code specialist. They will assist you shortly.
|
|
||||||
- code_specialist: The file "hello.txt" has been created with the content "Hello from handoff workflow!". You can download it using the link below:
|
|
||||||
|
|
||||||
[hello.txt](sandbox:/mnt/data/hello.txt)
|
|
||||||
===========================
|
|
||||||
|
|
||||||
[status] IDLE_WITH_PENDING_REQUESTS
|
|
||||||
|
|
||||||
User: exit
|
|
||||||
[status] IDLE
|
|
||||||
|
|
||||||
==================================================
|
|
||||||
SUCCESS: Found 1 file ID(s) in handoff workflow:
|
|
||||||
- assistant-JT1sA...
|
|
||||||
==================================================
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
Reference in New Issue
Block a user