mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: (ag-ui): fix Workflow.as_agent() streaming regression (#3875)
* fix Workflow.as_agent() streaming regression in ag-ui * Address PR feedback * PR feedback
This commit is contained in:
committed by
GitHub
Unverified
parent
1e350ea22f
commit
2203fa0f8b
@@ -7,7 +7,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import AsyncIterable, Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
@@ -172,12 +172,12 @@ class FlowState:
|
||||
tool_call_id: str | None = None # Current tool call being streamed
|
||||
tool_call_name: str | None = None # Name of current tool call
|
||||
waiting_for_approval: bool = False # Stop after approval request
|
||||
current_state: dict[str, Any] = field(default_factory=dict) # Shared state
|
||||
current_state: dict[str, Any] = field(default_factory=dict) # pyright: ignore[reportUnknownVariableType]
|
||||
accumulated_text: str = "" # For MessagesSnapshotEvent
|
||||
pending_tool_calls: list[dict[str, Any]] = field(default_factory=list) # For MessagesSnapshotEvent
|
||||
tool_calls_by_id: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
tool_results: list[dict[str, Any]] = field(default_factory=list)
|
||||
tool_calls_ended: set[str] = field(default_factory=set) # Track which tool calls have been ended
|
||||
pending_tool_calls: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
|
||||
tool_calls_by_id: dict[str, dict[str, Any]] = field(default_factory=dict) # pyright: ignore[reportUnknownVariableType]
|
||||
tool_results: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
|
||||
tool_calls_ended: set[str] = field(default_factory=set) # pyright: ignore[reportUnknownVariableType]
|
||||
|
||||
def get_tool_name(self, call_id: str | None) -> str | None:
|
||||
"""Get tool name by call ID."""
|
||||
@@ -191,6 +191,40 @@ class FlowState:
|
||||
return [tc for tc in self.pending_tool_calls if tc.get("id") not in self.tool_calls_ended]
|
||||
|
||||
|
||||
async def _normalize_response_stream(response_stream: Any) -> AsyncIterable[Any]:
|
||||
"""Normalize agent streaming return types to an async iterable.
|
||||
|
||||
Supports:
|
||||
- ResponseStream (standard agent stream type)
|
||||
- AsyncIterable[AgentResponseUpdate] (workflow-style stream)
|
||||
- Awaitable that resolves to either of the above
|
||||
"""
|
||||
if isinstance(response_stream, Awaitable):
|
||||
resolved_stream = await cast(Awaitable[Any], response_stream)
|
||||
if isinstance(resolved_stream, ResponseStream):
|
||||
# AG-UI consumes update iteration only; ResponseStream finalizers are not used here.
|
||||
return cast(AsyncIterable[Any], resolved_stream)
|
||||
if isinstance(resolved_stream, AsyncIterable):
|
||||
return cast(AsyncIterable[Any], resolved_stream)
|
||||
resolved_type = f"{type(resolved_stream).__module__}.{type(resolved_stream).__name__}"
|
||||
raise AgentExecutionException(
|
||||
"Agent did not return a streaming AsyncIterable response. "
|
||||
f"Awaitable resolved to unsupported type: {resolved_type}."
|
||||
)
|
||||
|
||||
if isinstance(response_stream, ResponseStream):
|
||||
# AG-UI consumes update iteration only; ResponseStream finalizers are not used here.
|
||||
return cast(AsyncIterable[Any], response_stream)
|
||||
|
||||
if isinstance(response_stream, AsyncIterable):
|
||||
return cast(AsyncIterable[Any], response_stream)
|
||||
|
||||
stream_type = f"{type(response_stream).__module__}.{type(response_stream).__name__}"
|
||||
raise AgentExecutionException(
|
||||
f"Agent did not return a streaming AsyncIterable response. Received unsupported type: {stream_type}."
|
||||
)
|
||||
|
||||
|
||||
def _create_state_context_message(
|
||||
current_state: dict[str, Any],
|
||||
state_schema: dict[str, Any],
|
||||
@@ -460,7 +494,7 @@ def _emit_approval_request(
|
||||
parent_message_id=flow.message_id,
|
||||
)
|
||||
)
|
||||
args = {
|
||||
args: dict[str, Any] = {
|
||||
"function_name": func_name,
|
||||
"function_call_id": func_call_id,
|
||||
"function_arguments": make_json_safe(func_call.parse_arguments()) or {},
|
||||
@@ -515,7 +549,8 @@ def _is_confirm_changes_response(messages: list[Any]) -> bool:
|
||||
if not messages:
|
||||
return False
|
||||
last = messages[-1]
|
||||
if not last.additional_properties.get("is_tool_result", False):
|
||||
additional_properties = cast(dict[str, Any], getattr(last, "additional_properties", {}) or {})
|
||||
if not additional_properties.get("is_tool_result", False):
|
||||
return False
|
||||
|
||||
# Parse the content to check if it has the confirm_changes structure
|
||||
@@ -523,6 +558,8 @@ def _is_confirm_changes_response(messages: list[Any]) -> bool:
|
||||
if getattr(content, "type", None) == "text" and content.text:
|
||||
try:
|
||||
result = json.loads(content.text)
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
# confirm_changes results have 'accepted' and 'steps' keys
|
||||
if "accepted" in result and "steps" in result:
|
||||
return True
|
||||
@@ -548,13 +585,19 @@ def _handle_step_based_approval(messages: list[Any]) -> list[BaseEvent]:
|
||||
message = "Acknowledged."
|
||||
else:
|
||||
try:
|
||||
result = json.loads(approval_text)
|
||||
accepted = result.get("accepted", False)
|
||||
steps = result.get("steps", [])
|
||||
parsed_result = json.loads(approval_text)
|
||||
result: dict[str, Any] = cast(dict[str, Any], parsed_result) if isinstance(parsed_result, dict) else {}
|
||||
accepted = bool(result.get("accepted", False))
|
||||
steps_raw = result.get("steps", [])
|
||||
steps: list[dict[str, Any]] = []
|
||||
if isinstance(steps_raw, list):
|
||||
for step_raw in cast(list[Any], steps_raw):
|
||||
if isinstance(step_raw, dict):
|
||||
steps.append(cast(dict[str, Any], step_raw))
|
||||
|
||||
if accepted:
|
||||
# Generate acceptance message with step descriptions
|
||||
enabled_steps = [s for s in steps if s.get("status") == "enabled"]
|
||||
enabled_steps: list[dict[str, Any]] = [step for step in steps if step.get("status") == "enabled"]
|
||||
if enabled_steps:
|
||||
message_parts = [f"Executing {len(enabled_steps)} approved steps:\n\n"]
|
||||
for i, step in enumerate(enabled_steps, 1):
|
||||
@@ -678,8 +721,9 @@ def _convert_approval_results_to_tool_messages(messages: list[Any]) -> None:
|
||||
result.append(msg)
|
||||
continue
|
||||
|
||||
function_results = [c for c in (msg.contents or []) if getattr(c, "type", None) == "function_result"]
|
||||
other_contents = [c for c in (msg.contents or []) if getattr(c, "type", None) != "function_result"]
|
||||
msg_contents = cast(list[Content], getattr(msg, "contents", None) or [])
|
||||
function_results: list[Content] = [content for content in msg_contents if content.type == "function_result"]
|
||||
other_contents: list[Content] = [content for content in msg_contents if content.type != "function_result"]
|
||||
|
||||
if not function_results:
|
||||
result.append(msg)
|
||||
@@ -695,7 +739,7 @@ def _convert_approval_results_to_tool_messages(messages: list[Any]) -> None:
|
||||
|
||||
# Then user message with remaining content (if any)
|
||||
if other_contents:
|
||||
result.append(Message(role=msg.role, contents=other_contents))
|
||||
result.append(Message(role="user", contents=other_contents))
|
||||
|
||||
messages[:] = result
|
||||
|
||||
@@ -765,21 +809,24 @@ async def run_agent_stream(
|
||||
if input_data.get("state"):
|
||||
flow.current_state = dict(input_data["state"])
|
||||
|
||||
state_schema = cast(dict[str, Any], getattr(config, "state_schema", {}) or {})
|
||||
predict_state_config = cast(dict[str, dict[str, str]], getattr(config, "predict_state_config", {}) or {})
|
||||
|
||||
# Apply schema defaults for missing state keys
|
||||
if config.state_schema:
|
||||
for key, schema in config.state_schema.items():
|
||||
if state_schema:
|
||||
for key, schema in state_schema.items():
|
||||
if key in flow.current_state:
|
||||
continue
|
||||
if isinstance(schema, dict) and schema.get("type") == "array":
|
||||
if isinstance(schema, dict) and cast(dict[str, Any], schema).get("type") == "array":
|
||||
flow.current_state[key] = []
|
||||
else:
|
||||
flow.current_state[key] = {}
|
||||
|
||||
# Initialize predictive state handler if configured
|
||||
predictive_handler: PredictiveStateHandler | None = None
|
||||
if config.predict_state_config:
|
||||
if predict_state_config:
|
||||
predictive_handler = PredictiveStateHandler(
|
||||
predict_state_config=config.predict_state_config,
|
||||
predict_state_config=predict_state_config,
|
||||
current_state=flow.current_state,
|
||||
)
|
||||
|
||||
@@ -789,11 +836,11 @@ async def run_agent_stream(
|
||||
|
||||
# Check for structured output mode (skip text content)
|
||||
skip_text = False
|
||||
response_format = None
|
||||
from agent_framework import Agent
|
||||
|
||||
if isinstance(agent, Agent):
|
||||
response_format = agent.default_options.get("response_format")
|
||||
response_format: type[Any] | None = None
|
||||
default_options = getattr(agent, "default_options", None)
|
||||
if isinstance(default_options, dict):
|
||||
typed_default_options = cast(dict[str, Any], default_options)
|
||||
response_format = cast(type[Any] | None, typed_default_options.get("response_format"))
|
||||
skip_text = response_format is not None
|
||||
|
||||
# Handle empty messages (emit RunStarted immediately since no agent response)
|
||||
@@ -831,8 +878,9 @@ async def run_agent_stream(
|
||||
run_kwargs["tools"] = tools
|
||||
# Filter out AG-UI internal metadata keys before passing to chat client
|
||||
# These are used internally for orchestration and should not be sent to the LLM provider
|
||||
client_metadata = {
|
||||
k: v for k, v in (getattr(session, "metadata", None) or {}).items() if k not in AG_UI_INTERNAL_METADATA_KEYS
|
||||
session_metadata = cast(dict[str, Any], getattr(session, "metadata", None) or {})
|
||||
client_metadata: dict[str, Any] = {
|
||||
k: v for k, v in session_metadata.items() if k not in AG_UI_INTERNAL_METADATA_KEYS
|
||||
}
|
||||
safe_metadata = _build_safe_metadata(client_metadata) if client_metadata else {}
|
||||
if safe_metadata:
|
||||
@@ -863,19 +911,14 @@ async def run_agent_stream(
|
||||
|
||||
# Inject state context message so the model knows current application state
|
||||
# This is critical for shared state scenarios where the UI state needs to be visible
|
||||
if config.state_schema and flow.current_state:
|
||||
messages = _inject_state_context(messages, flow.current_state, config.state_schema)
|
||||
if state_schema and flow.current_state:
|
||||
messages = _inject_state_context(messages, flow.current_state, state_schema)
|
||||
|
||||
# Stream from agent - emit RunStarted after first update to get service IDs
|
||||
run_started_emitted = False
|
||||
all_updates: list[Any] = [] # Collect for structured output processing
|
||||
response_stream = agent.run(messages, stream=True, **run_kwargs)
|
||||
if isinstance(response_stream, ResponseStream):
|
||||
stream = response_stream
|
||||
else:
|
||||
stream = await cast(Awaitable[ResponseStream[Any, Any]], response_stream)
|
||||
if not isinstance(stream, ResponseStream):
|
||||
raise AgentExecutionException("Chat client did not return a ResponseStream.")
|
||||
stream = await _normalize_response_stream(response_stream)
|
||||
async for update in stream:
|
||||
# Collect updates for structured output processing
|
||||
if response_format is not None:
|
||||
@@ -891,18 +934,18 @@ async def run_agent_stream(
|
||||
# NOW emit RunStarted with proper IDs
|
||||
yield RunStartedEvent(run_id=run_id, thread_id=thread_id)
|
||||
# Emit PredictState custom event if configured
|
||||
if config.predict_state_config:
|
||||
if predict_state_config:
|
||||
predict_state_value = [
|
||||
{
|
||||
"state_key": state_key,
|
||||
"tool": cfg["tool"],
|
||||
"tool_argument": cfg["tool_argument"],
|
||||
}
|
||||
for state_key, cfg in config.predict_state_config.items()
|
||||
for state_key, cfg in predict_state_config.items()
|
||||
]
|
||||
yield CustomEvent(name="PredictState", value=predict_state_value)
|
||||
# Emit initial state snapshot only if we have both state_schema and state
|
||||
if config.state_schema and flow.current_state:
|
||||
if state_schema and flow.current_state:
|
||||
yield StateSnapshotEvent(snapshot=flow.current_state)
|
||||
run_started_emitted = True
|
||||
|
||||
@@ -933,17 +976,17 @@ async def run_agent_stream(
|
||||
# If no updates at all, still emit RunStarted
|
||||
if not run_started_emitted:
|
||||
yield RunStartedEvent(run_id=run_id, thread_id=thread_id)
|
||||
if config.predict_state_config:
|
||||
if predict_state_config:
|
||||
predict_state_value = [
|
||||
{
|
||||
"state_key": state_key,
|
||||
"tool": cfg["tool"],
|
||||
"tool_argument": cfg["tool_argument"],
|
||||
}
|
||||
for state_key, cfg in config.predict_state_config.items()
|
||||
for state_key, cfg in predict_state_config.items()
|
||||
]
|
||||
yield CustomEvent(name="PredictState", value=predict_state_value)
|
||||
if config.state_schema and flow.current_state:
|
||||
if state_schema and flow.current_state:
|
||||
yield StateSnapshotEvent(snapshot=flow.current_state)
|
||||
|
||||
# Process structured output if response_format is set
|
||||
@@ -951,31 +994,33 @@ async def run_agent_stream(
|
||||
from agent_framework import AgentResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger.info(f"Processing structured output, update count: {len(all_updates)}")
|
||||
final_response = AgentResponse.from_updates(all_updates, output_format_type=response_format)
|
||||
if not (isinstance(response_format, type) and issubclass(response_format, BaseModel)):
|
||||
logger.warning("Skipping structured output parsing: response_format is not a Pydantic model type.")
|
||||
else:
|
||||
logger.info(f"Processing structured output, update count: {len(all_updates)}")
|
||||
final_response = AgentResponse.from_updates(all_updates, output_format_type=response_format)
|
||||
|
||||
if final_response.value and isinstance(final_response.value, BaseModel):
|
||||
response_dict = final_response.value.model_dump(mode="json", exclude_none=True)
|
||||
logger.info(f"Received structured output keys: {list(response_dict.keys())}")
|
||||
if final_response.value and isinstance(final_response.value, BaseModel):
|
||||
response_dict = final_response.value.model_dump(mode="json", exclude_none=True)
|
||||
logger.info(f"Received structured output keys: {list(response_dict.keys())}")
|
||||
|
||||
# Extract state updates - if no state_schema, all non-message fields are state
|
||||
state_keys = (
|
||||
set(config.state_schema.keys()) if config.state_schema else set(response_dict.keys()) - {"message"}
|
||||
)
|
||||
state_updates = {k: v for k, v in response_dict.items() if k in state_keys}
|
||||
# Extract state updates - if no state_schema, all non-message fields are state
|
||||
state_keys = set(state_schema.keys()) if state_schema else set(response_dict.keys()) - {"message"}
|
||||
state_updates = {k: v for k, v in response_dict.items() if k in state_keys}
|
||||
|
||||
if state_updates:
|
||||
flow.current_state.update(state_updates)
|
||||
yield StateSnapshotEvent(snapshot=flow.current_state)
|
||||
logger.info(f"Emitted StateSnapshotEvent with updates: {list(state_updates.keys())}")
|
||||
if state_updates:
|
||||
flow.current_state.update(state_updates)
|
||||
yield StateSnapshotEvent(snapshot=flow.current_state)
|
||||
logger.info(f"Emitted StateSnapshotEvent with updates: {list(state_updates.keys())}")
|
||||
|
||||
# Emit message field as text if present
|
||||
if "message" in response_dict and response_dict["message"]:
|
||||
message_id = generate_event_id()
|
||||
yield TextMessageStartEvent(message_id=message_id, role="assistant")
|
||||
yield TextMessageContentEvent(message_id=message_id, delta=response_dict["message"])
|
||||
yield TextMessageEndEvent(message_id=message_id)
|
||||
logger.info(f"Emitted conversational message with length={len(response_dict['message'])}")
|
||||
# Emit message field as text if present
|
||||
message_text = response_dict.get("message")
|
||||
if isinstance(message_text, str) and message_text:
|
||||
message_id = generate_event_id()
|
||||
yield TextMessageStartEvent(message_id=message_id, role="assistant")
|
||||
yield TextMessageContentEvent(message_id=message_id, delta=message_text)
|
||||
yield TextMessageEndEvent(message_id=message_id)
|
||||
logger.info(f"Emitted conversational message with length={len(message_text)}")
|
||||
|
||||
# Feature #1: Emit ToolCallEndEvent for declaration-only tools (tools without results)
|
||||
pending_without_end = flow.get_pending_without_end()
|
||||
@@ -989,8 +1034,8 @@ async def run_agent_stream(
|
||||
yield ToolCallEndEvent(tool_call_id=tool_call_id)
|
||||
|
||||
# For predictive tools with require_confirmation, emit confirm_changes
|
||||
if config.require_confirmation and config.predict_state_config and tool_name:
|
||||
is_predictive_tool = any(cfg["tool"] == tool_name for cfg in config.predict_state_config.values())
|
||||
if config.require_confirmation and predict_state_config and tool_name:
|
||||
is_predictive_tool = any(cfg["tool"] == tool_name for cfg in predict_state_config.values())
|
||||
if is_predictive_tool:
|
||||
logger.info(f"Emitting confirm_changes for predictive tool '{tool_name}'")
|
||||
# Extract state value from tool arguments for StateSnapshot
|
||||
@@ -1071,7 +1116,7 @@ async def run_agent_stream(
|
||||
last_call_id = last_result.get("toolCallId")
|
||||
last_tool_name = flow.get_tool_name(last_call_id)
|
||||
if not _should_suppress_intermediate_snapshot(
|
||||
last_tool_name, config.predict_state_config, config.require_confirmation
|
||||
last_tool_name, predict_state_config, config.require_confirmation
|
||||
):
|
||||
yield _build_messages_snapshot(flow, snapshot_messages)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import json
|
||||
|
||||
import pytest
|
||||
from agent_framework import Agent, ChatResponseUpdate, Content
|
||||
from agent_framework.orchestrations import SequentialBuilder
|
||||
from fastapi import FastAPI, Header, HTTPException
|
||||
from fastapi.params import Depends
|
||||
from fastapi.testclient import TestClient
|
||||
@@ -165,6 +166,28 @@ async def test_endpoint_event_streaming(build_chat_client):
|
||||
assert found_run_finished
|
||||
|
||||
|
||||
async def test_endpoint_with_workflow_as_agent_stream_output(build_chat_client):
|
||||
"""Test endpoint handles workflow-as-agent stream outputs."""
|
||||
app = FastAPI()
|
||||
brainstorm_agent = Agent(name="brainstorm", instructions="Brainstorm ideas", client=build_chat_client("Idea"))
|
||||
reviewer_agent = Agent(name="reviewer", instructions="Review ideas", client=build_chat_client("Review"))
|
||||
agent = SequentialBuilder(participants=[brainstorm_agent, reviewer_agent]).build().as_agent()
|
||||
|
||||
add_agent_framework_fastapi_endpoint(app, agent, path="/workflow-like")
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.post("/workflow-like", json={"messages": [{"role": "user", "content": "Hello"}]})
|
||||
|
||||
assert response.status_code == 200
|
||||
content = response.content.decode("utf-8")
|
||||
lines = [line for line in content.split("\n") if line.startswith("data: ")]
|
||||
event_types = [json.loads(line[6:]).get("type") for line in lines]
|
||||
|
||||
assert "RUN_STARTED" in event_types
|
||||
assert "TEXT_MESSAGE_CONTENT" in event_types
|
||||
assert "RUN_FINISHED" in event_types
|
||||
|
||||
|
||||
async def test_endpoint_error_handling(build_chat_client):
|
||||
"""Test endpoint error handling during request parsing."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -2,11 +2,13 @@
|
||||
|
||||
"""Tests for _run.py helper functions and FlowState."""
|
||||
|
||||
import pytest
|
||||
from ag_ui.core import (
|
||||
TextMessageEndEvent,
|
||||
TextMessageStartEvent,
|
||||
)
|
||||
from agent_framework import Content, Message
|
||||
from agent_framework import AgentResponseUpdate, Content, Message, ResponseStream
|
||||
from agent_framework.exceptions import AgentExecutionException
|
||||
|
||||
from agent_framework_ag_ui._run import (
|
||||
FlowState,
|
||||
@@ -16,6 +18,7 @@ from agent_framework_ag_ui._run import (
|
||||
_emit_tool_result,
|
||||
_has_only_tool_calls,
|
||||
_inject_state_context,
|
||||
_normalize_response_stream,
|
||||
_should_suppress_intermediate_snapshot,
|
||||
)
|
||||
|
||||
@@ -179,6 +182,54 @@ class TestFlowState:
|
||||
assert result[0]["id"] == "call_2"
|
||||
|
||||
|
||||
class TestNormalizeResponseStream:
|
||||
"""Tests for _normalize_response_stream helper."""
|
||||
|
||||
async def test_accepts_response_stream(self):
|
||||
"""Accept standard ResponseStream values."""
|
||||
|
||||
async def _stream():
|
||||
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
|
||||
|
||||
stream = await _normalize_response_stream(ResponseStream(_stream()))
|
||||
updates = [update async for update in stream]
|
||||
|
||||
assert len(updates) == 1
|
||||
assert updates[0].contents[0].text == "hello"
|
||||
|
||||
async def test_accepts_async_iterable(self):
|
||||
"""Accept workflow-style async generator streams."""
|
||||
|
||||
async def _stream():
|
||||
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
|
||||
|
||||
stream = await _normalize_response_stream(_stream())
|
||||
updates = [update async for update in stream]
|
||||
|
||||
assert len(updates) == 1
|
||||
assert updates[0].contents[0].text == "hello"
|
||||
|
||||
async def test_accepts_awaitable_resolving_to_async_iterable(self):
|
||||
"""Accept awaitables that resolve to async iterable streams."""
|
||||
|
||||
async def _stream():
|
||||
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
|
||||
|
||||
async def _resolve():
|
||||
return _stream()
|
||||
|
||||
stream = await _normalize_response_stream(_resolve())
|
||||
updates = [update async for update in stream]
|
||||
|
||||
assert len(updates) == 1
|
||||
assert updates[0].contents[0].text == "hello"
|
||||
|
||||
async def test_rejects_non_stream_values(self):
|
||||
"""Reject unsupported stream return values."""
|
||||
with pytest.raises(AgentExecutionException):
|
||||
await _normalize_response_stream("not-a-stream")
|
||||
|
||||
|
||||
class TestCreateStateContextMessage:
|
||||
"""Tests for _create_state_context_message function."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user