Python: Include reasoning messages in MESSAGES_SNAPSHOT events (#4844)

* Include reasoning messages in MESSAGES_SNAPSHOT (#4843)

FlowState now tracks reasoning messages emitted during a run.
_emit_text_reasoning() persists reasoning (including encrypted_value)
into flow.reasoning_messages, and _build_messages_snapshot() appends
them to the final MESSAGES_SNAPSHOT event.

Changes:
- Add reasoning_messages field to FlowState
- Update _emit_text_reasoning() to accept optional flow parameter
- Include reasoning_messages in _build_messages_snapshot()
- Add 'reasoning' to ALLOWED_AGUI_ROLES so normalize_agui_role()
  preserves the role through snapshot round-trips
- Skip reasoning messages in agui_messages_to_agent_framework() since
  they are UI-only state and should not be forwarded to LLM providers
- Add regression tests for snapshot emission, encrypted value
  preservation, and multi-turn round-trip with reasoning

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Python: Include reasoning messages in MESSAGES_SNAPSHOT events

Fixes #4843

* Fix PR review feedback for reasoning persistence (#4843)

- Accumulate reasoning text per message_id (append deltas) instead of
  storing only the current chunk, matching flow.accumulated_text pattern
- Use camelCase encryptedValue in snapshot JSON to match AG-UI protocol
  conventions (toolCallId, encryptedValue)
- Normalize snake_case encrypted_value to encryptedValue in
  agui_messages_to_snapshot_format for input compatibility
- Update normalize_agui_role docstring to include reasoning role
- Add tests for incremental reasoning accumulation and key normalization

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Address review feedback for #4843: Python: agent-framework-ag-ui: include reasoning messages in MESSAGES_SNAPSHOT

---------

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Evan Mattson
2026-03-26 14:56:10 +09:00
committed by GitHub
Unverified
parent dc27740f1a
commit dd3d085539
7 changed files with 303 additions and 5 deletions
@@ -684,6 +684,10 @@ def _build_messages_snapshot(
}
)
# Add reasoning messages so frontends that reconcile state from
# MESSAGES_SNAPSHOT retain reasoning content after streaming ends.
all_messages.extend(flow.reasoning_messages)
return MessagesSnapshotEvent(messages=all_messages) # type: ignore[arg-type]
@@ -1061,7 +1065,9 @@ async def run_agent_stream(
# Emit MessagesSnapshotEvent if we have tool calls or results
# Feature #5: Suppress intermediate snapshots for predictive tools without confirmation
should_emit_snapshot = flow.pending_tool_calls or flow.tool_results or flow.accumulated_text
should_emit_snapshot = (
flow.pending_tool_calls or flow.tool_results or flow.accumulated_text or flow.reasoning_messages
)
if should_emit_snapshot:
# Check if we should suppress for predictive tool
last_tool_name = None
@@ -604,6 +604,10 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Mes
# Handle standard tool result messages early (role="tool") to preserve provider invariants
# This path maps AGUI tool messages to function_result content with the correct tool_call_id
role_str = normalize_agui_role(msg.get("role", "user"))
if role_str == "reasoning":
# Reasoning messages are UI-only state carried in MESSAGES_SNAPSHOT.
# They should not be forwarded to the LLM provider.
continue
if role_str == "tool":
# Prefer explicit tool_call_id fields; fall back to backend fields only if necessary
tool_call_id = msg.get("tool_call_id") or msg.get("toolCallId")
@@ -1020,6 +1024,11 @@ def agui_messages_to_snapshot_format(messages: list[dict[str, Any]]) -> list[dic
elif "toolCallId" not in normalized_msg:
normalized_msg["toolCallId"] = ""
# Normalize encrypted_value to encryptedValue for reasoning messages
if normalized_msg.get("role") == "reasoning" and "encrypted_value" in normalized_msg:
normalized_msg["encryptedValue"] = normalized_msg["encrypted_value"]
del normalized_msg["encrypted_value"]
result.append(normalized_msg)
return result
@@ -126,6 +126,8 @@ class FlowState:
tool_results: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
tool_calls_ended: set[str] = field(default_factory=set) # pyright: ignore[reportUnknownVariableType]
interrupts: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
reasoning_messages: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
accumulated_reasoning: dict[str, str] = field(default_factory=dict) # pyright: ignore[reportUnknownVariableType]
def get_tool_name(self, call_id: str | None) -> str | None:
"""Get tool name by call ID."""
@@ -460,7 +462,7 @@ def _emit_mcp_tool_result(
return _emit_tool_result_common(content.call_id, raw_output, flow, predictive_handler)
def _emit_text_reasoning(content: Content) -> list[BaseEvent]:
def _emit_text_reasoning(content: Content, flow: FlowState | None = None) -> list[BaseEvent]:
"""Emit AG-UI reasoning events for text_reasoning content.
Uses the protocol-defined reasoning event types so that AG-UI consumers
@@ -470,6 +472,10 @@ def _emit_text_reasoning(content: Content) -> list[BaseEvent]:
``content.protected_data`` is present it is emitted as a
``ReasoningEncryptedValueEvent`` so that consumers can persist encrypted
reasoning for state continuity without conflating it with display text.
When *flow* is provided the reasoning message is persisted into
``flow.reasoning_messages`` so that ``_build_messages_snapshot`` can
include it in the final ``MESSAGES_SNAPSHOT``.
"""
text = content.text or ""
if not text and content.protected_data is None:
@@ -498,6 +504,36 @@ def _emit_text_reasoning(content: Content) -> list[BaseEvent]:
events.append(ReasoningEndEvent(message_id=message_id))
# Persist reasoning into flow state for MESSAGES_SNAPSHOT.
# Accumulate reasoning text per message_id, similar to flow.accumulated_text,
# so that incremental deltas build the full reasoning string.
if flow is not None:
if text:
previous_text = flow.accumulated_reasoning.get(message_id, "")
flow.accumulated_reasoning[message_id] = previous_text + text
full_text = flow.accumulated_reasoning.get(message_id, text or "")
# Update existing reasoning entry for this message_id if present; otherwise append a new one.
existing_entry: dict[str, Any] | None = None
for entry in flow.reasoning_messages:
if isinstance(entry, dict) and entry.get("id") == message_id:
existing_entry = entry
break
if existing_entry is None:
reasoning_entry: dict[str, Any] = {
"id": message_id,
"role": "reasoning",
"content": full_text,
}
if content.protected_data is not None:
reasoning_entry["encryptedValue"] = content.protected_data
flow.reasoning_messages.append(reasoning_entry)
else:
existing_entry["content"] = full_text
if content.protected_data is not None:
existing_entry["encryptedValue"] = content.protected_data
return events
@@ -527,6 +563,6 @@ def _emit_content(
if content_type == "mcp_server_tool_result":
return _emit_mcp_tool_result(content, flow, predictive_handler)
if content_type == "text_reasoning":
return _emit_text_reasoning(content)
return _emit_text_reasoning(content, flow)
logger.debug("Skipping unsupported content type in AG-UI emitter: %s", content_type)
return []
@@ -27,7 +27,7 @@ FRAMEWORK_TO_AGUI_ROLE: dict[str, str] = {
"system": "system",
}
ALLOWED_AGUI_ROLES: set[str] = {"user", "assistant", "system", "tool"}
ALLOWED_AGUI_ROLES: set[str] = {"user", "assistant", "system", "tool", "reasoning"}
def generate_event_id() -> str:
@@ -82,7 +82,7 @@ def normalize_agui_role(raw_role: Any) -> str:
raw_role: Raw role value from AG-UI message
Returns:
Normalized role string (user, assistant, system, or tool)
Normalized role string (user, assistant, system, tool, or reasoning)
"""
if not isinstance(raw_role, str):
return "user"
@@ -1669,3 +1669,94 @@ def test_agui_fresh_approval_is_still_processed():
assert len(approval_contents) == 1, "Fresh approval should produce function_approval_response"
assert approval_contents[0].approved is True
assert approval_contents[0].function_call.name == "get_datetime"
class TestReasoningRoundTrip:
"""Tests for reasoning message handling in inbound/outbound adapters."""
def test_reasoning_skipped_on_inbound(self):
"""Reasoning messages from prior snapshot are not forwarded to the LLM."""
messages_input = [
{"id": "u1", "role": "user", "content": "Hello"},
{"id": "r1", "role": "reasoning", "content": "Thinking..."},
{"id": "a1", "role": "assistant", "content": "Hi there"},
]
result = agui_messages_to_agent_framework(messages_input)
roles = [m.role if hasattr(m.role, "value") else str(m.role) for m in result]
assert "reasoning" not in roles
assert len(result) == 2
def test_reasoning_preserved_in_snapshot_format(self):
"""Reasoning messages retain their role through snapshot normalization."""
messages_input = [
{"id": "u1", "role": "user", "content": "Hello"},
{"id": "r1", "role": "reasoning", "content": "Thinking about this..."},
{"id": "a1", "role": "assistant", "content": "Answer"},
]
result = agui_messages_to_snapshot_format(messages_input)
reasoning_msgs = [m for m in result if m.get("role") == "reasoning"]
assert len(reasoning_msgs) == 1
assert reasoning_msgs[0]["content"] == "Thinking about this..."
def test_reasoning_with_encrypted_value_in_snapshot_format(self):
"""Reasoning with encryptedValue passes through snapshot normalization."""
messages_input = [
{
"id": "r1",
"role": "reasoning",
"content": "visible",
"encryptedValue": "secret-data",
},
]
result = agui_messages_to_snapshot_format(messages_input)
assert len(result) == 1
assert result[0]["role"] == "reasoning"
assert result[0]["encryptedValue"] == "secret-data"
def test_reasoning_encrypted_value_snake_case_normalized(self):
"""Snake-case encrypted_value is normalized to encryptedValue in snapshot format."""
messages_input = [
{
"id": "r1",
"role": "reasoning",
"content": "visible",
"encrypted_value": "snake-case-data",
},
]
result = agui_messages_to_snapshot_format(messages_input)
assert len(result) == 1
assert result[0]["encryptedValue"] == "snake-case-data"
assert "encrypted_value" not in result[0]
def test_multi_turn_with_reasoning_in_prior_snapshot(self):
"""Second turn with reasoning from prior snapshot does not corrupt messages."""
messages_input = [
{"id": "u1", "role": "user", "content": "First question"},
{"id": "r1", "role": "reasoning", "content": "Prior reasoning"},
{"id": "a1", "role": "assistant", "content": "First answer"},
{"id": "u2", "role": "user", "content": "Follow-up question"},
]
result = agui_messages_to_agent_framework(messages_input)
roles = [m.role if hasattr(m.role, "value") else str(m.role) for m in result]
# Reasoning is filtered out, other messages preserved in order
assert roles == ["user", "assistant", "user"]
# Content not corrupted
texts = []
for m in result:
for c in m.contents or []:
if hasattr(c, "text") and c.text:
texts.append(c.text)
assert "First question" in texts
assert "First answer" in texts
assert "Follow-up question" in texts
assert "Prior reasoning" not in texts
@@ -1346,3 +1346,158 @@ class TestEmitContentMcpRouting:
assert len(events) == 5
assert isinstance(events[0], ReasoningStartEvent)
class TestReasoningInSnapshot:
"""Tests for reasoning message inclusion in MESSAGES_SNAPSHOT."""
def test_reasoning_persisted_to_flow_state(self):
"""_emit_text_reasoning with flow persists reasoning into flow.reasoning_messages."""
flow = FlowState()
content = Content.from_text_reasoning(
id="reason_persist",
text="Let me think step by step.",
)
_emit_text_reasoning(content, flow)
assert len(flow.reasoning_messages) == 1
assert flow.reasoning_messages[0]["id"] == "reason_persist"
assert flow.reasoning_messages[0]["role"] == "reasoning"
assert flow.reasoning_messages[0]["content"] == "Let me think step by step."
assert "encryptedValue" not in flow.reasoning_messages[0]
def test_reasoning_with_encrypted_value_persisted(self):
"""Reasoning with protected_data preserves encryptedValue in flow state."""
flow = FlowState()
content = Content.from_text_reasoning(
id="reason_enc",
text="visible reasoning",
protected_data="encrypted-data-123",
)
_emit_text_reasoning(content, flow)
assert len(flow.reasoning_messages) == 1
assert flow.reasoning_messages[0]["encryptedValue"] == "encrypted-data-123"
def test_snapshot_includes_reasoning(self):
"""_build_messages_snapshot includes reasoning messages from flow state."""
from agent_framework_ag_ui._agent_run import _build_messages_snapshot
flow = FlowState()
flow.accumulated_text = "Here is my answer."
flow.reasoning_messages = [
{"id": "r1", "role": "reasoning", "content": "Thinking..."},
]
snapshot = _build_messages_snapshot(flow, [])
roles = [m.get("role") if isinstance(m, dict) else getattr(m, "role", None) for m in snapshot.messages]
assert "reasoning" in roles
def test_snapshot_preserves_reasoning_encrypted_value(self):
"""Snapshot reasoning with encryptedValue is preserved end-to-end."""
from agent_framework_ag_ui._agent_run import _build_messages_snapshot
flow = FlowState()
content = Content.from_text_reasoning(
id="reason_e2e",
text="visible",
protected_data="secret-data",
)
_emit_text_reasoning(content, flow)
text_content = Content.from_text("Final answer.")
_emit_text(text_content, flow)
snapshot = _build_messages_snapshot(flow, [])
reasoning_msgs = [
m
for m in snapshot.messages
if (m.get("role") if isinstance(m, dict) else getattr(m, "role", None)) == "reasoning"
]
assert len(reasoning_msgs) == 1
msg = reasoning_msgs[0]
if isinstance(msg, dict):
assert msg["content"] == "visible"
assert msg["encryptedValue"] == "secret-data"
def test_emit_content_routes_reasoning_with_flow(self):
"""_emit_content passes flow to _emit_text_reasoning for persistence."""
flow = FlowState()
content = Content.from_text_reasoning(text="routed reasoning")
_emit_content(content, flow)
assert len(flow.reasoning_messages) == 1
assert flow.reasoning_messages[0]["content"] == "routed reasoning"
def test_reasoning_without_flow_does_not_error(self):
"""Calling _emit_text_reasoning without flow still works (backward compat)."""
content = Content.from_text_reasoning(text="no flow")
events = _emit_text_reasoning(content)
assert len(events) == 5
assert isinstance(events[0], ReasoningStartEvent)
def test_snapshot_reasoning_ordering(self):
"""Reasoning messages appear after assistant text in snapshot."""
from agent_framework_ag_ui._agent_run import _build_messages_snapshot
flow = FlowState()
reasoning_content = Content.from_text_reasoning(id="r1", text="Thinking...")
_emit_text_reasoning(reasoning_content, flow)
text_content = Content.from_text("Answer")
_emit_text(text_content, flow)
snapshot = _build_messages_snapshot(flow, [{"id": "u1", "role": "user", "content": "Hi"}])
# user -> assistant text -> reasoning
assert len(snapshot.messages) == 3
roles = [m.get("role") if isinstance(m, dict) else getattr(m, "role", None) for m in snapshot.messages]
assert roles == ["user", "assistant", "reasoning"]
def test_reasoning_accumulates_incremental_deltas(self):
"""Multiple reasoning deltas with the same id accumulate into one entry."""
flow = FlowState()
content1 = Content.from_text_reasoning(id="reason_inc", text="First ")
content2 = Content.from_text_reasoning(id="reason_inc", text="second ")
content3 = Content.from_text_reasoning(id="reason_inc", text="third.")
_emit_text_reasoning(content1, flow)
_emit_text_reasoning(content2, flow)
_emit_text_reasoning(content3, flow)
assert len(flow.reasoning_messages) == 1
assert flow.reasoning_messages[0]["id"] == "reason_inc"
assert flow.reasoning_messages[0]["content"] == "First second third."
def test_reasoning_accumulates_distinct_message_ids(self):
"""Reasoning entries with different ids are stored separately."""
flow = FlowState()
content_a = Content.from_text_reasoning(id="a", text="alpha")
content_b = Content.from_text_reasoning(id="b", text="beta")
_emit_text_reasoning(content_a, flow)
_emit_text_reasoning(content_b, flow)
assert len(flow.reasoning_messages) == 2
assert flow.reasoning_messages[0]["content"] == "alpha"
assert flow.reasoning_messages[1]["content"] == "beta"
def test_reasoning_encrypted_value_updated_on_later_delta(self):
"""encryptedValue is set even when it arrives with a later delta."""
flow = FlowState()
content1 = Content.from_text_reasoning(id="enc_late", text="part1 ")
content2 = Content.from_text_reasoning(id="enc_late", text="part2", protected_data="encrypted-payload")
_emit_text_reasoning(content1, flow)
_emit_text_reasoning(content2, flow)
assert len(flow.reasoning_messages) == 1
assert flow.reasoning_messages[0]["content"] == "part1 part2"
assert flow.reasoning_messages[0]["encryptedValue"] == "encrypted-payload"
@@ -450,6 +450,7 @@ def test_normalize_agui_role_valid():
assert normalize_agui_role("assistant") == "assistant"
assert normalize_agui_role("system") == "system"
assert normalize_agui_role("tool") == "tool"
assert normalize_agui_role("reasoning") == "reasoning"
def test_normalize_agui_role_invalid():