mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Include reasoning messages in MESSAGES_SNAPSHOT events (#4844)
* Include reasoning messages in MESSAGES_SNAPSHOT (#4843) FlowState now tracks reasoning messages emitted during a run. _emit_text_reasoning() persists reasoning (including encrypted_value) into flow.reasoning_messages, and _build_messages_snapshot() appends them to the final MESSAGES_SNAPSHOT event. Changes: - Add reasoning_messages field to FlowState - Update _emit_text_reasoning() to accept optional flow parameter - Include reasoning_messages in _build_messages_snapshot() - Add 'reasoning' to ALLOWED_AGUI_ROLES so normalize_agui_role() preserves the role through snapshot round-trips - Skip reasoning messages in agui_messages_to_agent_framework() since they are UI-only state and should not be forwarded to LLM providers - Add regression tests for snapshot emission, encrypted value preservation, and multi-turn round-trip with reasoning Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Python: Include reasoning messages in MESSAGES_SNAPSHOT events Fixes #4843 * Fix PR review feedback for reasoning persistence (#4843) - Accumulate reasoning text per message_id (append deltas) instead of storing only the current chunk, matching flow.accumulated_text pattern - Use camelCase encryptedValue in snapshot JSON to match AG-UI protocol conventions (toolCallId, encryptedValue) - Normalize snake_case encrypted_value to encryptedValue in agui_messages_to_snapshot_format for input compatibility - Update normalize_agui_role docstring to include reasoning role - Add tests for incremental reasoning accumulation and key normalization Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #4843: Python: agent-framework-ag-ui: include reasoning messages in MESSAGES_SNAPSHOT --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
dc27740f1a
commit
dd3d085539
@@ -684,6 +684,10 @@ def _build_messages_snapshot(
|
||||
}
|
||||
)
|
||||
|
||||
# Add reasoning messages so frontends that reconcile state from
|
||||
# MESSAGES_SNAPSHOT retain reasoning content after streaming ends.
|
||||
all_messages.extend(flow.reasoning_messages)
|
||||
|
||||
return MessagesSnapshotEvent(messages=all_messages) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@@ -1061,7 +1065,9 @@ async def run_agent_stream(
|
||||
|
||||
# Emit MessagesSnapshotEvent if we have tool calls or results
|
||||
# Feature #5: Suppress intermediate snapshots for predictive tools without confirmation
|
||||
should_emit_snapshot = flow.pending_tool_calls or flow.tool_results or flow.accumulated_text
|
||||
should_emit_snapshot = (
|
||||
flow.pending_tool_calls or flow.tool_results or flow.accumulated_text or flow.reasoning_messages
|
||||
)
|
||||
if should_emit_snapshot:
|
||||
# Check if we should suppress for predictive tool
|
||||
last_tool_name = None
|
||||
|
||||
@@ -604,6 +604,10 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Mes
|
||||
# Handle standard tool result messages early (role="tool") to preserve provider invariants
|
||||
# This path maps AG‑UI tool messages to function_result content with the correct tool_call_id
|
||||
role_str = normalize_agui_role(msg.get("role", "user"))
|
||||
if role_str == "reasoning":
|
||||
# Reasoning messages are UI-only state carried in MESSAGES_SNAPSHOT.
|
||||
# They should not be forwarded to the LLM provider.
|
||||
continue
|
||||
if role_str == "tool":
|
||||
# Prefer explicit tool_call_id fields; fall back to backend fields only if necessary
|
||||
tool_call_id = msg.get("tool_call_id") or msg.get("toolCallId")
|
||||
@@ -1020,6 +1024,11 @@ def agui_messages_to_snapshot_format(messages: list[dict[str, Any]]) -> list[dic
|
||||
elif "toolCallId" not in normalized_msg:
|
||||
normalized_msg["toolCallId"] = ""
|
||||
|
||||
# Normalize encrypted_value to encryptedValue for reasoning messages
|
||||
if normalized_msg.get("role") == "reasoning" and "encrypted_value" in normalized_msg:
|
||||
normalized_msg["encryptedValue"] = normalized_msg["encrypted_value"]
|
||||
del normalized_msg["encrypted_value"]
|
||||
|
||||
result.append(normalized_msg)
|
||||
|
||||
return result
|
||||
|
||||
@@ -126,6 +126,8 @@ class FlowState:
|
||||
tool_results: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
|
||||
tool_calls_ended: set[str] = field(default_factory=set) # pyright: ignore[reportUnknownVariableType]
|
||||
interrupts: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
|
||||
reasoning_messages: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
|
||||
accumulated_reasoning: dict[str, str] = field(default_factory=dict) # pyright: ignore[reportUnknownVariableType]
|
||||
|
||||
def get_tool_name(self, call_id: str | None) -> str | None:
|
||||
"""Get tool name by call ID."""
|
||||
@@ -460,7 +462,7 @@ def _emit_mcp_tool_result(
|
||||
return _emit_tool_result_common(content.call_id, raw_output, flow, predictive_handler)
|
||||
|
||||
|
||||
def _emit_text_reasoning(content: Content) -> list[BaseEvent]:
|
||||
def _emit_text_reasoning(content: Content, flow: FlowState | None = None) -> list[BaseEvent]:
|
||||
"""Emit AG-UI reasoning events for text_reasoning content.
|
||||
|
||||
Uses the protocol-defined reasoning event types so that AG-UI consumers
|
||||
@@ -470,6 +472,10 @@ def _emit_text_reasoning(content: Content) -> list[BaseEvent]:
|
||||
``content.protected_data`` is present it is emitted as a
|
||||
``ReasoningEncryptedValueEvent`` so that consumers can persist encrypted
|
||||
reasoning for state continuity without conflating it with display text.
|
||||
|
||||
When *flow* is provided the reasoning message is persisted into
|
||||
``flow.reasoning_messages`` so that ``_build_messages_snapshot`` can
|
||||
include it in the final ``MESSAGES_SNAPSHOT``.
|
||||
"""
|
||||
text = content.text or ""
|
||||
if not text and content.protected_data is None:
|
||||
@@ -498,6 +504,36 @@ def _emit_text_reasoning(content: Content) -> list[BaseEvent]:
|
||||
|
||||
events.append(ReasoningEndEvent(message_id=message_id))
|
||||
|
||||
# Persist reasoning into flow state for MESSAGES_SNAPSHOT.
|
||||
# Accumulate reasoning text per message_id, similar to flow.accumulated_text,
|
||||
# so that incremental deltas build the full reasoning string.
|
||||
if flow is not None:
|
||||
if text:
|
||||
previous_text = flow.accumulated_reasoning.get(message_id, "")
|
||||
flow.accumulated_reasoning[message_id] = previous_text + text
|
||||
full_text = flow.accumulated_reasoning.get(message_id, text or "")
|
||||
|
||||
# Update existing reasoning entry for this message_id if present; otherwise append a new one.
|
||||
existing_entry: dict[str, Any] | None = None
|
||||
for entry in flow.reasoning_messages:
|
||||
if isinstance(entry, dict) and entry.get("id") == message_id:
|
||||
existing_entry = entry
|
||||
break
|
||||
|
||||
if existing_entry is None:
|
||||
reasoning_entry: dict[str, Any] = {
|
||||
"id": message_id,
|
||||
"role": "reasoning",
|
||||
"content": full_text,
|
||||
}
|
||||
if content.protected_data is not None:
|
||||
reasoning_entry["encryptedValue"] = content.protected_data
|
||||
flow.reasoning_messages.append(reasoning_entry)
|
||||
else:
|
||||
existing_entry["content"] = full_text
|
||||
if content.protected_data is not None:
|
||||
existing_entry["encryptedValue"] = content.protected_data
|
||||
|
||||
return events
|
||||
|
||||
|
||||
@@ -527,6 +563,6 @@ def _emit_content(
|
||||
if content_type == "mcp_server_tool_result":
|
||||
return _emit_mcp_tool_result(content, flow, predictive_handler)
|
||||
if content_type == "text_reasoning":
|
||||
return _emit_text_reasoning(content)
|
||||
return _emit_text_reasoning(content, flow)
|
||||
logger.debug("Skipping unsupported content type in AG-UI emitter: %s", content_type)
|
||||
return []
|
||||
|
||||
@@ -27,7 +27,7 @@ FRAMEWORK_TO_AGUI_ROLE: dict[str, str] = {
|
||||
"system": "system",
|
||||
}
|
||||
|
||||
ALLOWED_AGUI_ROLES: set[str] = {"user", "assistant", "system", "tool"}
|
||||
ALLOWED_AGUI_ROLES: set[str] = {"user", "assistant", "system", "tool", "reasoning"}
|
||||
|
||||
|
||||
def generate_event_id() -> str:
|
||||
@@ -82,7 +82,7 @@ def normalize_agui_role(raw_role: Any) -> str:
|
||||
raw_role: Raw role value from AG-UI message
|
||||
|
||||
Returns:
|
||||
Normalized role string (user, assistant, system, or tool)
|
||||
Normalized role string (user, assistant, system, tool, or reasoning)
|
||||
"""
|
||||
if not isinstance(raw_role, str):
|
||||
return "user"
|
||||
|
||||
@@ -1669,3 +1669,94 @@ def test_agui_fresh_approval_is_still_processed():
|
||||
assert len(approval_contents) == 1, "Fresh approval should produce function_approval_response"
|
||||
assert approval_contents[0].approved is True
|
||||
assert approval_contents[0].function_call.name == "get_datetime"
|
||||
|
||||
|
||||
class TestReasoningRoundTrip:
|
||||
"""Tests for reasoning message handling in inbound/outbound adapters."""
|
||||
|
||||
def test_reasoning_skipped_on_inbound(self):
|
||||
"""Reasoning messages from prior snapshot are not forwarded to the LLM."""
|
||||
messages_input = [
|
||||
{"id": "u1", "role": "user", "content": "Hello"},
|
||||
{"id": "r1", "role": "reasoning", "content": "Thinking..."},
|
||||
{"id": "a1", "role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
|
||||
result = agui_messages_to_agent_framework(messages_input)
|
||||
|
||||
roles = [m.role if hasattr(m.role, "value") else str(m.role) for m in result]
|
||||
assert "reasoning" not in roles
|
||||
assert len(result) == 2
|
||||
|
||||
def test_reasoning_preserved_in_snapshot_format(self):
|
||||
"""Reasoning messages retain their role through snapshot normalization."""
|
||||
messages_input = [
|
||||
{"id": "u1", "role": "user", "content": "Hello"},
|
||||
{"id": "r1", "role": "reasoning", "content": "Thinking about this..."},
|
||||
{"id": "a1", "role": "assistant", "content": "Answer"},
|
||||
]
|
||||
|
||||
result = agui_messages_to_snapshot_format(messages_input)
|
||||
|
||||
reasoning_msgs = [m for m in result if m.get("role") == "reasoning"]
|
||||
assert len(reasoning_msgs) == 1
|
||||
assert reasoning_msgs[0]["content"] == "Thinking about this..."
|
||||
|
||||
def test_reasoning_with_encrypted_value_in_snapshot_format(self):
|
||||
"""Reasoning with encryptedValue passes through snapshot normalization."""
|
||||
messages_input = [
|
||||
{
|
||||
"id": "r1",
|
||||
"role": "reasoning",
|
||||
"content": "visible",
|
||||
"encryptedValue": "secret-data",
|
||||
},
|
||||
]
|
||||
|
||||
result = agui_messages_to_snapshot_format(messages_input)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "reasoning"
|
||||
assert result[0]["encryptedValue"] == "secret-data"
|
||||
|
||||
def test_reasoning_encrypted_value_snake_case_normalized(self):
|
||||
"""Snake-case encrypted_value is normalized to encryptedValue in snapshot format."""
|
||||
messages_input = [
|
||||
{
|
||||
"id": "r1",
|
||||
"role": "reasoning",
|
||||
"content": "visible",
|
||||
"encrypted_value": "snake-case-data",
|
||||
},
|
||||
]
|
||||
|
||||
result = agui_messages_to_snapshot_format(messages_input)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["encryptedValue"] == "snake-case-data"
|
||||
assert "encrypted_value" not in result[0]
|
||||
|
||||
def test_multi_turn_with_reasoning_in_prior_snapshot(self):
|
||||
"""Second turn with reasoning from prior snapshot does not corrupt messages."""
|
||||
messages_input = [
|
||||
{"id": "u1", "role": "user", "content": "First question"},
|
||||
{"id": "r1", "role": "reasoning", "content": "Prior reasoning"},
|
||||
{"id": "a1", "role": "assistant", "content": "First answer"},
|
||||
{"id": "u2", "role": "user", "content": "Follow-up question"},
|
||||
]
|
||||
|
||||
result = agui_messages_to_agent_framework(messages_input)
|
||||
|
||||
roles = [m.role if hasattr(m.role, "value") else str(m.role) for m in result]
|
||||
# Reasoning is filtered out, other messages preserved in order
|
||||
assert roles == ["user", "assistant", "user"]
|
||||
# Content not corrupted
|
||||
texts = []
|
||||
for m in result:
|
||||
for c in m.contents or []:
|
||||
if hasattr(c, "text") and c.text:
|
||||
texts.append(c.text)
|
||||
assert "First question" in texts
|
||||
assert "First answer" in texts
|
||||
assert "Follow-up question" in texts
|
||||
assert "Prior reasoning" not in texts
|
||||
|
||||
@@ -1346,3 +1346,158 @@ class TestEmitContentMcpRouting:
|
||||
|
||||
assert len(events) == 5
|
||||
assert isinstance(events[0], ReasoningStartEvent)
|
||||
|
||||
|
||||
class TestReasoningInSnapshot:
|
||||
"""Tests for reasoning message inclusion in MESSAGES_SNAPSHOT."""
|
||||
|
||||
def test_reasoning_persisted_to_flow_state(self):
|
||||
"""_emit_text_reasoning with flow persists reasoning into flow.reasoning_messages."""
|
||||
flow = FlowState()
|
||||
content = Content.from_text_reasoning(
|
||||
id="reason_persist",
|
||||
text="Let me think step by step.",
|
||||
)
|
||||
|
||||
_emit_text_reasoning(content, flow)
|
||||
|
||||
assert len(flow.reasoning_messages) == 1
|
||||
assert flow.reasoning_messages[0]["id"] == "reason_persist"
|
||||
assert flow.reasoning_messages[0]["role"] == "reasoning"
|
||||
assert flow.reasoning_messages[0]["content"] == "Let me think step by step."
|
||||
assert "encryptedValue" not in flow.reasoning_messages[0]
|
||||
|
||||
def test_reasoning_with_encrypted_value_persisted(self):
|
||||
"""Reasoning with protected_data preserves encryptedValue in flow state."""
|
||||
flow = FlowState()
|
||||
content = Content.from_text_reasoning(
|
||||
id="reason_enc",
|
||||
text="visible reasoning",
|
||||
protected_data="encrypted-data-123",
|
||||
)
|
||||
|
||||
_emit_text_reasoning(content, flow)
|
||||
|
||||
assert len(flow.reasoning_messages) == 1
|
||||
assert flow.reasoning_messages[0]["encryptedValue"] == "encrypted-data-123"
|
||||
|
||||
def test_snapshot_includes_reasoning(self):
|
||||
"""_build_messages_snapshot includes reasoning messages from flow state."""
|
||||
from agent_framework_ag_ui._agent_run import _build_messages_snapshot
|
||||
|
||||
flow = FlowState()
|
||||
flow.accumulated_text = "Here is my answer."
|
||||
flow.reasoning_messages = [
|
||||
{"id": "r1", "role": "reasoning", "content": "Thinking..."},
|
||||
]
|
||||
|
||||
snapshot = _build_messages_snapshot(flow, [])
|
||||
|
||||
roles = [m.get("role") if isinstance(m, dict) else getattr(m, "role", None) for m in snapshot.messages]
|
||||
assert "reasoning" in roles
|
||||
|
||||
def test_snapshot_preserves_reasoning_encrypted_value(self):
|
||||
"""Snapshot reasoning with encryptedValue is preserved end-to-end."""
|
||||
from agent_framework_ag_ui._agent_run import _build_messages_snapshot
|
||||
|
||||
flow = FlowState()
|
||||
content = Content.from_text_reasoning(
|
||||
id="reason_e2e",
|
||||
text="visible",
|
||||
protected_data="secret-data",
|
||||
)
|
||||
_emit_text_reasoning(content, flow)
|
||||
|
||||
text_content = Content.from_text("Final answer.")
|
||||
_emit_text(text_content, flow)
|
||||
|
||||
snapshot = _build_messages_snapshot(flow, [])
|
||||
|
||||
reasoning_msgs = [
|
||||
m
|
||||
for m in snapshot.messages
|
||||
if (m.get("role") if isinstance(m, dict) else getattr(m, "role", None)) == "reasoning"
|
||||
]
|
||||
assert len(reasoning_msgs) == 1
|
||||
msg = reasoning_msgs[0]
|
||||
if isinstance(msg, dict):
|
||||
assert msg["content"] == "visible"
|
||||
assert msg["encryptedValue"] == "secret-data"
|
||||
|
||||
def test_emit_content_routes_reasoning_with_flow(self):
|
||||
"""_emit_content passes flow to _emit_text_reasoning for persistence."""
|
||||
flow = FlowState()
|
||||
content = Content.from_text_reasoning(text="routed reasoning")
|
||||
|
||||
_emit_content(content, flow)
|
||||
|
||||
assert len(flow.reasoning_messages) == 1
|
||||
assert flow.reasoning_messages[0]["content"] == "routed reasoning"
|
||||
|
||||
def test_reasoning_without_flow_does_not_error(self):
|
||||
"""Calling _emit_text_reasoning without flow still works (backward compat)."""
|
||||
content = Content.from_text_reasoning(text="no flow")
|
||||
|
||||
events = _emit_text_reasoning(content)
|
||||
|
||||
assert len(events) == 5
|
||||
assert isinstance(events[0], ReasoningStartEvent)
|
||||
|
||||
def test_snapshot_reasoning_ordering(self):
|
||||
"""Reasoning messages appear after assistant text in snapshot."""
|
||||
from agent_framework_ag_ui._agent_run import _build_messages_snapshot
|
||||
|
||||
flow = FlowState()
|
||||
reasoning_content = Content.from_text_reasoning(id="r1", text="Thinking...")
|
||||
_emit_text_reasoning(reasoning_content, flow)
|
||||
|
||||
text_content = Content.from_text("Answer")
|
||||
_emit_text(text_content, flow)
|
||||
|
||||
snapshot = _build_messages_snapshot(flow, [{"id": "u1", "role": "user", "content": "Hi"}])
|
||||
|
||||
# user -> assistant text -> reasoning
|
||||
assert len(snapshot.messages) == 3
|
||||
roles = [m.get("role") if isinstance(m, dict) else getattr(m, "role", None) for m in snapshot.messages]
|
||||
assert roles == ["user", "assistant", "reasoning"]
|
||||
|
||||
def test_reasoning_accumulates_incremental_deltas(self):
|
||||
"""Multiple reasoning deltas with the same id accumulate into one entry."""
|
||||
flow = FlowState()
|
||||
content1 = Content.from_text_reasoning(id="reason_inc", text="First ")
|
||||
content2 = Content.from_text_reasoning(id="reason_inc", text="second ")
|
||||
content3 = Content.from_text_reasoning(id="reason_inc", text="third.")
|
||||
|
||||
_emit_text_reasoning(content1, flow)
|
||||
_emit_text_reasoning(content2, flow)
|
||||
_emit_text_reasoning(content3, flow)
|
||||
|
||||
assert len(flow.reasoning_messages) == 1
|
||||
assert flow.reasoning_messages[0]["id"] == "reason_inc"
|
||||
assert flow.reasoning_messages[0]["content"] == "First second third."
|
||||
|
||||
def test_reasoning_accumulates_distinct_message_ids(self):
|
||||
"""Reasoning entries with different ids are stored separately."""
|
||||
flow = FlowState()
|
||||
content_a = Content.from_text_reasoning(id="a", text="alpha")
|
||||
content_b = Content.from_text_reasoning(id="b", text="beta")
|
||||
|
||||
_emit_text_reasoning(content_a, flow)
|
||||
_emit_text_reasoning(content_b, flow)
|
||||
|
||||
assert len(flow.reasoning_messages) == 2
|
||||
assert flow.reasoning_messages[0]["content"] == "alpha"
|
||||
assert flow.reasoning_messages[1]["content"] == "beta"
|
||||
|
||||
def test_reasoning_encrypted_value_updated_on_later_delta(self):
|
||||
"""encryptedValue is set even when it arrives with a later delta."""
|
||||
flow = FlowState()
|
||||
content1 = Content.from_text_reasoning(id="enc_late", text="part1 ")
|
||||
content2 = Content.from_text_reasoning(id="enc_late", text="part2", protected_data="encrypted-payload")
|
||||
|
||||
_emit_text_reasoning(content1, flow)
|
||||
_emit_text_reasoning(content2, flow)
|
||||
|
||||
assert len(flow.reasoning_messages) == 1
|
||||
assert flow.reasoning_messages[0]["content"] == "part1 part2"
|
||||
assert flow.reasoning_messages[0]["encryptedValue"] == "encrypted-payload"
|
||||
|
||||
@@ -450,6 +450,7 @@ def test_normalize_agui_role_valid():
|
||||
assert normalize_agui_role("assistant") == "assistant"
|
||||
assert normalize_agui_role("system") == "system"
|
||||
assert normalize_agui_role("tool") == "tool"
|
||||
assert normalize_agui_role("reasoning") == "reasoning"
|
||||
|
||||
|
||||
def test_normalize_agui_role_invalid():
|
||||
|
||||
Reference in New Issue
Block a user