mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix streamed workflow agent continuation context by finalizing AgentExecutor streams (#3882)
* Fix streamed workflow agent continuation context by finalizing AgentExecutor streams * Fix stream handling * Fixes * Fix DevUI and tests
This commit is contained in:
committed by
GitHub
Unverified
parent
2203fa0f8b
commit
a276c1295a
@@ -911,19 +911,21 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
if ctx is None:
|
||||
return # No context available (shouldn't happen in normal flow)
|
||||
|
||||
# Update thread with conversation_id derived from streaming raw updates.
|
||||
# Using response_id here can break function-call continuation for APIs
|
||||
# where response IDs are not valid conversation handles.
|
||||
conversation_id = self._extract_conversation_id_from_streaming_response(response)
|
||||
# Ensure author names are set for all messages
|
||||
for message in response.messages:
|
||||
if message.author_name is None:
|
||||
message.author_name = ctx["agent_name"]
|
||||
|
||||
# Propagate conversation_id back to session from streaming updates
|
||||
# Propagate conversation_id back to session from streaming updates.
|
||||
# For Responses-style APIs this can rotate every turn (response_id-based continuation),
|
||||
# so refresh when a newer value is returned.
|
||||
sess = ctx["session"]
|
||||
if sess and not sess.service_session_id and response.raw_representation:
|
||||
raw_items = response.raw_representation if isinstance(response.raw_representation, list) else []
|
||||
for item in raw_items:
|
||||
if hasattr(item, "conversation_id") and item.conversation_id:
|
||||
sess.service_session_id = item.conversation_id
|
||||
break
|
||||
if sess and conversation_id and sess.service_session_id != conversation_id:
|
||||
sess.service_session_id = conversation_id
|
||||
|
||||
# Run after_run providers (reverse order)
|
||||
session_context = ctx["session_context"]
|
||||
@@ -974,6 +976,27 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
output_format_type = response_format if isinstance(response_format, type) else None
|
||||
return AgentResponse.from_updates(updates, output_format_type=output_format_type)
|
||||
|
||||
@staticmethod
|
||||
def _extract_conversation_id_from_streaming_response(response: AgentResponse[Any]) -> str | None:
|
||||
"""Extract conversation_id from streaming raw updates, if present."""
|
||||
raw = response.raw_representation
|
||||
if raw is None:
|
||||
return None
|
||||
|
||||
raw_items: list[Any] = raw if isinstance(raw, list) else [raw]
|
||||
for item in reversed(raw_items):
|
||||
if isinstance(item, Mapping):
|
||||
value = item.get("conversation_id")
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
continue
|
||||
|
||||
value = getattr(item, "conversation_id", None)
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
async def _prepare_run_context(
|
||||
self,
|
||||
*,
|
||||
@@ -1100,8 +1123,10 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
if message.author_name is None:
|
||||
message.author_name = agent_name
|
||||
|
||||
# Propagate conversation_id back to session (e.g. thread ID from Assistants API)
|
||||
if session and response.conversation_id and not session.service_session_id:
|
||||
# Propagate conversation_id back to session (e.g. thread ID from Assistants API).
|
||||
# For Responses-style APIs this can rotate every turn (response_id-based continuation),
|
||||
# so refresh when a newer value is returned.
|
||||
if session and response.conversation_id and session.service_session_id != response.conversation_id:
|
||||
session.service_session_id = response.conversation_id
|
||||
|
||||
# Set the response on the context for after_run providers
|
||||
|
||||
@@ -872,7 +872,16 @@ class MCPTool:
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k
|
||||
not in {"chat_options", "tools", "tool_choice", "session", "thread", "conversation_id", "options", "response_format"}
|
||||
not in {
|
||||
"chat_options",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"session",
|
||||
"thread",
|
||||
"conversation_id",
|
||||
"options",
|
||||
"response_format",
|
||||
}
|
||||
}
|
||||
|
||||
parser = self.parse_tool_results or _parse_tool_result_from_mcp
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -358,22 +358,31 @@ class AgentExecutor(Executor):
|
||||
run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {})
|
||||
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
user_input_requests: list[Content] = []
|
||||
async for update in self._agent.run(
|
||||
streamed_user_input_requests: list[Content] = []
|
||||
stream = self._agent.run(
|
||||
self._cache,
|
||||
stream=True,
|
||||
session=self._session,
|
||||
options=options,
|
||||
**run_kwargs,
|
||||
):
|
||||
)
|
||||
async for update in stream:
|
||||
updates.append(update)
|
||||
await ctx.yield_output(update)
|
||||
|
||||
if update.user_input_requests:
|
||||
user_input_requests.extend(update.user_input_requests)
|
||||
streamed_user_input_requests.extend(update.user_input_requests)
|
||||
|
||||
# Build the final AgentResponse from the collected updates
|
||||
if is_chat_agent(self._agent):
|
||||
# Prefer stream finalization when available so result hooks run
|
||||
# (e.g., thread conversation updates). Fall back to reconstructing from updates
|
||||
# for legacy/custom agents that return a plain async iterable.
|
||||
# TODO(evmattso): Integrate workflow agent run handling around ResponseStream so
|
||||
# AgentExecutor does not need this conditional stream-finalization branch.
|
||||
maybe_get_final_response = getattr(stream, "get_final_response", None)
|
||||
get_final_response = maybe_get_final_response if callable(maybe_get_final_response) else None
|
||||
response: AgentResponse[Any]
|
||||
if get_final_response is not None:
|
||||
response = await cast(Callable[[], Awaitable[AgentResponse[Any]]], get_final_response)()
|
||||
elif is_chat_agent(self._agent):
|
||||
response_format = self._agent.default_options.get("response_format")
|
||||
response = AgentResponse.from_updates(
|
||||
updates,
|
||||
@@ -383,6 +392,16 @@ class AgentExecutor(Executor):
|
||||
response = AgentResponse.from_updates(updates)
|
||||
|
||||
# Handle any user input requests after the streaming completes
|
||||
user_input_requests: list[Content] = []
|
||||
seen_request_ids: set[str] = set()
|
||||
for user_input_request in [*streamed_user_input_requests, *response.user_input_requests]:
|
||||
request_id = getattr(user_input_request, "id", None)
|
||||
if isinstance(request_id, str) and request_id:
|
||||
if request_id in seen_request_ids:
|
||||
continue
|
||||
seen_request_ids.add(request_id)
|
||||
user_input_requests.append(user_input_request)
|
||||
|
||||
if user_input_requests:
|
||||
for user_input_request in user_input_requests:
|
||||
self._pending_agent_requests[user_input_request.id] = user_input_request # type: ignore[index]
|
||||
|
||||
@@ -17,6 +17,7 @@ from agent_framework import (
|
||||
BaseContextProvider,
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
Content,
|
||||
FunctionTool,
|
||||
Message,
|
||||
@@ -154,6 +155,111 @@ async def test_chat_client_agent_run_with_session(chat_client_base: SupportsChat
|
||||
assert session.service_session_id == "123"
|
||||
|
||||
|
||||
async def test_chat_client_agent_updates_existing_session_id_non_streaming(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=[Message(role="assistant", contents=[Content.from_text("test response")])],
|
||||
conversation_id="resp_new_123",
|
||||
)
|
||||
]
|
||||
|
||||
agent = Agent(client=chat_client_base)
|
||||
session = agent.get_session(service_session_id="resp_old_123")
|
||||
|
||||
await agent.run("Hello", session=session)
|
||||
assert session.service_session_id == "resp_new_123"
|
||||
|
||||
|
||||
async def test_chat_client_agent_update_session_id_streaming_uses_conversation_id(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
chat_client_base.streaming_responses = [
|
||||
[
|
||||
ChatResponseUpdate(
|
||||
contents=[Content.from_text("stream part 1")],
|
||||
role="assistant",
|
||||
response_id="resp_stream_123",
|
||||
conversation_id="conv_stream_456",
|
||||
),
|
||||
ChatResponseUpdate(
|
||||
contents=[Content.from_text(" stream part 2")],
|
||||
role="assistant",
|
||||
response_id="resp_stream_123",
|
||||
conversation_id="conv_stream_456",
|
||||
finish_reason="stop",
|
||||
),
|
||||
]
|
||||
]
|
||||
|
||||
agent = Agent(client=chat_client_base)
|
||||
session = agent.create_session()
|
||||
|
||||
stream = agent.run("Hello", session=session, stream=True)
|
||||
async for _ in stream:
|
||||
pass
|
||||
result = await stream.get_final_response()
|
||||
assert result.text == "stream part 1 stream part 2"
|
||||
assert session.service_session_id == "conv_stream_456"
|
||||
|
||||
|
||||
async def test_chat_client_agent_updates_existing_session_id_streaming(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
chat_client_base.streaming_responses = [
|
||||
[
|
||||
ChatResponseUpdate(
|
||||
contents=[Content.from_text("stream part 1")],
|
||||
role="assistant",
|
||||
response_id="resp_stream_123",
|
||||
conversation_id="resp_new_456",
|
||||
),
|
||||
ChatResponseUpdate(
|
||||
contents=[Content.from_text(" stream part 2")],
|
||||
role="assistant",
|
||||
response_id="resp_stream_123",
|
||||
conversation_id="resp_new_456",
|
||||
finish_reason="stop",
|
||||
),
|
||||
]
|
||||
]
|
||||
|
||||
agent = Agent(client=chat_client_base)
|
||||
session = agent.get_session(service_session_id="resp_old_456")
|
||||
|
||||
stream = agent.run("Hello", session=session, stream=True)
|
||||
async for _ in stream:
|
||||
pass
|
||||
await stream.get_final_response()
|
||||
assert session.service_session_id == "resp_new_456"
|
||||
|
||||
|
||||
async def test_chat_client_agent_update_session_id_streaming_does_not_use_response_id(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
chat_client_base.streaming_responses = [
|
||||
[
|
||||
ChatResponseUpdate(
|
||||
contents=[Content.from_text("stream response without conversation id")],
|
||||
role="assistant",
|
||||
response_id="resp_only_123",
|
||||
finish_reason="stop",
|
||||
),
|
||||
]
|
||||
]
|
||||
|
||||
agent = Agent(client=chat_client_base)
|
||||
session = agent.create_session()
|
||||
|
||||
stream = agent.run("Hello", session=session, stream=True)
|
||||
async for _ in stream:
|
||||
pass
|
||||
result = await stream.get_final_response()
|
||||
assert result.text == "stream response without conversation id"
|
||||
assert session.service_session_id is None
|
||||
|
||||
|
||||
async def test_chat_client_agent_update_session_messages(client: SupportsChatGetResponse) -> None:
|
||||
agent = Agent(client=client)
|
||||
session = agent.create_session()
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
@@ -50,6 +50,57 @@ class _CountingAgent(BaseAgent):
|
||||
return _run()
|
||||
|
||||
|
||||
class _StreamingHookAgent(BaseAgent):
|
||||
"""Agent that exposes whether its streaming result hook was executed."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self.result_hook_called = False
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | list[str] | list[Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
if stream:
|
||||
|
||||
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
|
||||
yield AgentResponseUpdate(
|
||||
contents=[Content.from_text(text="hook test")],
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
async def _mark_result_hook_called(response: AgentResponse) -> AgentResponse:
|
||||
self.result_hook_called = True
|
||||
return response
|
||||
|
||||
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates).with_result_hook(
|
||||
_mark_result_hook_called
|
||||
)
|
||||
|
||||
async def _run() -> AgentResponse:
|
||||
return AgentResponse(messages=[Message("assistant", ["hook test"])])
|
||||
|
||||
return _run()
|
||||
|
||||
|
||||
async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> None:
|
||||
"""AgentExecutor should call get_final_response() so stream result hooks execute."""
|
||||
agent = _StreamingHookAgent(id="hook_agent", name="HookAgent")
|
||||
executor = AgentExecutor(agent, id="hook_exec")
|
||||
workflow = SequentialBuilder(participants=[executor]).build()
|
||||
|
||||
output_events: list[Any] = []
|
||||
async for event in workflow.run("run hook test", stream=True):
|
||||
if event.type == "output":
|
||||
output_events.append(event)
|
||||
|
||||
assert output_events
|
||||
assert agent.result_hook_called
|
||||
|
||||
|
||||
async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
|
||||
"""Test that workflow checkpoint stores AgentExecutor's cache and session states and restores them correctly."""
|
||||
storage = InMemoryCheckpointStorage()
|
||||
|
||||
@@ -12,7 +12,7 @@ from agent_framework import (
|
||||
WorkflowRunState,
|
||||
)
|
||||
from agent_framework._workflows._checkpoint_encoding import (
|
||||
_PICKLE_MARKER,
|
||||
_PICKLE_MARKER, # type: ignore
|
||||
encode_checkpoint_value,
|
||||
)
|
||||
from agent_framework._workflows._events import WorkflowEvent
|
||||
|
||||
Reference in New Issue
Block a user