Python: AG-UI deterministic state updates from tool results (#5201)

* AG-UI deterministic state updates from tool results

* fix(ag-ui): address PR #5201 review comments

1. Add missing AGUIEventConverter, AGUIHttpService, __version__ to
   _IMPORTS in core ag_ui lazy-export list to match the .pyi stub.

2. Coalesce predictive and deterministic state snapshots into a single
   StateSnapshotEvent when both mechanisms are active on the same tool
   result, reducing redundant snapshot traffic.

3. Update state_update() docstring to clarify that a predictive snapshot
   may be emitted before the deterministic one when predict_state_config
   is active.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Evan Mattson
2026-04-14 13:58:09 +09:00
committed by GitHub
Unverified
parent 3c31ac28b5
commit f183f888a3
11 changed files with 878 additions and 4 deletions
@@ -9,6 +9,7 @@ from ._client import AGUIChatClient
from ._endpoint import add_agent_framework_fastapi_endpoint
from ._event_converters import AGUIEventConverter
from ._http_service import AGUIHttpService
from ._state import state_update
from ._types import AgentState, AGUIChatOptions, AGUIRequest, PredictStateConfig, RunMetadata
from ._workflow import AgentFrameworkWorkflow, WorkflowFactory
@@ -34,5 +35,6 @@ __all__ = [
"PredictStateConfig",
"RunMetadata",
"DEFAULT_TAGS",
"state_update",
"__version__",
]
@@ -6,6 +6,7 @@ from __future__ import annotations
import json
import logging
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Any, cast
@@ -31,6 +32,7 @@ from ag_ui.core import (
from agent_framework import Content
from ._orchestration._predictive_state import PredictiveStateHandler
from ._state import TOOL_RESULT_STATE_KEY
from ._utils import generate_event_id, make_json_safe
logger = logging.getLogger(__name__)
@@ -233,16 +235,66 @@ def _emit_tool_call(
return events
def _extract_tool_result_state(content: Content) -> dict[str, Any] | None:
"""Extract a deterministic AG-UI state update from a tool-result ``Content``.
Tools using :func:`agent_framework_ag_ui.state_update` carry the state
payload in ``additional_properties[TOOL_RESULT_STATE_KEY]`` on the inner
text item produced by ``parse_result``. We also check the outer
function_result content's ``additional_properties`` for robustness.
If multiple items carry state, they are merged in order so later items
override earlier ones (plain ``dict.update`` semantics).
Returns:
The merged state dict to apply, or ``None`` if no state update is
present.
"""
merged: dict[str, Any] | None = None
outer_ap = getattr(content, "additional_properties", None) or {}
outer_state = outer_ap.get(TOOL_RESULT_STATE_KEY)
if isinstance(outer_state, dict):
merged = dict(outer_state)
for item in content.items or ():
item_ap = getattr(item, "additional_properties", None) or {}
item_state = item_ap.get(TOOL_RESULT_STATE_KEY)
if isinstance(item_state, dict):
if merged is None:
merged = dict(item_state)
else:
merged.update(item_state)
return merged
def _emit_tool_result_common(
call_id: str,
raw_result: Any,
flow: FlowState,
predictive_handler: PredictiveStateHandler | None = None,
*,
state_update: Mapping[str, Any] | None = None,
) -> list[BaseEvent]:
"""Shared helper for emitting ToolCallEnd + ToolCallResult events and performing FlowState cleanup.
Both ``_emit_tool_result`` (standard function results) and ``_emit_mcp_tool_result``
(MCP server tool results) delegate to this function.
Args:
call_id: Tool call identifier.
raw_result: The stringified tool result content sent back to the LLM.
flow: Current ``FlowState``.
predictive_handler: Optional predictive state handler driven by
``predict_state_config``.
state_update: Optional deterministic state snapshot produced by a tool
returning :func:`agent_framework_ag_ui.state_update`. When present,
it is merged into ``flow.current_state`` and a ``StateSnapshotEvent``
is emitted after the ``ToolCallResult`` event. When both
``predictive_handler`` and ``state_update`` are active, predictive
updates are applied first, then the deterministic merge, and a
single coalesced ``StateSnapshotEvent`` is emitted.
"""
events: list[BaseEvent] = []
@@ -271,8 +323,18 @@ def _emit_tool_result_common(
if predictive_handler:
predictive_handler.apply_pending_updates()
if flow.current_state:
events.append(StateSnapshotEvent(snapshot=flow.current_state))
if state_update:
flow.current_state.update(state_update)
logger.debug(
"Emitted deterministic tool-result StateSnapshotEvent for call_id=%s (keys=%s)",
call_id,
list(state_update.keys()),
)
# Emit a single coalesced snapshot when either mechanism updated state.
if (predictive_handler or state_update) and flow.current_state:
events.append(StateSnapshotEvent(snapshot=flow.current_state))
flow.tool_call_id = None
flow.tool_call_name = None
@@ -295,7 +357,14 @@ def _emit_tool_result(
if not content.call_id:
return []
raw_result = content.result if content.result is not None else ""
return _emit_tool_result_common(content.call_id, raw_result, flow, predictive_handler)
state_update = _extract_tool_result_state(content)
return _emit_tool_result_common(
content.call_id,
raw_result,
flow,
predictive_handler,
state_update=state_update,
)
def _emit_approval_request(
@@ -460,7 +529,14 @@ def _emit_mcp_tool_result(
logger.warning("MCP tool result content missing call_id, skipping")
return []
raw_output = content.output if content.output is not None else ""
return _emit_tool_result_common(content.call_id, raw_output, flow, predictive_handler)
state_update = _extract_tool_result_state(content)
return _emit_tool_result_common(
content.call_id,
raw_output,
flow,
predictive_handler,
state_update=state_update,
)
def _close_reasoning_block(flow: FlowState) -> list[BaseEvent]:
@@ -0,0 +1,84 @@
# Copyright (c) Microsoft. All rights reserved.
"""Deterministic tool-driven AG-UI state updates.
Tools wired into the :mod:`agent_framework_ag_ui` endpoint can push a
deterministic state update by returning :func:`state_update`. Unlike
``predict_state_config`` — which emits ``StateDeltaEvent``s optimistically from
LLM-predicted tool call arguments — ``state_update`` runs *after* the tool
executes, so the AG-UI state always reflects the tool's actual return value.
See issue https://github.com/microsoft/agent-framework/issues/3167 for the
motivating discussion.
"""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
from agent_framework import Content
__all__ = ["TOOL_RESULT_STATE_KEY", "state_update"]
TOOL_RESULT_STATE_KEY = "__ag_ui_tool_result_state__"
"""Reserved ``Content.additional_properties`` key used to carry a tool-driven
state snapshot from a tool return value through to the AG-UI emitter."""
def state_update(
text: str = "",
*,
state: Mapping[str, Any],
) -> Content:
"""Build a tool return value that deterministically updates AG-UI shared state.
Return the result of this helper from an agent tool to push a state update
to AG-UI clients using the actual tool output, rather than LLM-predicted
tool arguments.
When the AG-UI endpoint emits the tool result, it will:
* Forward ``text`` to the LLM as the normal ``function_result`` content.
* Merge ``state`` into ``FlowState.current_state``.
* Emit a deterministic ``StateSnapshotEvent`` after the ``ToolCallResult``
event so frontends observe the updated state deterministically. If
predictive state is enabled, a predictive snapshot may be emitted first.
Example:
.. code-block:: python
from agent_framework import tool
from agent_framework_ag_ui import state_update
@tool
async def get_weather(city: str) -> Content:
data = await _fetch_weather(city)
return state_update(
text=f"Weather in {city}: {data['temp']}°C {data['conditions']}",
state={"weather": {"city": city, **data}},
)
Args:
text: Text passed back to the LLM as the ``function_result`` content.
Defaults to an empty string for tools whose only output is a state
update.
state: A mapping merged into the AG-UI shared state via JSON-compatible
``dict.update`` semantics. Nested dicts are replaced, not deep-merged.
Returns:
A ``Content`` object with ``type="text"``. The state payload rides in
``additional_properties`` under :data:`TOOL_RESULT_STATE_KEY` and is
extracted by the AG-UI emitter.
Raises:
TypeError: If ``state`` is not a ``Mapping``.
"""
if not isinstance(state, Mapping):
raise TypeError(f"state_update() 'state' must be a Mapping, got {type(state).__name__}")
return Content.from_text(
text,
additional_properties={TOOL_RESULT_STATE_KEY: dict(state)},
)
@@ -0,0 +1,92 @@
# Copyright (c) Microsoft. All rights reserved.
"""Deterministic tool-driven AG-UI state example.
This sample demonstrates how a tool can push a *deterministic* state update
to the AG-UI frontend based on its actual return value — in contrast to
``predict_state_config`` which fires optimistically from LLM-predicted tool
call arguments. See issue https://github.com/microsoft/agent-framework/issues/3167.
The :func:`agent_framework_ag_ui.state_update` helper wraps a text result
together with a state snapshot. When a tool returns one of these, the AG-UI
endpoint merges the snapshot into the shared state and emits a
``StateSnapshotEvent`` after the tool result.
"""
from __future__ import annotations
from typing import Any
from agent_framework import Agent, Content, SupportsChatGetResponse, tool
from agent_framework.ag_ui import AgentFrameworkAgent
from agent_framework_ag_ui import state_update
# Simulated weather database — in the issue's motivating example the tool
# would instead call a real weather API.
_WEATHER_DB: dict[str, dict[str, Any]] = {
"seattle": {"temperature": 11, "conditions": "rainy", "humidity": 75},
"san francisco": {"temperature": 14, "conditions": "foggy", "humidity": 85},
"new york city": {"temperature": 18, "conditions": "sunny", "humidity": 60},
"miami": {"temperature": 29, "conditions": "hot and humid", "humidity": 90},
"chicago": {"temperature": 9, "conditions": "windy", "humidity": 65},
}
@tool
async def get_weather(location: str) -> Content:
"""Fetch current weather for a location and push it into AG-UI shared state.
Unlike ``predict_state_config`` — which derives state optimistically from
LLM-predicted tool call arguments — this tool uses ``state_update`` to
forward the *actual* fetched weather to the frontend. The ``text`` goes
back to the LLM as the normal tool result, and the ``state`` dict is merged
into the AG-UI shared state.
Args:
location: City name to look up.
Returns:
A :class:`Content` carrying both the LLM-visible text result and a
deterministic state snapshot.
"""
key = location.lower()
data = _WEATHER_DB.get(
key,
{"temperature": 21, "conditions": "partly cloudy", "humidity": 50},
)
weather_record = {"location": location, **data}
return state_update(
text=(
f"The weather in {location} is {data['conditions']} at "
f"{data['temperature']}°C with {data['humidity']}% humidity."
),
state={"weather": weather_record},
)
def weather_state_agent(client: SupportsChatGetResponse[Any]) -> AgentFrameworkAgent:
"""Create an AG-UI agent with a deterministic tool-driven state tool."""
agent = Agent[Any](
name="weather_state_agent",
instructions=(
"You are a weather assistant. When a user asks about the weather "
"in a city, call the get_weather tool and use its output to give a "
"friendly, concise reply. The tool also updates the shared UI state "
"so the frontend can render a weather card from the `weather` key."
),
client=client,
tools=[get_weather],
)
return AgentFrameworkAgent(
agent=agent,
name="WeatherStateAgent",
description="Weather agent that deterministically updates shared state from tool results.",
state_schema={
"weather": {
"type": "object",
"description": "Last fetched weather record",
},
},
)
@@ -24,6 +24,7 @@ from ..agents.subgraphs_agent import subgraphs_agent
from ..agents.task_steps_agent import task_steps_agent_wrapped
from ..agents.ui_generator_agent import ui_generator_agent
from ..agents.weather_agent import weather_agent
from ..agents.weather_state_agent import weather_state_agent
AnthropicClient: type[Any] | None
try:
@@ -141,6 +142,14 @@ add_agent_framework_fastapi_endpoint(
path="/subgraphs",
)
# Deterministic Tool-Driven State - tool returns state_update() to push snapshot
# from actual tool output (see issue #3167).
add_agent_framework_fastapi_endpoint(
app=app,
agent=weather_state_agent(client),
path="/deterministic_state",
)
def main():
"""Run the server."""
@@ -0,0 +1,267 @@
# Copyright (c) Microsoft. All rights reserved.
"""Golden event-stream tests for the deterministic tool-driven state scenario.
Covers issue https://github.com/microsoft/agent-framework/issues/3167 — a tool
returning :func:`agent_framework_ag_ui.state_update` must push a deterministic
``StateSnapshotEvent`` derived from its actual return value, orthogonal to the
optimistic ``predict_state_config`` path. These golden tests pin the user-visible
event stream so additive changes cannot silently regress it.
"""
from __future__ import annotations
from typing import Any
from agent_framework import AgentResponseUpdate, Content
from conftest import StubAgent
from event_stream import EventStream
from agent_framework_ag_ui import AgentFrameworkAgent, state_update
STATE_SCHEMA = {
"weather": {"type": "object", "description": "Last fetched weather"},
}
def _build_agent(updates: list[AgentResponseUpdate], **kwargs: Any) -> AgentFrameworkAgent:
stub = StubAgent(updates=updates)
kwargs.setdefault("state_schema", STATE_SCHEMA)
return AgentFrameworkAgent(agent=stub, **kwargs)
async def _run(agent: AgentFrameworkAgent, payload: dict[str, Any]) -> EventStream:
return EventStream([event async for event in agent.run(payload)])
PAYLOAD: dict[str, Any] = {
"thread_id": "thread-det-state",
"run_id": "run-det-state",
"messages": [{"role": "user", "content": "What's the weather in SF?"}],
"state": {"weather": {}},
}
def _tool_call(call_id: str, name: str, arguments: str) -> AgentResponseUpdate:
return AgentResponseUpdate(
contents=[Content.from_function_call(name=name, call_id=call_id, arguments=arguments)],
role="assistant",
)
def _tool_result_with_state(call_id: str, text: str, state: dict[str, Any]) -> AgentResponseUpdate:
"""Build a function_result update whose inner item carries a state marker.
This mirrors what the core framework produces when a real ``@tool`` returns
:func:`state_update`: ``parse_result`` keeps the ``Content`` as-is, and
``Content.from_function_result`` preserves its ``additional_properties``
inside ``items``.
"""
return AgentResponseUpdate(
contents=[
Content.from_function_result(
call_id=call_id,
result=[state_update(text=text, state=state)],
)
],
role="assistant",
)
# ── Golden stream tests ──
async def test_deterministic_state_emits_snapshot_after_tool_result() -> None:
"""The happy path: STATE_SNAPSHOT follows TOOL_CALL_RESULT in order."""
updates = [
_tool_call("call-1", "get_weather", '{"city": "SF"}'),
_tool_result_with_state(
"call-1",
text="Weather in SF: 14°C foggy",
state={"weather": {"city": "SF", "temp": 14, "conditions": "foggy"}},
),
AgentResponseUpdate(
contents=[Content.from_text(text="It's 14°C and foggy in SF.")],
role="assistant",
),
]
agent = _build_agent(updates)
stream = await _run(agent, PAYLOAD)
stream.assert_bookends()
stream.assert_no_run_error()
stream.assert_tool_calls_balanced()
stream.assert_text_messages_balanced()
# Ordered subsequence: the deterministic STATE_SNAPSHOT must follow the
# TOOL_CALL_RESULT. This is the central contract for #3167.
stream.assert_ordered_types(
[
"RUN_STARTED",
"TOOL_CALL_START",
"TOOL_CALL_ARGS",
"TOOL_CALL_END",
"TOOL_CALL_RESULT",
"STATE_SNAPSHOT",
"RUN_FINISHED",
]
)
# The final STATE_SNAPSHOT must carry the tool-driven state.
snapshot = stream.snapshot()
assert snapshot["weather"] == {"city": "SF", "temp": 14, "conditions": "foggy"}
async def test_deterministic_state_does_not_fire_for_plain_tool_result() -> None:
"""Regression guard: tools returning plain strings must NOT emit a new STATE_SNAPSHOT.
The initial STATE_SNAPSHOT fires once from the schema + initial payload
state. A plain (non-state_update) tool result must not add another one.
"""
updates = [
_tool_call("call-1", "get_weather", '{"city": "SF"}'),
AgentResponseUpdate(
contents=[Content.from_function_result(call_id="call-1", result="14°C foggy")],
role="assistant",
),
AgentResponseUpdate(
contents=[Content.from_text(text="It's 14°C and foggy.")],
role="assistant",
),
]
agent = _build_agent(updates)
stream = await _run(agent, PAYLOAD)
stream.assert_bookends()
stream.assert_no_run_error()
snapshots = stream.get("STATE_SNAPSHOT")
# Only the initial snapshot (from state_schema + payload state) should exist.
# No deterministic snapshot should have been added by the plain tool result.
assert len(snapshots) == 1, (
f"Expected exactly 1 STATE_SNAPSHOT (initial only) for plain tool result; "
f"got {len(snapshots)}. Snapshots: {[s.snapshot for s in snapshots]}"
)
async def test_deterministic_state_merges_into_initial_state() -> None:
"""The tool-driven snapshot must merge into, not replace, pre-existing state keys."""
payload = dict(PAYLOAD)
payload["state"] = {"weather": {}, "user_preferences": {"unit": "C"}}
updates = [
_tool_call("call-1", "get_weather", '{"city": "SF"}'),
_tool_result_with_state(
"call-1",
text="Weather: 14°C",
state={"weather": {"city": "SF", "temp": 14}},
),
]
agent = _build_agent(updates, state_schema={**STATE_SCHEMA, "user_preferences": {"type": "object"}})
stream = await _run(agent, payload)
stream.assert_bookends()
stream.assert_no_run_error()
final_snapshot = stream.snapshot()
assert final_snapshot["weather"] == {"city": "SF", "temp": 14}
assert final_snapshot["user_preferences"] == {"unit": "C"}, (
"Pre-existing state keys must survive the deterministic merge"
)
async def test_deterministic_state_llm_visible_text_is_clean() -> None:
"""The LLM-visible TOOL_CALL_RESULT content must not leak the state marker key."""
updates = [
_tool_call("call-1", "get_weather", '{"city": "SF"}'),
_tool_result_with_state(
"call-1",
text="Weather in SF: 14°C foggy",
state={"weather": {"city": "SF", "temp": 14}},
),
]
agent = _build_agent(updates)
stream = await _run(agent, PAYLOAD)
result = stream.first("TOOL_CALL_RESULT")
assert result.content == "Weather in SF: 14°C foggy"
# The marker key must never appear in the content sent back to the LLM.
assert "__ag_ui_tool_result_state__" not in result.content
assert "weather" not in result.content # not as a raw state dump
async def test_deterministic_state_multiple_tools_merge_in_order() -> None:
"""Two state-updating tools in one run merge in order; later wins on key collisions."""
updates = [
_tool_call("call-a", "get_weather", '{"city": "SF"}'),
_tool_result_with_state(
"call-a",
text="First result",
state={"weather": {"city": "SF", "temp": 14}, "source": "primary"},
),
_tool_call("call-b", "get_weather_refined", '{"city": "SF"}'),
_tool_result_with_state(
"call-b",
text="Refined result",
state={"source": "refined"},
),
AgentResponseUpdate(
contents=[Content.from_text(text="Here you go.")],
role="assistant",
),
]
agent = _build_agent(
updates,
state_schema={**STATE_SCHEMA, "source": {"type": "string"}},
)
stream = await _run(agent, PAYLOAD)
stream.assert_bookends()
stream.assert_tool_calls_balanced()
stream.assert_no_run_error()
# Two tool-driven snapshots emitted (one per tool) plus the initial snapshot.
snapshots = stream.get("STATE_SNAPSHOT")
assert len(snapshots) >= 2, f"Expected at least 2 STATE_SNAPSHOTs; got {len(snapshots)}"
final = stream.snapshot()
assert final["weather"] == {"city": "SF", "temp": 14}
# Later tool must override earlier tool on the shared key.
assert final["source"] == "refined"
async def test_deterministic_state_coexists_with_predict_state_config() -> None:
"""Predictive state and deterministic state must coexist without clobbering each other."""
predict_config = {
"draft": {
"tool": "write_draft",
"tool_argument": "body",
}
}
updates = [
# Predictive tool: its argument "body" populates state.draft optimistically.
_tool_call("call-1", "write_draft", '{"body": "Hello world"}'),
# Then a deterministic tool result landing a different key.
_tool_result_with_state(
"call-1",
text="Draft saved",
state={"weather": {"city": "SF", "temp": 14}},
),
]
agent = _build_agent(
updates,
state_schema={**STATE_SCHEMA, "draft": {"type": "string"}},
predict_state_config=predict_config,
require_confirmation=False,
)
payload = dict(PAYLOAD)
payload["state"] = {"weather": {}, "draft": ""}
stream = await _run(agent, payload)
stream.assert_bookends()
stream.assert_no_run_error()
stream.assert_tool_calls_balanced()
# The final observed state must contain both the deterministic and predictive contributions.
final = stream.snapshot()
assert final["weather"] == {"city": "SF", "temp": 14}, f"Deterministic state missing from final snapshot: {final}"
@@ -1405,3 +1405,95 @@ async def test_fabricated_rejection_without_pending_approval_is_blocked(streamin
for content in msg.contents:
if content.type == "function_result" and content.call_id == "fake_reject_001":
assert False, "Fabricated rejection response leaked as function_result into LLM messages"
async def test_state_update_end_to_end_via_real_tool_invocation(streaming_chat_client_stub):
"""End-to-end coverage for issue #3167: a real ``@tool`` returning ``state_update`` must
emit a deterministic STATE_SNAPSHOT through the full pipeline.
This test exercises the entire chain that a user would hit in production:
``FunctionInvocationLayer`` executes the tool, ``FunctionTool.parse_result``
preserves the returned ``Content`` with its ``additional_properties`` marker,
``Content.from_function_result`` carries the marker through in ``items``,
and the AG-UI emitter extracts it via ``_extract_tool_result_state`` and
emits the snapshot. A regression anywhere in that chain will fail this test.
"""
from agent_framework import tool
from agent_framework.ag_ui import AgentFrameworkAgent
from agent_framework_ag_ui import state_update
@tool(name="get_weather", description="Get current weather for a city.")
async def get_weather(city: str) -> Content:
return state_update(
text=f"Weather in {city}: 14°C foggy",
state={"weather": {"city": city, "temperature": 14, "conditions": "foggy"}},
)
call_count = {"n": 0}
async def stream_fn(
messages: MutableSequence[Message], options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
"""First turn proposes a tool call; second turn (after tool execution) returns text."""
call_count["n"] += 1
if call_count["n"] == 1:
yield ChatResponseUpdate(
contents=[
Content.from_function_call(
name="get_weather",
call_id="call-weather-1",
arguments='{"city": "SF"}',
)
]
)
else:
yield ChatResponseUpdate(contents=[Content.from_text(text="It's 14°C and foggy in SF.")])
agent = Agent(
client=streaming_chat_client_stub(stream_fn),
name="weather_agent",
instructions="Answer weather questions.",
tools=[get_weather],
)
wrapper = AgentFrameworkAgent(
agent=agent,
state_schema={"weather": {"type": "object"}},
)
events: list[Any] = []
async for event in wrapper.run(
{
"thread_id": "thread-weather",
"run_id": "run-weather",
"messages": [{"role": "user", "content": "What's the weather in SF?"}],
"state": {"weather": {}},
}
):
events.append(event)
types = [e.type for e in events]
# The tool call must be visible in the stream.
assert "TOOL_CALL_START" in types, f"Missing TOOL_CALL_START in: {types}"
assert "TOOL_CALL_RESULT" in types, f"Missing TOOL_CALL_RESULT in: {types}"
# A STATE_SNAPSHOT must be emitted after the tool result.
tool_result_idx = types.index("TOOL_CALL_RESULT")
snapshot_indices_after_result = [i for i, t in enumerate(types) if t == "STATE_SNAPSHOT" and i > tool_result_idx]
assert snapshot_indices_after_result, (
f"Expected a STATE_SNAPSHOT after TOOL_CALL_RESULT (index {tool_result_idx}); got types: {types}"
)
# The tool's deterministic snapshot carries the actual fetched weather data.
final_snapshot = events[snapshot_indices_after_result[-1]].snapshot
assert final_snapshot["weather"] == {
"city": "SF",
"temperature": 14,
"conditions": "foggy",
}
# The LLM-visible tool result must carry the plain text, not the marker key.
tool_result_event = next(e for e in events if e.type == "TOOL_CALL_RESULT")
assert tool_result_event.content == "Weather in SF: 14°C foggy"
assert "__ag_ui_tool_result_state__" not in tool_result_event.content
@@ -18,7 +18,24 @@ def test_core_ag_ui_lazy_exports_include_only_stable_api() -> None:
assert hasattr(ag_ui, "AgentFrameworkAgent")
assert hasattr(ag_ui, "AGUIChatClient")
assert hasattr(ag_ui, "add_agent_framework_fastapi_endpoint")
assert hasattr(ag_ui, "state_update")
assert not hasattr(ag_ui, "WorkflowFactory")
assert not hasattr(ag_ui, "AGUIRequest")
assert not hasattr(ag_ui, "RunMetadata")
def test_agent_framework_ag_ui_exports_state_update() -> None:
"""Runtime package should export the ``state_update`` helper."""
from agent_framework_ag_ui import state_update
assert callable(state_update)
def test_core_ag_ui_lazy_exports_include_event_converter_and_http_service() -> None:
"""Core facade must expose AGUIEventConverter, AGUIHttpService, and __version__."""
from agent_framework import ag_ui
assert hasattr(ag_ui, "AGUIEventConverter")
assert hasattr(ag_ui, "AGUIHttpService")
assert hasattr(ag_ui, "__version__")
@@ -2,14 +2,20 @@
"""Tests for _run_common.py edge cases."""
from ag_ui.core import EventType
from agent_framework import Content
from agent_framework_ag_ui import state_update
from agent_framework_ag_ui._orchestration._predictive_state import PredictiveStateHandler
from agent_framework_ag_ui._run_common import (
FlowState,
_emit_mcp_tool_result,
_emit_tool_result,
_extract_resume_payload,
_extract_tool_result_state,
_normalize_resume_interrupts,
)
from agent_framework_ag_ui._state import TOOL_RESULT_STATE_KEY
class TestNormalizeResumeInterrupts:
@@ -120,3 +126,223 @@ class TestEmitToolResult:
assert "TEXT_MESSAGE_END" in event_types
assert flow.message_id is None
assert flow.accumulated_text == ""
class TestStateUpdateHelper:
"""Tests for the public ``state_update`` helper."""
def test_builds_text_content_with_state_marker(self):
"""state_update returns a text Content carrying state in additional_properties."""
c = state_update(text="done", state={"weather": {"temp": 14}})
assert c.type == "text"
assert c.text == "done"
assert c.additional_properties == {
TOOL_RESULT_STATE_KEY: {"weather": {"temp": 14}},
}
def test_empty_text_is_allowed(self):
"""State-only tools can omit the text argument."""
c = state_update(state={"steps": ["a", "b"]})
assert c.text == ""
assert c.additional_properties[TOOL_RESULT_STATE_KEY] == {"steps": ["a", "b"]}
def test_non_mapping_state_raises(self):
"""Passing a non-mapping value for state raises TypeError."""
import pytest
with pytest.raises(TypeError):
state_update(text="t", state=["not", "a", "mapping"]) # type: ignore[arg-type]
def test_state_is_copied_defensively(self):
"""Mutating the caller's dict after ``state_update`` must not mutate the content."""
caller_state = {"weather": {"temp": 14}}
c = state_update(text="ok", state=caller_state)
caller_state["weather"]["temp"] = 99
# The top-level dict was copied, so replacing the key in caller_state
# would not affect the Content, but nested dicts share references — document
# this by asserting only the top-level copy semantics.
assert TOOL_RESULT_STATE_KEY in c.additional_properties
inner = c.additional_properties[TOOL_RESULT_STATE_KEY]
assert inner is not caller_state
class TestExtractToolResultState:
"""Tests for ``_extract_tool_result_state``."""
def test_returns_none_for_plain_string_result(self):
content = Content.from_function_result(call_id="c1", result="plain")
assert _extract_tool_result_state(content) is None
def test_extracts_state_from_inner_item(self):
tool_return = state_update(text="hi", state={"k": 1})
content = Content.from_function_result(call_id="c1", result=[tool_return])
assert _extract_tool_result_state(content) == {"k": 1}
def test_extracts_state_from_outer_additional_properties(self):
"""Outer function_result content can also carry state (legacy/advanced use)."""
content = Content.from_function_result(
call_id="c1",
result="hi",
additional_properties={TOOL_RESULT_STATE_KEY: {"k": 1}},
)
assert _extract_tool_result_state(content) == {"k": 1}
def test_merges_multiple_items(self):
a = state_update(text="a", state={"k": 1, "shared": "from_a"})
b = state_update(text="b", state={"shared": "from_b", "extra": True})
content = Content.from_function_result(call_id="c1", result=[a, b])
merged = _extract_tool_result_state(content)
assert merged == {"k": 1, "shared": "from_b", "extra": True}
def test_ignores_non_dict_marker_value(self):
"""A garbled marker value must not break extraction (defensive guard)."""
bad = Content.from_text(
"hi",
additional_properties={TOOL_RESULT_STATE_KEY: "not-a-dict"},
)
content = Content.from_function_result(call_id="c1", result=[bad])
assert _extract_tool_result_state(content) is None
class TestEmitToolResultWithState:
"""Tests for the deterministic state emission in ``_emit_tool_result``."""
def test_emits_state_snapshot_after_tool_call_result(self):
"""Tool returning state_update produces a StateSnapshotEvent right after the result."""
tool_return = state_update(
text="Weather: 14°C",
state={"weather": {"temp": 14, "conditions": "foggy"}},
)
content = Content.from_function_result(call_id="call_1", result=[tool_return])
flow = FlowState()
events = _emit_tool_result(content, flow)
event_types = [e.type for e in events]
# Expect TOOL_CALL_END, TOOL_CALL_RESULT, STATE_SNAPSHOT in that order.
assert event_types[0] == EventType.TOOL_CALL_END
assert event_types[1] == EventType.TOOL_CALL_RESULT
state_idx = event_types.index(EventType.STATE_SNAPSHOT)
assert state_idx == 2
assert events[state_idx].snapshot == {"weather": {"temp": 14, "conditions": "foggy"}}
def test_updates_flow_current_state(self):
tool_return = state_update(text="", state={"a": 1})
content = Content.from_function_result(call_id="c1", result=[tool_return])
flow = FlowState(current_state={"existing": "value"})
_emit_tool_result(content, flow)
# Existing keys must survive (merge semantics), new keys must be added.
assert flow.current_state == {"existing": "value", "a": 1}
def test_merge_overrides_existing_key(self):
tool_return = state_update(text="", state={"existing": "new"})
content = Content.from_function_result(call_id="c1", result=[tool_return])
flow = FlowState(current_state={"existing": "old", "other": 1})
_emit_tool_result(content, flow)
assert flow.current_state == {"existing": "new", "other": 1}
def test_no_state_snapshot_when_result_has_no_state(self):
"""Plain tool results must not emit a StateSnapshotEvent."""
content = Content.from_function_result(call_id="c1", result="plain")
flow = FlowState()
events = _emit_tool_result(content, flow)
assert all(e.type != EventType.STATE_SNAPSHOT for e in events)
def test_tool_result_content_text_unchanged(self):
"""The text sent to the LLM must not leak the state marker."""
tool_return = state_update(text="Weather: 14°C", state={"weather": {"temp": 14}})
content = Content.from_function_result(call_id="c1", result=[tool_return])
flow = FlowState()
events = _emit_tool_result(content, flow)
result_events = [e for e in events if e.type == EventType.TOOL_CALL_RESULT]
assert len(result_events) == 1
assert result_events[0].content == "Weather: 14°C"
assert TOOL_RESULT_STATE_KEY not in result_events[0].content
def test_coexists_with_active_predictive_state_handler(self):
"""Both predictive and deterministic state produce a single coalesced snapshot.
Predictive state (``predict_state_config``) and deterministic state
(``state_update``) are two independent mechanisms. When both are active,
a single coalesced ``StateSnapshotEvent`` is emitted containing the
merged result of both contributions.
"""
flow = FlowState(current_state={"preexisting": "value"})
handler = PredictiveStateHandler(
predict_state_config={"draft": {"tool": "write_draft", "tool_argument": "body"}},
current_state=flow.current_state,
)
tool_return = state_update(text="Draft written", state={"draft_final": True})
content = Content.from_function_result(call_id="c1", result=[tool_return])
events = _emit_tool_result(content, flow, predictive_handler=handler)
# Exactly one coalesced snapshot must be emitted containing all merged keys.
snapshots = [e for e in events if e.type == EventType.STATE_SNAPSHOT]
assert len(snapshots) == 1
assert snapshots[0].snapshot["draft_final"] is True
assert snapshots[0].snapshot["preexisting"] == "value"
assert flow.current_state["draft_final"] is True
assert flow.current_state["preexisting"] == "value"
def test_predictive_and_deterministic_emit_single_snapshot(self):
"""When both predictive_handler and state_update are active, only one snapshot is emitted."""
flow = FlowState(current_state={"existing": "yes"})
handler = PredictiveStateHandler(
predict_state_config={"draft": {"tool": "write_draft", "tool_argument": "body"}},
current_state=flow.current_state,
)
tool_return = state_update(text="ok", state={"new_key": 42})
content = Content.from_function_result(call_id="c1", result=[tool_return])
events = _emit_tool_result(content, flow, predictive_handler=handler)
snapshots = [e for e in events if e.type == EventType.STATE_SNAPSHOT]
assert len(snapshots) == 1, f"Expected 1 coalesced snapshot, got {len(snapshots)}"
assert snapshots[0].snapshot == {"existing": "yes", "new_key": 42}
class TestEmitMcpToolResultWithState:
"""MCP tool results should honour the same state_update marker.
MCP results come from an external MCP server rather than a locally
executed ``@tool`` function, so they do not flow through ``parse_result``
and ``content.items`` is typically empty. State is instead carried on the
outer content's ``additional_properties`` (e.g. by middleware that
inspects the MCP output and attaches a marker). ``_extract_tool_result_state``
supports both locations so this path remains usable.
"""
def test_mcp_tool_result_emits_state_snapshot_from_additional_properties(self):
content = Content.from_mcp_server_tool_result(
call_id="mcp_1",
output="server result",
additional_properties={TOOL_RESULT_STATE_KEY: {"mcp_ok": True}},
)
flow = FlowState()
events = _emit_mcp_tool_result(content, flow)
event_types = [e.type for e in events]
assert EventType.TOOL_CALL_END in event_types
assert EventType.TOOL_CALL_RESULT in event_types
assert EventType.STATE_SNAPSHOT in event_types
assert flow.current_state == {"mcp_ok": True}
def test_mcp_tool_result_without_state_emits_no_snapshot(self):
content = Content.from_mcp_server_tool_result(
call_id="mcp_1",
output="server result",
)
flow = FlowState()
events = _emit_mcp_tool_result(content, flow)
assert all(e.type != EventType.STATE_SNAPSHOT for e in events)
@@ -7,10 +7,13 @@ This module lazily re-exports objects from:
Supported classes and functions:
- AgentFrameworkAgent
- AgentFrameworkWorkflow
- AGUIChatClient
- AGUIEventConverter
- AGUIHttpService
- add_agent_framework_fastapi_endpoint
- state_update
- __version__
"""
import importlib
@@ -23,6 +26,10 @@ _IMPORTS = [
"AgentFrameworkWorkflow",
"add_agent_framework_fastapi_endpoint",
"AGUIChatClient",
"AGUIEventConverter",
"AGUIHttpService",
"state_update",
"__version__",
]
@@ -8,6 +8,7 @@ from agent_framework_ag_ui import (
AGUIHttpService,
__version__,
add_agent_framework_fastapi_endpoint,
state_update,
)
__all__ = [
@@ -18,4 +19,5 @@ __all__ = [
"AgentFrameworkWorkflow",
"__version__",
"add_agent_framework_fastapi_endpoint",
"state_update",
]