Python: Fix streamed workflow agent continuation context by finalizing AgentExecutor streams (#3882)

* Fix streamed workflow agent continuation context by finalizing AgentExecutor streams

* Fix stream handling

* Fixes

* Fix DevUI and tests
This commit is contained in:
Evan Mattson
2026-02-13 07:45:46 +09:00
committed by GitHub
Unverified
parent 2203fa0f8b
commit a276c1295a
17 changed files with 359 additions and 267 deletions
@@ -911,19 +911,21 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
if ctx is None:
return # No context available (shouldn't happen in normal flow)
# Update thread with conversation_id derived from streaming raw updates.
# Using response_id here can break function-call continuation for APIs
# where response IDs are not valid conversation handles.
conversation_id = self._extract_conversation_id_from_streaming_response(response)
# Ensure author names are set for all messages
for message in response.messages:
if message.author_name is None:
message.author_name = ctx["agent_name"]
# Propagate conversation_id back to session from streaming updates
# Propagate conversation_id back to session from streaming updates.
# For Responses-style APIs this can rotate every turn (response_id-based continuation),
# so refresh when a newer value is returned.
sess = ctx["session"]
if sess and not sess.service_session_id and response.raw_representation:
raw_items = response.raw_representation if isinstance(response.raw_representation, list) else []
for item in raw_items:
if hasattr(item, "conversation_id") and item.conversation_id:
sess.service_session_id = item.conversation_id
break
if sess and conversation_id and sess.service_session_id != conversation_id:
sess.service_session_id = conversation_id
# Run after_run providers (reverse order)
session_context = ctx["session_context"]
@@ -974,6 +976,27 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
output_format_type = response_format if isinstance(response_format, type) else None
return AgentResponse.from_updates(updates, output_format_type=output_format_type)
@staticmethod
def _extract_conversation_id_from_streaming_response(response: AgentResponse[Any]) -> str | None:
"""Extract conversation_id from streaming raw updates, if present."""
raw = response.raw_representation
if raw is None:
return None
raw_items: list[Any] = raw if isinstance(raw, list) else [raw]
for item in reversed(raw_items):
if isinstance(item, Mapping):
value = item.get("conversation_id")
if isinstance(value, str) and value:
return value
continue
value = getattr(item, "conversation_id", None)
if isinstance(value, str) and value:
return value
return None
async def _prepare_run_context(
self,
*,
@@ -1100,8 +1123,10 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
if message.author_name is None:
message.author_name = agent_name
# Propagate conversation_id back to session (e.g. thread ID from Assistants API)
if session and response.conversation_id and not session.service_session_id:
# Propagate conversation_id back to session (e.g. thread ID from Assistants API).
# For Responses-style APIs this can rotate every turn (response_id-based continuation),
# so refresh when a newer value is returned.
if session and response.conversation_id and session.service_session_id != response.conversation_id:
session.service_session_id = response.conversation_id
# Set the response on the context for after_run providers
+10 -1
View File
@@ -872,7 +872,16 @@ class MCPTool:
k: v
for k, v in kwargs.items()
if k
not in {"chat_options", "tools", "tool_choice", "session", "thread", "conversation_id", "options", "response_format"}
not in {
"chat_options",
"tools",
"tool_choice",
"session",
"thread",
"conversation_id",
"options",
"response_format",
}
}
parser = self.parse_tool_results or _parse_tool_result_from_mcp
@@ -2,7 +2,7 @@
import logging
import sys
from collections.abc import Mapping
from collections.abc import Awaitable, Callable, Mapping
from dataclasses import dataclass
from typing import Any, cast
@@ -358,22 +358,31 @@ class AgentExecutor(Executor):
run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {})
updates: list[AgentResponseUpdate] = []
user_input_requests: list[Content] = []
async for update in self._agent.run(
streamed_user_input_requests: list[Content] = []
stream = self._agent.run(
self._cache,
stream=True,
session=self._session,
options=options,
**run_kwargs,
):
)
async for update in stream:
updates.append(update)
await ctx.yield_output(update)
if update.user_input_requests:
user_input_requests.extend(update.user_input_requests)
streamed_user_input_requests.extend(update.user_input_requests)
# Build the final AgentResponse from the collected updates
if is_chat_agent(self._agent):
# Prefer stream finalization when available so result hooks run
# (e.g., thread conversation updates). Fall back to reconstructing from updates
# for legacy/custom agents that return a plain async iterable.
# TODO(evmattso): Integrate workflow agent run handling around ResponseStream so
# AgentExecutor does not need this conditional stream-finalization branch.
maybe_get_final_response = getattr(stream, "get_final_response", None)
get_final_response = maybe_get_final_response if callable(maybe_get_final_response) else None
response: AgentResponse[Any]
if get_final_response is not None:
response = await cast(Callable[[], Awaitable[AgentResponse[Any]]], get_final_response)()
elif is_chat_agent(self._agent):
response_format = self._agent.default_options.get("response_format")
response = AgentResponse.from_updates(
updates,
@@ -383,6 +392,16 @@ class AgentExecutor(Executor):
response = AgentResponse.from_updates(updates)
# Handle any user input requests after the streaming completes
user_input_requests: list[Content] = []
seen_request_ids: set[str] = set()
for user_input_request in [*streamed_user_input_requests, *response.user_input_requests]:
request_id = getattr(user_input_request, "id", None)
if isinstance(request_id, str) and request_id:
if request_id in seen_request_ids:
continue
seen_request_ids.add(request_id)
user_input_requests.append(user_input_request)
if user_input_requests:
for user_input_request in user_input_requests:
self._pending_agent_requests[user_input_request.id] = user_input_request # type: ignore[index]
@@ -17,6 +17,7 @@ from agent_framework import (
BaseContextProvider,
ChatOptions,
ChatResponse,
ChatResponseUpdate,
Content,
FunctionTool,
Message,
@@ -154,6 +155,111 @@ async def test_chat_client_agent_run_with_session(chat_client_base: SupportsChat
assert session.service_session_id == "123"
async def test_chat_client_agent_updates_existing_session_id_non_streaming(
chat_client_base: SupportsChatGetResponse,
) -> None:
chat_client_base.run_responses = [
ChatResponse(
messages=[Message(role="assistant", contents=[Content.from_text("test response")])],
conversation_id="resp_new_123",
)
]
agent = Agent(client=chat_client_base)
session = agent.get_session(service_session_id="resp_old_123")
await agent.run("Hello", session=session)
assert session.service_session_id == "resp_new_123"
async def test_chat_client_agent_update_session_id_streaming_uses_conversation_id(
chat_client_base: SupportsChatGetResponse,
) -> None:
chat_client_base.streaming_responses = [
[
ChatResponseUpdate(
contents=[Content.from_text("stream part 1")],
role="assistant",
response_id="resp_stream_123",
conversation_id="conv_stream_456",
),
ChatResponseUpdate(
contents=[Content.from_text(" stream part 2")],
role="assistant",
response_id="resp_stream_123",
conversation_id="conv_stream_456",
finish_reason="stop",
),
]
]
agent = Agent(client=chat_client_base)
session = agent.create_session()
stream = agent.run("Hello", session=session, stream=True)
async for _ in stream:
pass
result = await stream.get_final_response()
assert result.text == "stream part 1 stream part 2"
assert session.service_session_id == "conv_stream_456"
async def test_chat_client_agent_updates_existing_session_id_streaming(
chat_client_base: SupportsChatGetResponse,
) -> None:
chat_client_base.streaming_responses = [
[
ChatResponseUpdate(
contents=[Content.from_text("stream part 1")],
role="assistant",
response_id="resp_stream_123",
conversation_id="resp_new_456",
),
ChatResponseUpdate(
contents=[Content.from_text(" stream part 2")],
role="assistant",
response_id="resp_stream_123",
conversation_id="resp_new_456",
finish_reason="stop",
),
]
]
agent = Agent(client=chat_client_base)
session = agent.get_session(service_session_id="resp_old_456")
stream = agent.run("Hello", session=session, stream=True)
async for _ in stream:
pass
await stream.get_final_response()
assert session.service_session_id == "resp_new_456"
async def test_chat_client_agent_update_session_id_streaming_does_not_use_response_id(
chat_client_base: SupportsChatGetResponse,
) -> None:
chat_client_base.streaming_responses = [
[
ChatResponseUpdate(
contents=[Content.from_text("stream response without conversation id")],
role="assistant",
response_id="resp_only_123",
finish_reason="stop",
),
]
]
agent = Agent(client=chat_client_base)
session = agent.create_session()
stream = agent.run("Hello", session=session, stream=True)
async for _ in stream:
pass
result = await stream.get_final_response()
assert result.text == "stream response without conversation id"
assert session.service_session_id is None
async def test_chat_client_agent_update_session_messages(client: SupportsChatGetResponse) -> None:
agent = Agent(client=client)
session = agent.create_session()
@@ -0,0 +1 @@
# Copyright (c) Microsoft. All rights reserved.
@@ -50,6 +50,57 @@ class _CountingAgent(BaseAgent):
return _run()
class _StreamingHookAgent(BaseAgent):
"""Agent that exposes whether its streaming result hook was executed."""
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.result_hook_called = False
def run(
self,
messages: str | Message | list[str] | list[Message] | None = None,
*,
stream: bool = False,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
if stream:
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
yield AgentResponseUpdate(
contents=[Content.from_text(text="hook test")],
role="assistant",
)
async def _mark_result_hook_called(response: AgentResponse) -> AgentResponse:
self.result_hook_called = True
return response
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates).with_result_hook(
_mark_result_hook_called
)
async def _run() -> AgentResponse:
return AgentResponse(messages=[Message("assistant", ["hook test"])])
return _run()
async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> None:
"""AgentExecutor should call get_final_response() so stream result hooks execute."""
agent = _StreamingHookAgent(id="hook_agent", name="HookAgent")
executor = AgentExecutor(agent, id="hook_exec")
workflow = SequentialBuilder(participants=[executor]).build()
output_events: list[Any] = []
async for event in workflow.run("run hook test", stream=True):
if event.type == "output":
output_events.append(event)
assert output_events
assert agent.result_hook_called
async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
"""Test that workflow checkpoint stores AgentExecutor's cache and session states and restores them correctly."""
storage = InMemoryCheckpointStorage()
@@ -12,7 +12,7 @@ from agent_framework import (
WorkflowRunState,
)
from agent_framework._workflows._checkpoint_encoding import (
_PICKLE_MARKER,
_PICKLE_MARKER, # type: ignore
encode_checkpoint_value,
)
from agent_framework._workflows._events import WorkflowEvent