Python: (ag-ui): fix Workflow.as_agent() streaming regression (#3875)

* fix Workflow.as_agent() streaming regression in ag-ui

* Address PR feedback

* PR feedback
This commit is contained in:
Evan Mattson
2026-02-13 07:43:44 +09:00
committed by GitHub
Unverified
parent 1e350ea22f
commit 2203fa0f8b
3 changed files with 185 additions and 66 deletions
@@ -7,7 +7,7 @@ from __future__ import annotations
import json
import logging
import uuid
from collections.abc import Awaitable
from collections.abc import AsyncIterable, Awaitable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, cast
@@ -172,12 +172,12 @@ class FlowState:
tool_call_id: str | None = None # Current tool call being streamed
tool_call_name: str | None = None # Name of current tool call
waiting_for_approval: bool = False # Stop after approval request
current_state: dict[str, Any] = field(default_factory=dict) # Shared state
current_state: dict[str, Any] = field(default_factory=dict) # pyright: ignore[reportUnknownVariableType]
accumulated_text: str = "" # For MessagesSnapshotEvent
pending_tool_calls: list[dict[str, Any]] = field(default_factory=list) # For MessagesSnapshotEvent
tool_calls_by_id: dict[str, dict[str, Any]] = field(default_factory=dict)
tool_results: list[dict[str, Any]] = field(default_factory=list)
tool_calls_ended: set[str] = field(default_factory=set) # Track which tool calls have been ended
pending_tool_calls: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
tool_calls_by_id: dict[str, dict[str, Any]] = field(default_factory=dict) # pyright: ignore[reportUnknownVariableType]
tool_results: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
tool_calls_ended: set[str] = field(default_factory=set) # pyright: ignore[reportUnknownVariableType]
def get_tool_name(self, call_id: str | None) -> str | None:
"""Get tool name by call ID."""
@@ -191,6 +191,40 @@ class FlowState:
return [tc for tc in self.pending_tool_calls if tc.get("id") not in self.tool_calls_ended]
async def _normalize_response_stream(response_stream: Any) -> AsyncIterable[Any]:
"""Normalize agent streaming return types to an async iterable.
Supports:
- ResponseStream (standard agent stream type)
- AsyncIterable[AgentResponseUpdate] (workflow-style stream)
- Awaitable that resolves to either of the above
"""
if isinstance(response_stream, Awaitable):
resolved_stream = await cast(Awaitable[Any], response_stream)
if isinstance(resolved_stream, ResponseStream):
# AG-UI consumes update iteration only; ResponseStream finalizers are not used here.
return cast(AsyncIterable[Any], resolved_stream)
if isinstance(resolved_stream, AsyncIterable):
return cast(AsyncIterable[Any], resolved_stream)
resolved_type = f"{type(resolved_stream).__module__}.{type(resolved_stream).__name__}"
raise AgentExecutionException(
"Agent did not return a streaming AsyncIterable response. "
f"Awaitable resolved to unsupported type: {resolved_type}."
)
if isinstance(response_stream, ResponseStream):
# AG-UI consumes update iteration only; ResponseStream finalizers are not used here.
return cast(AsyncIterable[Any], response_stream)
if isinstance(response_stream, AsyncIterable):
return cast(AsyncIterable[Any], response_stream)
stream_type = f"{type(response_stream).__module__}.{type(response_stream).__name__}"
raise AgentExecutionException(
f"Agent did not return a streaming AsyncIterable response. Received unsupported type: {stream_type}."
)
def _create_state_context_message(
current_state: dict[str, Any],
state_schema: dict[str, Any],
@@ -460,7 +494,7 @@ def _emit_approval_request(
parent_message_id=flow.message_id,
)
)
args = {
args: dict[str, Any] = {
"function_name": func_name,
"function_call_id": func_call_id,
"function_arguments": make_json_safe(func_call.parse_arguments()) or {},
@@ -515,7 +549,8 @@ def _is_confirm_changes_response(messages: list[Any]) -> bool:
if not messages:
return False
last = messages[-1]
if not last.additional_properties.get("is_tool_result", False):
additional_properties = cast(dict[str, Any], getattr(last, "additional_properties", {}) or {})
if not additional_properties.get("is_tool_result", False):
return False
# Parse the content to check if it has the confirm_changes structure
@@ -523,6 +558,8 @@ def _is_confirm_changes_response(messages: list[Any]) -> bool:
if getattr(content, "type", None) == "text" and content.text:
try:
result = json.loads(content.text)
if not isinstance(result, dict):
continue
# confirm_changes results have 'accepted' and 'steps' keys
if "accepted" in result and "steps" in result:
return True
@@ -548,13 +585,19 @@ def _handle_step_based_approval(messages: list[Any]) -> list[BaseEvent]:
message = "Acknowledged."
else:
try:
result = json.loads(approval_text)
accepted = result.get("accepted", False)
steps = result.get("steps", [])
parsed_result = json.loads(approval_text)
result: dict[str, Any] = cast(dict[str, Any], parsed_result) if isinstance(parsed_result, dict) else {}
accepted = bool(result.get("accepted", False))
steps_raw = result.get("steps", [])
steps: list[dict[str, Any]] = []
if isinstance(steps_raw, list):
for step_raw in cast(list[Any], steps_raw):
if isinstance(step_raw, dict):
steps.append(cast(dict[str, Any], step_raw))
if accepted:
# Generate acceptance message with step descriptions
enabled_steps = [s for s in steps if s.get("status") == "enabled"]
enabled_steps: list[dict[str, Any]] = [step for step in steps if step.get("status") == "enabled"]
if enabled_steps:
message_parts = [f"Executing {len(enabled_steps)} approved steps:\n\n"]
for i, step in enumerate(enabled_steps, 1):
@@ -678,8 +721,9 @@ def _convert_approval_results_to_tool_messages(messages: list[Any]) -> None:
result.append(msg)
continue
function_results = [c for c in (msg.contents or []) if getattr(c, "type", None) == "function_result"]
other_contents = [c for c in (msg.contents or []) if getattr(c, "type", None) != "function_result"]
msg_contents = cast(list[Content], getattr(msg, "contents", None) or [])
function_results: list[Content] = [content for content in msg_contents if content.type == "function_result"]
other_contents: list[Content] = [content for content in msg_contents if content.type != "function_result"]
if not function_results:
result.append(msg)
@@ -695,7 +739,7 @@ def _convert_approval_results_to_tool_messages(messages: list[Any]) -> None:
# Then user message with remaining content (if any)
if other_contents:
result.append(Message(role=msg.role, contents=other_contents))
result.append(Message(role="user", contents=other_contents))
messages[:] = result
@@ -765,21 +809,24 @@ async def run_agent_stream(
if input_data.get("state"):
flow.current_state = dict(input_data["state"])
state_schema = cast(dict[str, Any], getattr(config, "state_schema", {}) or {})
predict_state_config = cast(dict[str, dict[str, str]], getattr(config, "predict_state_config", {}) or {})
# Apply schema defaults for missing state keys
if config.state_schema:
for key, schema in config.state_schema.items():
if state_schema:
for key, schema in state_schema.items():
if key in flow.current_state:
continue
if isinstance(schema, dict) and schema.get("type") == "array":
if isinstance(schema, dict) and cast(dict[str, Any], schema).get("type") == "array":
flow.current_state[key] = []
else:
flow.current_state[key] = {}
# Initialize predictive state handler if configured
predictive_handler: PredictiveStateHandler | None = None
if config.predict_state_config:
if predict_state_config:
predictive_handler = PredictiveStateHandler(
predict_state_config=config.predict_state_config,
predict_state_config=predict_state_config,
current_state=flow.current_state,
)
@@ -789,11 +836,11 @@ async def run_agent_stream(
# Check for structured output mode (skip text content)
skip_text = False
response_format = None
from agent_framework import Agent
if isinstance(agent, Agent):
response_format = agent.default_options.get("response_format")
response_format: type[Any] | None = None
default_options = getattr(agent, "default_options", None)
if isinstance(default_options, dict):
typed_default_options = cast(dict[str, Any], default_options)
response_format = cast(type[Any] | None, typed_default_options.get("response_format"))
skip_text = response_format is not None
# Handle empty messages (emit RunStarted immediately since no agent response)
@@ -831,8 +878,9 @@ async def run_agent_stream(
run_kwargs["tools"] = tools
# Filter out AG-UI internal metadata keys before passing to chat client
# These are used internally for orchestration and should not be sent to the LLM provider
client_metadata = {
k: v for k, v in (getattr(session, "metadata", None) or {}).items() if k not in AG_UI_INTERNAL_METADATA_KEYS
session_metadata = cast(dict[str, Any], getattr(session, "metadata", None) or {})
client_metadata: dict[str, Any] = {
k: v for k, v in session_metadata.items() if k not in AG_UI_INTERNAL_METADATA_KEYS
}
safe_metadata = _build_safe_metadata(client_metadata) if client_metadata else {}
if safe_metadata:
@@ -863,19 +911,14 @@ async def run_agent_stream(
# Inject state context message so the model knows current application state
# This is critical for shared state scenarios where the UI state needs to be visible
if config.state_schema and flow.current_state:
messages = _inject_state_context(messages, flow.current_state, config.state_schema)
if state_schema and flow.current_state:
messages = _inject_state_context(messages, flow.current_state, state_schema)
# Stream from agent - emit RunStarted after first update to get service IDs
run_started_emitted = False
all_updates: list[Any] = [] # Collect for structured output processing
response_stream = agent.run(messages, stream=True, **run_kwargs)
if isinstance(response_stream, ResponseStream):
stream = response_stream
else:
stream = await cast(Awaitable[ResponseStream[Any, Any]], response_stream)
if not isinstance(stream, ResponseStream):
raise AgentExecutionException("Chat client did not return a ResponseStream.")
stream = await _normalize_response_stream(response_stream)
async for update in stream:
# Collect updates for structured output processing
if response_format is not None:
@@ -891,18 +934,18 @@ async def run_agent_stream(
# NOW emit RunStarted with proper IDs
yield RunStartedEvent(run_id=run_id, thread_id=thread_id)
# Emit PredictState custom event if configured
if config.predict_state_config:
if predict_state_config:
predict_state_value = [
{
"state_key": state_key,
"tool": cfg["tool"],
"tool_argument": cfg["tool_argument"],
}
for state_key, cfg in config.predict_state_config.items()
for state_key, cfg in predict_state_config.items()
]
yield CustomEvent(name="PredictState", value=predict_state_value)
# Emit initial state snapshot only if we have both state_schema and state
if config.state_schema and flow.current_state:
if state_schema and flow.current_state:
yield StateSnapshotEvent(snapshot=flow.current_state)
run_started_emitted = True
@@ -933,17 +976,17 @@ async def run_agent_stream(
# If no updates at all, still emit RunStarted
if not run_started_emitted:
yield RunStartedEvent(run_id=run_id, thread_id=thread_id)
if config.predict_state_config:
if predict_state_config:
predict_state_value = [
{
"state_key": state_key,
"tool": cfg["tool"],
"tool_argument": cfg["tool_argument"],
}
for state_key, cfg in config.predict_state_config.items()
for state_key, cfg in predict_state_config.items()
]
yield CustomEvent(name="PredictState", value=predict_state_value)
if config.state_schema and flow.current_state:
if state_schema and flow.current_state:
yield StateSnapshotEvent(snapshot=flow.current_state)
# Process structured output if response_format is set
@@ -951,31 +994,33 @@ async def run_agent_stream(
from agent_framework import AgentResponse
from pydantic import BaseModel
logger.info(f"Processing structured output, update count: {len(all_updates)}")
final_response = AgentResponse.from_updates(all_updates, output_format_type=response_format)
if not (isinstance(response_format, type) and issubclass(response_format, BaseModel)):
logger.warning("Skipping structured output parsing: response_format is not a Pydantic model type.")
else:
logger.info(f"Processing structured output, update count: {len(all_updates)}")
final_response = AgentResponse.from_updates(all_updates, output_format_type=response_format)
if final_response.value and isinstance(final_response.value, BaseModel):
response_dict = final_response.value.model_dump(mode="json", exclude_none=True)
logger.info(f"Received structured output keys: {list(response_dict.keys())}")
if final_response.value and isinstance(final_response.value, BaseModel):
response_dict = final_response.value.model_dump(mode="json", exclude_none=True)
logger.info(f"Received structured output keys: {list(response_dict.keys())}")
# Extract state updates - if no state_schema, all non-message fields are state
state_keys = (
set(config.state_schema.keys()) if config.state_schema else set(response_dict.keys()) - {"message"}
)
state_updates = {k: v for k, v in response_dict.items() if k in state_keys}
# Extract state updates - if no state_schema, all non-message fields are state
state_keys = set(state_schema.keys()) if state_schema else set(response_dict.keys()) - {"message"}
state_updates = {k: v for k, v in response_dict.items() if k in state_keys}
if state_updates:
flow.current_state.update(state_updates)
yield StateSnapshotEvent(snapshot=flow.current_state)
logger.info(f"Emitted StateSnapshotEvent with updates: {list(state_updates.keys())}")
if state_updates:
flow.current_state.update(state_updates)
yield StateSnapshotEvent(snapshot=flow.current_state)
logger.info(f"Emitted StateSnapshotEvent with updates: {list(state_updates.keys())}")
# Emit message field as text if present
if "message" in response_dict and response_dict["message"]:
message_id = generate_event_id()
yield TextMessageStartEvent(message_id=message_id, role="assistant")
yield TextMessageContentEvent(message_id=message_id, delta=response_dict["message"])
yield TextMessageEndEvent(message_id=message_id)
logger.info(f"Emitted conversational message with length={len(response_dict['message'])}")
# Emit message field as text if present
message_text = response_dict.get("message")
if isinstance(message_text, str) and message_text:
message_id = generate_event_id()
yield TextMessageStartEvent(message_id=message_id, role="assistant")
yield TextMessageContentEvent(message_id=message_id, delta=message_text)
yield TextMessageEndEvent(message_id=message_id)
logger.info(f"Emitted conversational message with length={len(message_text)}")
# Feature #1: Emit ToolCallEndEvent for declaration-only tools (tools without results)
pending_without_end = flow.get_pending_without_end()
@@ -989,8 +1034,8 @@ async def run_agent_stream(
yield ToolCallEndEvent(tool_call_id=tool_call_id)
# For predictive tools with require_confirmation, emit confirm_changes
if config.require_confirmation and config.predict_state_config and tool_name:
is_predictive_tool = any(cfg["tool"] == tool_name for cfg in config.predict_state_config.values())
if config.require_confirmation and predict_state_config and tool_name:
is_predictive_tool = any(cfg["tool"] == tool_name for cfg in predict_state_config.values())
if is_predictive_tool:
logger.info(f"Emitting confirm_changes for predictive tool '{tool_name}'")
# Extract state value from tool arguments for StateSnapshot
@@ -1071,7 +1116,7 @@ async def run_agent_stream(
last_call_id = last_result.get("toolCallId")
last_tool_name = flow.get_tool_name(last_call_id)
if not _should_suppress_intermediate_snapshot(
last_tool_name, config.predict_state_config, config.require_confirmation
last_tool_name, predict_state_config, config.require_confirmation
):
yield _build_messages_snapshot(flow, snapshot_messages)
@@ -6,6 +6,7 @@ import json
import pytest
from agent_framework import Agent, ChatResponseUpdate, Content
from agent_framework.orchestrations import SequentialBuilder
from fastapi import FastAPI, Header, HTTPException
from fastapi.params import Depends
from fastapi.testclient import TestClient
@@ -165,6 +166,28 @@ async def test_endpoint_event_streaming(build_chat_client):
assert found_run_finished
async def test_endpoint_with_workflow_as_agent_stream_output(build_chat_client):
"""Test endpoint handles workflow-as-agent stream outputs."""
app = FastAPI()
brainstorm_agent = Agent(name="brainstorm", instructions="Brainstorm ideas", client=build_chat_client("Idea"))
reviewer_agent = Agent(name="reviewer", instructions="Review ideas", client=build_chat_client("Review"))
agent = SequentialBuilder(participants=[brainstorm_agent, reviewer_agent]).build().as_agent()
add_agent_framework_fastapi_endpoint(app, agent, path="/workflow-like")
client = TestClient(app)
response = client.post("/workflow-like", json={"messages": [{"role": "user", "content": "Hello"}]})
assert response.status_code == 200
content = response.content.decode("utf-8")
lines = [line for line in content.split("\n") if line.startswith("data: ")]
event_types = [json.loads(line[6:]).get("type") for line in lines]
assert "RUN_STARTED" in event_types
assert "TEXT_MESSAGE_CONTENT" in event_types
assert "RUN_FINISHED" in event_types
async def test_endpoint_error_handling(build_chat_client):
"""Test endpoint error handling during request parsing."""
app = FastAPI()
+52 -1
View File
@@ -2,11 +2,13 @@
"""Tests for _run.py helper functions and FlowState."""
import pytest
from ag_ui.core import (
TextMessageEndEvent,
TextMessageStartEvent,
)
from agent_framework import Content, Message
from agent_framework import AgentResponseUpdate, Content, Message, ResponseStream
from agent_framework.exceptions import AgentExecutionException
from agent_framework_ag_ui._run import (
FlowState,
@@ -16,6 +18,7 @@ from agent_framework_ag_ui._run import (
_emit_tool_result,
_has_only_tool_calls,
_inject_state_context,
_normalize_response_stream,
_should_suppress_intermediate_snapshot,
)
@@ -179,6 +182,54 @@ class TestFlowState:
assert result[0]["id"] == "call_2"
class TestNormalizeResponseStream:
"""Tests for _normalize_response_stream helper."""
async def test_accepts_response_stream(self):
"""Accept standard ResponseStream values."""
async def _stream():
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
stream = await _normalize_response_stream(ResponseStream(_stream()))
updates = [update async for update in stream]
assert len(updates) == 1
assert updates[0].contents[0].text == "hello"
async def test_accepts_async_iterable(self):
"""Accept workflow-style async generator streams."""
async def _stream():
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
stream = await _normalize_response_stream(_stream())
updates = [update async for update in stream]
assert len(updates) == 1
assert updates[0].contents[0].text == "hello"
async def test_accepts_awaitable_resolving_to_async_iterable(self):
"""Accept awaitables that resolve to async iterable streams."""
async def _stream():
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
async def _resolve():
return _stream()
stream = await _normalize_response_stream(_resolve())
updates = [update async for update in stream]
assert len(updates) == 1
assert updates[0].contents[0].text == "hello"
async def test_rejects_non_stream_values(self):
"""Reject unsupported stream return values."""
with pytest.raises(AgentExecutionException):
await _normalize_response_stream("not-a-stream")
class TestCreateStateContextMessage:
"""Tests for _create_state_context_message function."""