mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: AG-UI deterministic state updates from tool results (#5201)
* AG-UI deterministic state updates from tool results * fix(ag-ui): address PR #5201 review comments 1. Add missing AGUIEventConverter, AGUIHttpService, __version__ to _IMPORTS in core ag_ui lazy-export list to match the .pyi stub. 2. Coalesce predictive and deterministic state snapshots into a single StateSnapshotEvent when both mechanisms are active on the same tool result, reducing redundant snapshot traffic. 3. Update state_update() docstring to clarify that a predictive snapshot may be emitted before the deterministic one when predict_state_config is active. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
3c31ac28b5
commit
f183f888a3
@@ -9,6 +9,7 @@ from ._client import AGUIChatClient
|
||||
from ._endpoint import add_agent_framework_fastapi_endpoint
|
||||
from ._event_converters import AGUIEventConverter
|
||||
from ._http_service import AGUIHttpService
|
||||
from ._state import state_update
|
||||
from ._types import AgentState, AGUIChatOptions, AGUIRequest, PredictStateConfig, RunMetadata
|
||||
from ._workflow import AgentFrameworkWorkflow, WorkflowFactory
|
||||
|
||||
@@ -34,5 +35,6 @@ __all__ = [
|
||||
"PredictStateConfig",
|
||||
"RunMetadata",
|
||||
"DEFAULT_TAGS",
|
||||
"state_update",
|
||||
"__version__",
|
||||
]
|
||||
|
||||
@@ -6,6 +6,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -31,6 +32,7 @@ from ag_ui.core import (
|
||||
from agent_framework import Content
|
||||
|
||||
from ._orchestration._predictive_state import PredictiveStateHandler
|
||||
from ._state import TOOL_RESULT_STATE_KEY
|
||||
from ._utils import generate_event_id, make_json_safe
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -233,16 +235,66 @@ def _emit_tool_call(
|
||||
return events
|
||||
|
||||
|
||||
def _extract_tool_result_state(content: Content) -> dict[str, Any] | None:
|
||||
"""Extract a deterministic AG-UI state update from a tool-result ``Content``.
|
||||
|
||||
Tools using :func:`agent_framework_ag_ui.state_update` carry the state
|
||||
payload in ``additional_properties[TOOL_RESULT_STATE_KEY]`` on the inner
|
||||
text item produced by ``parse_result``. We also check the outer
|
||||
function_result content's ``additional_properties`` for robustness.
|
||||
|
||||
If multiple items carry state, they are merged in order so later items
|
||||
override earlier ones (plain ``dict.update`` semantics).
|
||||
|
||||
Returns:
|
||||
The merged state dict to apply, or ``None`` if no state update is
|
||||
present.
|
||||
"""
|
||||
merged: dict[str, Any] | None = None
|
||||
|
||||
outer_ap = getattr(content, "additional_properties", None) or {}
|
||||
outer_state = outer_ap.get(TOOL_RESULT_STATE_KEY)
|
||||
if isinstance(outer_state, dict):
|
||||
merged = dict(outer_state)
|
||||
|
||||
for item in content.items or ():
|
||||
item_ap = getattr(item, "additional_properties", None) or {}
|
||||
item_state = item_ap.get(TOOL_RESULT_STATE_KEY)
|
||||
if isinstance(item_state, dict):
|
||||
if merged is None:
|
||||
merged = dict(item_state)
|
||||
else:
|
||||
merged.update(item_state)
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def _emit_tool_result_common(
|
||||
call_id: str,
|
||||
raw_result: Any,
|
||||
flow: FlowState,
|
||||
predictive_handler: PredictiveStateHandler | None = None,
|
||||
*,
|
||||
state_update: Mapping[str, Any] | None = None,
|
||||
) -> list[BaseEvent]:
|
||||
"""Shared helper for emitting ToolCallEnd + ToolCallResult events and performing FlowState cleanup.
|
||||
|
||||
Both ``_emit_tool_result`` (standard function results) and ``_emit_mcp_tool_result``
|
||||
(MCP server tool results) delegate to this function.
|
||||
|
||||
Args:
|
||||
call_id: Tool call identifier.
|
||||
raw_result: The stringified tool result content sent back to the LLM.
|
||||
flow: Current ``FlowState``.
|
||||
predictive_handler: Optional predictive state handler driven by
|
||||
``predict_state_config``.
|
||||
state_update: Optional deterministic state snapshot produced by a tool
|
||||
returning :func:`agent_framework_ag_ui.state_update`. When present,
|
||||
it is merged into ``flow.current_state`` and a ``StateSnapshotEvent``
|
||||
is emitted after the ``ToolCallResult`` event. When both
|
||||
``predictive_handler`` and ``state_update`` are active, predictive
|
||||
updates are applied first, then the deterministic merge, and a
|
||||
single coalesced ``StateSnapshotEvent`` is emitted.
|
||||
"""
|
||||
events: list[BaseEvent] = []
|
||||
|
||||
@@ -271,8 +323,18 @@ def _emit_tool_result_common(
|
||||
|
||||
if predictive_handler:
|
||||
predictive_handler.apply_pending_updates()
|
||||
if flow.current_state:
|
||||
events.append(StateSnapshotEvent(snapshot=flow.current_state))
|
||||
|
||||
if state_update:
|
||||
flow.current_state.update(state_update)
|
||||
logger.debug(
|
||||
"Emitted deterministic tool-result StateSnapshotEvent for call_id=%s (keys=%s)",
|
||||
call_id,
|
||||
list(state_update.keys()),
|
||||
)
|
||||
|
||||
# Emit a single coalesced snapshot when either mechanism updated state.
|
||||
if (predictive_handler or state_update) and flow.current_state:
|
||||
events.append(StateSnapshotEvent(snapshot=flow.current_state))
|
||||
|
||||
flow.tool_call_id = None
|
||||
flow.tool_call_name = None
|
||||
@@ -295,7 +357,14 @@ def _emit_tool_result(
|
||||
if not content.call_id:
|
||||
return []
|
||||
raw_result = content.result if content.result is not None else ""
|
||||
return _emit_tool_result_common(content.call_id, raw_result, flow, predictive_handler)
|
||||
state_update = _extract_tool_result_state(content)
|
||||
return _emit_tool_result_common(
|
||||
content.call_id,
|
||||
raw_result,
|
||||
flow,
|
||||
predictive_handler,
|
||||
state_update=state_update,
|
||||
)
|
||||
|
||||
|
||||
def _emit_approval_request(
|
||||
@@ -460,7 +529,14 @@ def _emit_mcp_tool_result(
|
||||
logger.warning("MCP tool result content missing call_id, skipping")
|
||||
return []
|
||||
raw_output = content.output if content.output is not None else ""
|
||||
return _emit_tool_result_common(content.call_id, raw_output, flow, predictive_handler)
|
||||
state_update = _extract_tool_result_state(content)
|
||||
return _emit_tool_result_common(
|
||||
content.call_id,
|
||||
raw_output,
|
||||
flow,
|
||||
predictive_handler,
|
||||
state_update=state_update,
|
||||
)
|
||||
|
||||
|
||||
def _close_reasoning_block(flow: FlowState) -> list[BaseEvent]:
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Deterministic tool-driven AG-UI state updates.
|
||||
|
||||
Tools wired into the :mod:`agent_framework_ag_ui` endpoint can push a
|
||||
deterministic state update by returning :func:`state_update`. Unlike
|
||||
``predict_state_config`` — which emits ``StateDeltaEvent``s optimistically from
|
||||
LLM-predicted tool call arguments — ``state_update`` runs *after* the tool
|
||||
executes, so the AG-UI state always reflects the tool's actual return value.
|
||||
|
||||
See issue https://github.com/microsoft/agent-framework/issues/3167 for the
|
||||
motivating discussion.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import Content
|
||||
|
||||
__all__ = ["TOOL_RESULT_STATE_KEY", "state_update"]
|
||||
|
||||
|
||||
TOOL_RESULT_STATE_KEY = "__ag_ui_tool_result_state__"
|
||||
"""Reserved ``Content.additional_properties`` key used to carry a tool-driven
|
||||
state snapshot from a tool return value through to the AG-UI emitter."""
|
||||
|
||||
|
||||
def state_update(
|
||||
text: str = "",
|
||||
*,
|
||||
state: Mapping[str, Any],
|
||||
) -> Content:
|
||||
"""Build a tool return value that deterministically updates AG-UI shared state.
|
||||
|
||||
Return the result of this helper from an agent tool to push a state update
|
||||
to AG-UI clients using the actual tool output, rather than LLM-predicted
|
||||
tool arguments.
|
||||
|
||||
When the AG-UI endpoint emits the tool result, it will:
|
||||
|
||||
* Forward ``text`` to the LLM as the normal ``function_result`` content.
|
||||
* Merge ``state`` into ``FlowState.current_state``.
|
||||
* Emit a deterministic ``StateSnapshotEvent`` after the ``ToolCallResult``
|
||||
event so frontends observe the updated state deterministically. If
|
||||
predictive state is enabled, a predictive snapshot may be emitted first.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework import tool
|
||||
from agent_framework_ag_ui import state_update
|
||||
|
||||
|
||||
@tool
|
||||
async def get_weather(city: str) -> Content:
|
||||
data = await _fetch_weather(city)
|
||||
return state_update(
|
||||
text=f"Weather in {city}: {data['temp']}°C {data['conditions']}",
|
||||
state={"weather": {"city": city, **data}},
|
||||
)
|
||||
|
||||
Args:
|
||||
text: Text passed back to the LLM as the ``function_result`` content.
|
||||
Defaults to an empty string for tools whose only output is a state
|
||||
update.
|
||||
state: A mapping merged into the AG-UI shared state via JSON-compatible
|
||||
``dict.update`` semantics. Nested dicts are replaced, not deep-merged.
|
||||
|
||||
Returns:
|
||||
A ``Content`` object with ``type="text"``. The state payload rides in
|
||||
``additional_properties`` under :data:`TOOL_RESULT_STATE_KEY` and is
|
||||
extracted by the AG-UI emitter.
|
||||
|
||||
Raises:
|
||||
TypeError: If ``state`` is not a ``Mapping``.
|
||||
"""
|
||||
if not isinstance(state, Mapping):
|
||||
raise TypeError(f"state_update() 'state' must be a Mapping, got {type(state).__name__}")
|
||||
return Content.from_text(
|
||||
text,
|
||||
additional_properties={TOOL_RESULT_STATE_KEY: dict(state)},
|
||||
)
|
||||
@@ -0,0 +1,92 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Deterministic tool-driven AG-UI state example.
|
||||
|
||||
This sample demonstrates how a tool can push a *deterministic* state update
|
||||
to the AG-UI frontend based on its actual return value — in contrast to
|
||||
``predict_state_config`` which fires optimistically from LLM-predicted tool
|
||||
call arguments. See issue https://github.com/microsoft/agent-framework/issues/3167.
|
||||
|
||||
The :func:`agent_framework_ag_ui.state_update` helper wraps a text result
|
||||
together with a state snapshot. When a tool returns one of these, the AG-UI
|
||||
endpoint merges the snapshot into the shared state and emits a
|
||||
``StateSnapshotEvent`` after the tool result.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import Agent, Content, SupportsChatGetResponse, tool
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
from agent_framework_ag_ui import state_update
|
||||
|
||||
# Simulated weather database — in the issue's motivating example the tool
|
||||
# would instead call a real weather API.
|
||||
_WEATHER_DB: dict[str, dict[str, Any]] = {
|
||||
"seattle": {"temperature": 11, "conditions": "rainy", "humidity": 75},
|
||||
"san francisco": {"temperature": 14, "conditions": "foggy", "humidity": 85},
|
||||
"new york city": {"temperature": 18, "conditions": "sunny", "humidity": 60},
|
||||
"miami": {"temperature": 29, "conditions": "hot and humid", "humidity": 90},
|
||||
"chicago": {"temperature": 9, "conditions": "windy", "humidity": 65},
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
async def get_weather(location: str) -> Content:
|
||||
"""Fetch current weather for a location and push it into AG-UI shared state.
|
||||
|
||||
Unlike ``predict_state_config`` — which derives state optimistically from
|
||||
LLM-predicted tool call arguments — this tool uses ``state_update`` to
|
||||
forward the *actual* fetched weather to the frontend. The ``text`` goes
|
||||
back to the LLM as the normal tool result, and the ``state`` dict is merged
|
||||
into the AG-UI shared state.
|
||||
|
||||
Args:
|
||||
location: City name to look up.
|
||||
|
||||
Returns:
|
||||
A :class:`Content` carrying both the LLM-visible text result and a
|
||||
deterministic state snapshot.
|
||||
"""
|
||||
key = location.lower()
|
||||
data = _WEATHER_DB.get(
|
||||
key,
|
||||
{"temperature": 21, "conditions": "partly cloudy", "humidity": 50},
|
||||
)
|
||||
weather_record = {"location": location, **data}
|
||||
return state_update(
|
||||
text=(
|
||||
f"The weather in {location} is {data['conditions']} at "
|
||||
f"{data['temperature']}°C with {data['humidity']}% humidity."
|
||||
),
|
||||
state={"weather": weather_record},
|
||||
)
|
||||
|
||||
|
||||
def weather_state_agent(client: SupportsChatGetResponse[Any]) -> AgentFrameworkAgent:
|
||||
"""Create an AG-UI agent with a deterministic tool-driven state tool."""
|
||||
agent = Agent[Any](
|
||||
name="weather_state_agent",
|
||||
instructions=(
|
||||
"You are a weather assistant. When a user asks about the weather "
|
||||
"in a city, call the get_weather tool and use its output to give a "
|
||||
"friendly, concise reply. The tool also updates the shared UI state "
|
||||
"so the frontend can render a weather card from the `weather` key."
|
||||
),
|
||||
client=client,
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
return AgentFrameworkAgent(
|
||||
agent=agent,
|
||||
name="WeatherStateAgent",
|
||||
description="Weather agent that deterministically updates shared state from tool results.",
|
||||
state_schema={
|
||||
"weather": {
|
||||
"type": "object",
|
||||
"description": "Last fetched weather record",
|
||||
},
|
||||
},
|
||||
)
|
||||
@@ -24,6 +24,7 @@ from ..agents.subgraphs_agent import subgraphs_agent
|
||||
from ..agents.task_steps_agent import task_steps_agent_wrapped
|
||||
from ..agents.ui_generator_agent import ui_generator_agent
|
||||
from ..agents.weather_agent import weather_agent
|
||||
from ..agents.weather_state_agent import weather_state_agent
|
||||
|
||||
AnthropicClient: type[Any] | None
|
||||
try:
|
||||
@@ -141,6 +142,14 @@ add_agent_framework_fastapi_endpoint(
|
||||
path="/subgraphs",
|
||||
)
|
||||
|
||||
# Deterministic Tool-Driven State - tool returns state_update() to push snapshot
|
||||
# from actual tool output (see issue #3167).
|
||||
add_agent_framework_fastapi_endpoint(
|
||||
app=app,
|
||||
agent=weather_state_agent(client),
|
||||
path="/deterministic_state",
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the server."""
|
||||
|
||||
@@ -0,0 +1,267 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Golden event-stream tests for the deterministic tool-driven state scenario.
|
||||
|
||||
Covers issue https://github.com/microsoft/agent-framework/issues/3167 — a tool
|
||||
returning :func:`agent_framework_ag_ui.state_update` must push a deterministic
|
||||
``StateSnapshotEvent`` derived from its actual return value, orthogonal to the
|
||||
optimistic ``predict_state_config`` path. These golden tests pin the user-visible
|
||||
event stream so additive changes cannot silently regress it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import AgentResponseUpdate, Content
|
||||
from conftest import StubAgent
|
||||
from event_stream import EventStream
|
||||
|
||||
from agent_framework_ag_ui import AgentFrameworkAgent, state_update
|
||||
|
||||
STATE_SCHEMA = {
|
||||
"weather": {"type": "object", "description": "Last fetched weather"},
|
||||
}
|
||||
|
||||
|
||||
def _build_agent(updates: list[AgentResponseUpdate], **kwargs: Any) -> AgentFrameworkAgent:
|
||||
stub = StubAgent(updates=updates)
|
||||
kwargs.setdefault("state_schema", STATE_SCHEMA)
|
||||
return AgentFrameworkAgent(agent=stub, **kwargs)
|
||||
|
||||
|
||||
async def _run(agent: AgentFrameworkAgent, payload: dict[str, Any]) -> EventStream:
|
||||
return EventStream([event async for event in agent.run(payload)])
|
||||
|
||||
|
||||
PAYLOAD: dict[str, Any] = {
|
||||
"thread_id": "thread-det-state",
|
||||
"run_id": "run-det-state",
|
||||
"messages": [{"role": "user", "content": "What's the weather in SF?"}],
|
||||
"state": {"weather": {}},
|
||||
}
|
||||
|
||||
|
||||
def _tool_call(call_id: str, name: str, arguments: str) -> AgentResponseUpdate:
|
||||
return AgentResponseUpdate(
|
||||
contents=[Content.from_function_call(name=name, call_id=call_id, arguments=arguments)],
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
|
||||
def _tool_result_with_state(call_id: str, text: str, state: dict[str, Any]) -> AgentResponseUpdate:
|
||||
"""Build a function_result update whose inner item carries a state marker.
|
||||
|
||||
This mirrors what the core framework produces when a real ``@tool`` returns
|
||||
:func:`state_update`: ``parse_result`` keeps the ``Content`` as-is, and
|
||||
``Content.from_function_result`` preserves its ``additional_properties``
|
||||
inside ``items``.
|
||||
"""
|
||||
return AgentResponseUpdate(
|
||||
contents=[
|
||||
Content.from_function_result(
|
||||
call_id=call_id,
|
||||
result=[state_update(text=text, state=state)],
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
|
||||
# ── Golden stream tests ──
|
||||
|
||||
|
||||
async def test_deterministic_state_emits_snapshot_after_tool_result() -> None:
|
||||
"""The happy path: STATE_SNAPSHOT follows TOOL_CALL_RESULT in order."""
|
||||
updates = [
|
||||
_tool_call("call-1", "get_weather", '{"city": "SF"}'),
|
||||
_tool_result_with_state(
|
||||
"call-1",
|
||||
text="Weather in SF: 14°C foggy",
|
||||
state={"weather": {"city": "SF", "temp": 14, "conditions": "foggy"}},
|
||||
),
|
||||
AgentResponseUpdate(
|
||||
contents=[Content.from_text(text="It's 14°C and foggy in SF.")],
|
||||
role="assistant",
|
||||
),
|
||||
]
|
||||
agent = _build_agent(updates)
|
||||
stream = await _run(agent, PAYLOAD)
|
||||
|
||||
stream.assert_bookends()
|
||||
stream.assert_no_run_error()
|
||||
stream.assert_tool_calls_balanced()
|
||||
stream.assert_text_messages_balanced()
|
||||
|
||||
# Ordered subsequence: the deterministic STATE_SNAPSHOT must follow the
|
||||
# TOOL_CALL_RESULT. This is the central contract for #3167.
|
||||
stream.assert_ordered_types(
|
||||
[
|
||||
"RUN_STARTED",
|
||||
"TOOL_CALL_START",
|
||||
"TOOL_CALL_ARGS",
|
||||
"TOOL_CALL_END",
|
||||
"TOOL_CALL_RESULT",
|
||||
"STATE_SNAPSHOT",
|
||||
"RUN_FINISHED",
|
||||
]
|
||||
)
|
||||
|
||||
# The final STATE_SNAPSHOT must carry the tool-driven state.
|
||||
snapshot = stream.snapshot()
|
||||
assert snapshot["weather"] == {"city": "SF", "temp": 14, "conditions": "foggy"}
|
||||
|
||||
|
||||
async def test_deterministic_state_does_not_fire_for_plain_tool_result() -> None:
|
||||
"""Regression guard: tools returning plain strings must NOT emit a new STATE_SNAPSHOT.
|
||||
|
||||
The initial STATE_SNAPSHOT fires once from the schema + initial payload
|
||||
state. A plain (non-state_update) tool result must not add another one.
|
||||
"""
|
||||
updates = [
|
||||
_tool_call("call-1", "get_weather", '{"city": "SF"}'),
|
||||
AgentResponseUpdate(
|
||||
contents=[Content.from_function_result(call_id="call-1", result="14°C foggy")],
|
||||
role="assistant",
|
||||
),
|
||||
AgentResponseUpdate(
|
||||
contents=[Content.from_text(text="It's 14°C and foggy.")],
|
||||
role="assistant",
|
||||
),
|
||||
]
|
||||
agent = _build_agent(updates)
|
||||
stream = await _run(agent, PAYLOAD)
|
||||
|
||||
stream.assert_bookends()
|
||||
stream.assert_no_run_error()
|
||||
|
||||
snapshots = stream.get("STATE_SNAPSHOT")
|
||||
# Only the initial snapshot (from state_schema + payload state) should exist.
|
||||
# No deterministic snapshot should have been added by the plain tool result.
|
||||
assert len(snapshots) == 1, (
|
||||
f"Expected exactly 1 STATE_SNAPSHOT (initial only) for plain tool result; "
|
||||
f"got {len(snapshots)}. Snapshots: {[s.snapshot for s in snapshots]}"
|
||||
)
|
||||
|
||||
|
||||
async def test_deterministic_state_merges_into_initial_state() -> None:
|
||||
"""The tool-driven snapshot must merge into, not replace, pre-existing state keys."""
|
||||
payload = dict(PAYLOAD)
|
||||
payload["state"] = {"weather": {}, "user_preferences": {"unit": "C"}}
|
||||
|
||||
updates = [
|
||||
_tool_call("call-1", "get_weather", '{"city": "SF"}'),
|
||||
_tool_result_with_state(
|
||||
"call-1",
|
||||
text="Weather: 14°C",
|
||||
state={"weather": {"city": "SF", "temp": 14}},
|
||||
),
|
||||
]
|
||||
agent = _build_agent(updates, state_schema={**STATE_SCHEMA, "user_preferences": {"type": "object"}})
|
||||
stream = await _run(agent, payload)
|
||||
|
||||
stream.assert_bookends()
|
||||
stream.assert_no_run_error()
|
||||
|
||||
final_snapshot = stream.snapshot()
|
||||
assert final_snapshot["weather"] == {"city": "SF", "temp": 14}
|
||||
assert final_snapshot["user_preferences"] == {"unit": "C"}, (
|
||||
"Pre-existing state keys must survive the deterministic merge"
|
||||
)
|
||||
|
||||
|
||||
async def test_deterministic_state_llm_visible_text_is_clean() -> None:
|
||||
"""The LLM-visible TOOL_CALL_RESULT content must not leak the state marker key."""
|
||||
updates = [
|
||||
_tool_call("call-1", "get_weather", '{"city": "SF"}'),
|
||||
_tool_result_with_state(
|
||||
"call-1",
|
||||
text="Weather in SF: 14°C foggy",
|
||||
state={"weather": {"city": "SF", "temp": 14}},
|
||||
),
|
||||
]
|
||||
agent = _build_agent(updates)
|
||||
stream = await _run(agent, PAYLOAD)
|
||||
|
||||
result = stream.first("TOOL_CALL_RESULT")
|
||||
assert result.content == "Weather in SF: 14°C foggy"
|
||||
# The marker key must never appear in the content sent back to the LLM.
|
||||
assert "__ag_ui_tool_result_state__" not in result.content
|
||||
assert "weather" not in result.content # not as a raw state dump
|
||||
|
||||
|
||||
async def test_deterministic_state_multiple_tools_merge_in_order() -> None:
|
||||
"""Two state-updating tools in one run merge in order; later wins on key collisions."""
|
||||
updates = [
|
||||
_tool_call("call-a", "get_weather", '{"city": "SF"}'),
|
||||
_tool_result_with_state(
|
||||
"call-a",
|
||||
text="First result",
|
||||
state={"weather": {"city": "SF", "temp": 14}, "source": "primary"},
|
||||
),
|
||||
_tool_call("call-b", "get_weather_refined", '{"city": "SF"}'),
|
||||
_tool_result_with_state(
|
||||
"call-b",
|
||||
text="Refined result",
|
||||
state={"source": "refined"},
|
||||
),
|
||||
AgentResponseUpdate(
|
||||
contents=[Content.from_text(text="Here you go.")],
|
||||
role="assistant",
|
||||
),
|
||||
]
|
||||
agent = _build_agent(
|
||||
updates,
|
||||
state_schema={**STATE_SCHEMA, "source": {"type": "string"}},
|
||||
)
|
||||
stream = await _run(agent, PAYLOAD)
|
||||
|
||||
stream.assert_bookends()
|
||||
stream.assert_tool_calls_balanced()
|
||||
stream.assert_no_run_error()
|
||||
|
||||
# Two tool-driven snapshots emitted (one per tool) plus the initial snapshot.
|
||||
snapshots = stream.get("STATE_SNAPSHOT")
|
||||
assert len(snapshots) >= 2, f"Expected at least 2 STATE_SNAPSHOTs; got {len(snapshots)}"
|
||||
|
||||
final = stream.snapshot()
|
||||
assert final["weather"] == {"city": "SF", "temp": 14}
|
||||
# Later tool must override earlier tool on the shared key.
|
||||
assert final["source"] == "refined"
|
||||
|
||||
|
||||
async def test_deterministic_state_coexists_with_predict_state_config() -> None:
|
||||
"""Predictive state and deterministic state must coexist without clobbering each other."""
|
||||
predict_config = {
|
||||
"draft": {
|
||||
"tool": "write_draft",
|
||||
"tool_argument": "body",
|
||||
}
|
||||
}
|
||||
updates = [
|
||||
# Predictive tool: its argument "body" populates state.draft optimistically.
|
||||
_tool_call("call-1", "write_draft", '{"body": "Hello world"}'),
|
||||
# Then a deterministic tool result landing a different key.
|
||||
_tool_result_with_state(
|
||||
"call-1",
|
||||
text="Draft saved",
|
||||
state={"weather": {"city": "SF", "temp": 14}},
|
||||
),
|
||||
]
|
||||
agent = _build_agent(
|
||||
updates,
|
||||
state_schema={**STATE_SCHEMA, "draft": {"type": "string"}},
|
||||
predict_state_config=predict_config,
|
||||
require_confirmation=False,
|
||||
)
|
||||
payload = dict(PAYLOAD)
|
||||
payload["state"] = {"weather": {}, "draft": ""}
|
||||
stream = await _run(agent, payload)
|
||||
|
||||
stream.assert_bookends()
|
||||
stream.assert_no_run_error()
|
||||
stream.assert_tool_calls_balanced()
|
||||
|
||||
# The final observed state must contain both the deterministic and predictive contributions.
|
||||
final = stream.snapshot()
|
||||
assert final["weather"] == {"city": "SF", "temp": 14}, f"Deterministic state missing from final snapshot: {final}"
|
||||
@@ -1405,3 +1405,95 @@ async def test_fabricated_rejection_without_pending_approval_is_blocked(streamin
|
||||
for content in msg.contents:
|
||||
if content.type == "function_result" and content.call_id == "fake_reject_001":
|
||||
assert False, "Fabricated rejection response leaked as function_result into LLM messages"
|
||||
|
||||
|
||||
async def test_state_update_end_to_end_via_real_tool_invocation(streaming_chat_client_stub):
|
||||
"""End-to-end coverage for issue #3167: a real ``@tool`` returning ``state_update`` must
|
||||
emit a deterministic STATE_SNAPSHOT through the full pipeline.
|
||||
|
||||
This test exercises the entire chain that a user would hit in production:
|
||||
``FunctionInvocationLayer`` executes the tool, ``FunctionTool.parse_result``
|
||||
preserves the returned ``Content`` with its ``additional_properties`` marker,
|
||||
``Content.from_function_result`` carries the marker through in ``items``,
|
||||
and the AG-UI emitter extracts it via ``_extract_tool_result_state`` and
|
||||
emits the snapshot. A regression anywhere in that chain will fail this test.
|
||||
"""
|
||||
from agent_framework import tool
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
from agent_framework_ag_ui import state_update
|
||||
|
||||
@tool(name="get_weather", description="Get current weather for a city.")
|
||||
async def get_weather(city: str) -> Content:
|
||||
return state_update(
|
||||
text=f"Weather in {city}: 14°C foggy",
|
||||
state={"weather": {"city": city, "temperature": 14, "conditions": "foggy"}},
|
||||
)
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[Message], options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
"""First turn proposes a tool call; second turn (after tool execution) returns text."""
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
yield ChatResponseUpdate(
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
name="get_weather",
|
||||
call_id="call-weather-1",
|
||||
arguments='{"city": "SF"}',
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
yield ChatResponseUpdate(contents=[Content.from_text(text="It's 14°C and foggy in SF.")])
|
||||
|
||||
agent = Agent(
|
||||
client=streaming_chat_client_stub(stream_fn),
|
||||
name="weather_agent",
|
||||
instructions="Answer weather questions.",
|
||||
tools=[get_weather],
|
||||
)
|
||||
wrapper = AgentFrameworkAgent(
|
||||
agent=agent,
|
||||
state_schema={"weather": {"type": "object"}},
|
||||
)
|
||||
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run(
|
||||
{
|
||||
"thread_id": "thread-weather",
|
||||
"run_id": "run-weather",
|
||||
"messages": [{"role": "user", "content": "What's the weather in SF?"}],
|
||||
"state": {"weather": {}},
|
||||
}
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
types = [e.type for e in events]
|
||||
|
||||
# The tool call must be visible in the stream.
|
||||
assert "TOOL_CALL_START" in types, f"Missing TOOL_CALL_START in: {types}"
|
||||
assert "TOOL_CALL_RESULT" in types, f"Missing TOOL_CALL_RESULT in: {types}"
|
||||
|
||||
# A STATE_SNAPSHOT must be emitted after the tool result.
|
||||
tool_result_idx = types.index("TOOL_CALL_RESULT")
|
||||
snapshot_indices_after_result = [i for i, t in enumerate(types) if t == "STATE_SNAPSHOT" and i > tool_result_idx]
|
||||
assert snapshot_indices_after_result, (
|
||||
f"Expected a STATE_SNAPSHOT after TOOL_CALL_RESULT (index {tool_result_idx}); got types: {types}"
|
||||
)
|
||||
|
||||
# The tool's deterministic snapshot carries the actual fetched weather data.
|
||||
final_snapshot = events[snapshot_indices_after_result[-1]].snapshot
|
||||
assert final_snapshot["weather"] == {
|
||||
"city": "SF",
|
||||
"temperature": 14,
|
||||
"conditions": "foggy",
|
||||
}
|
||||
|
||||
# The LLM-visible tool result must carry the plain text, not the marker key.
|
||||
tool_result_event = next(e for e in events if e.type == "TOOL_CALL_RESULT")
|
||||
assert tool_result_event.content == "Weather in SF: 14°C foggy"
|
||||
assert "__ag_ui_tool_result_state__" not in tool_result_event.content
|
||||
|
||||
@@ -18,7 +18,24 @@ def test_core_ag_ui_lazy_exports_include_only_stable_api() -> None:
|
||||
assert hasattr(ag_ui, "AgentFrameworkAgent")
|
||||
assert hasattr(ag_ui, "AGUIChatClient")
|
||||
assert hasattr(ag_ui, "add_agent_framework_fastapi_endpoint")
|
||||
assert hasattr(ag_ui, "state_update")
|
||||
|
||||
assert not hasattr(ag_ui, "WorkflowFactory")
|
||||
assert not hasattr(ag_ui, "AGUIRequest")
|
||||
assert not hasattr(ag_ui, "RunMetadata")
|
||||
|
||||
|
||||
def test_agent_framework_ag_ui_exports_state_update() -> None:
|
||||
"""Runtime package should export the ``state_update`` helper."""
|
||||
from agent_framework_ag_ui import state_update
|
||||
|
||||
assert callable(state_update)
|
||||
|
||||
|
||||
def test_core_ag_ui_lazy_exports_include_event_converter_and_http_service() -> None:
|
||||
"""Core facade must expose AGUIEventConverter, AGUIHttpService, and __version__."""
|
||||
from agent_framework import ag_ui
|
||||
|
||||
assert hasattr(ag_ui, "AGUIEventConverter")
|
||||
assert hasattr(ag_ui, "AGUIHttpService")
|
||||
assert hasattr(ag_ui, "__version__")
|
||||
|
||||
@@ -2,14 +2,20 @@
|
||||
|
||||
"""Tests for _run_common.py edge cases."""
|
||||
|
||||
from ag_ui.core import EventType
|
||||
from agent_framework import Content
|
||||
|
||||
from agent_framework_ag_ui import state_update
|
||||
from agent_framework_ag_ui._orchestration._predictive_state import PredictiveStateHandler
|
||||
from agent_framework_ag_ui._run_common import (
|
||||
FlowState,
|
||||
_emit_mcp_tool_result,
|
||||
_emit_tool_result,
|
||||
_extract_resume_payload,
|
||||
_extract_tool_result_state,
|
||||
_normalize_resume_interrupts,
|
||||
)
|
||||
from agent_framework_ag_ui._state import TOOL_RESULT_STATE_KEY
|
||||
|
||||
|
||||
class TestNormalizeResumeInterrupts:
|
||||
@@ -120,3 +126,223 @@ class TestEmitToolResult:
|
||||
assert "TEXT_MESSAGE_END" in event_types
|
||||
assert flow.message_id is None
|
||||
assert flow.accumulated_text == ""
|
||||
|
||||
|
||||
class TestStateUpdateHelper:
|
||||
"""Tests for the public ``state_update`` helper."""
|
||||
|
||||
def test_builds_text_content_with_state_marker(self):
|
||||
"""state_update returns a text Content carrying state in additional_properties."""
|
||||
c = state_update(text="done", state={"weather": {"temp": 14}})
|
||||
assert c.type == "text"
|
||||
assert c.text == "done"
|
||||
assert c.additional_properties == {
|
||||
TOOL_RESULT_STATE_KEY: {"weather": {"temp": 14}},
|
||||
}
|
||||
|
||||
def test_empty_text_is_allowed(self):
|
||||
"""State-only tools can omit the text argument."""
|
||||
c = state_update(state={"steps": ["a", "b"]})
|
||||
assert c.text == ""
|
||||
assert c.additional_properties[TOOL_RESULT_STATE_KEY] == {"steps": ["a", "b"]}
|
||||
|
||||
def test_non_mapping_state_raises(self):
|
||||
"""Passing a non-mapping value for state raises TypeError."""
|
||||
import pytest
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
state_update(text="t", state=["not", "a", "mapping"]) # type: ignore[arg-type]
|
||||
|
||||
def test_state_is_copied_defensively(self):
|
||||
"""Mutating the caller's dict after ``state_update`` must not mutate the content."""
|
||||
caller_state = {"weather": {"temp": 14}}
|
||||
c = state_update(text="ok", state=caller_state)
|
||||
caller_state["weather"]["temp"] = 99
|
||||
# The top-level dict was copied, so replacing the key in caller_state
|
||||
# would not affect the Content, but nested dicts share references — document
|
||||
# this by asserting only the top-level copy semantics.
|
||||
assert TOOL_RESULT_STATE_KEY in c.additional_properties
|
||||
inner = c.additional_properties[TOOL_RESULT_STATE_KEY]
|
||||
assert inner is not caller_state
|
||||
|
||||
|
||||
class TestExtractToolResultState:
|
||||
"""Tests for ``_extract_tool_result_state``."""
|
||||
|
||||
def test_returns_none_for_plain_string_result(self):
|
||||
content = Content.from_function_result(call_id="c1", result="plain")
|
||||
assert _extract_tool_result_state(content) is None
|
||||
|
||||
def test_extracts_state_from_inner_item(self):
|
||||
tool_return = state_update(text="hi", state={"k": 1})
|
||||
content = Content.from_function_result(call_id="c1", result=[tool_return])
|
||||
assert _extract_tool_result_state(content) == {"k": 1}
|
||||
|
||||
def test_extracts_state_from_outer_additional_properties(self):
|
||||
"""Outer function_result content can also carry state (legacy/advanced use)."""
|
||||
content = Content.from_function_result(
|
||||
call_id="c1",
|
||||
result="hi",
|
||||
additional_properties={TOOL_RESULT_STATE_KEY: {"k": 1}},
|
||||
)
|
||||
assert _extract_tool_result_state(content) == {"k": 1}
|
||||
|
||||
def test_merges_multiple_items(self):
|
||||
a = state_update(text="a", state={"k": 1, "shared": "from_a"})
|
||||
b = state_update(text="b", state={"shared": "from_b", "extra": True})
|
||||
content = Content.from_function_result(call_id="c1", result=[a, b])
|
||||
merged = _extract_tool_result_state(content)
|
||||
assert merged == {"k": 1, "shared": "from_b", "extra": True}
|
||||
|
||||
def test_ignores_non_dict_marker_value(self):
|
||||
"""A garbled marker value must not break extraction (defensive guard)."""
|
||||
bad = Content.from_text(
|
||||
"hi",
|
||||
additional_properties={TOOL_RESULT_STATE_KEY: "not-a-dict"},
|
||||
)
|
||||
content = Content.from_function_result(call_id="c1", result=[bad])
|
||||
assert _extract_tool_result_state(content) is None
|
||||
|
||||
|
||||
class TestEmitToolResultWithState:
|
||||
"""Tests for the deterministic state emission in ``_emit_tool_result``."""
|
||||
|
||||
def test_emits_state_snapshot_after_tool_call_result(self):
|
||||
"""Tool returning state_update produces a StateSnapshotEvent right after the result."""
|
||||
tool_return = state_update(
|
||||
text="Weather: 14°C",
|
||||
state={"weather": {"temp": 14, "conditions": "foggy"}},
|
||||
)
|
||||
content = Content.from_function_result(call_id="call_1", result=[tool_return])
|
||||
flow = FlowState()
|
||||
|
||||
events = _emit_tool_result(content, flow)
|
||||
event_types = [e.type for e in events]
|
||||
|
||||
# Expect TOOL_CALL_END, TOOL_CALL_RESULT, STATE_SNAPSHOT in that order.
|
||||
assert event_types[0] == EventType.TOOL_CALL_END
|
||||
assert event_types[1] == EventType.TOOL_CALL_RESULT
|
||||
state_idx = event_types.index(EventType.STATE_SNAPSHOT)
|
||||
assert state_idx == 2
|
||||
assert events[state_idx].snapshot == {"weather": {"temp": 14, "conditions": "foggy"}}
|
||||
|
||||
def test_updates_flow_current_state(self):
|
||||
tool_return = state_update(text="", state={"a": 1})
|
||||
content = Content.from_function_result(call_id="c1", result=[tool_return])
|
||||
flow = FlowState(current_state={"existing": "value"})
|
||||
|
||||
_emit_tool_result(content, flow)
|
||||
|
||||
# Existing keys must survive (merge semantics), new keys must be added.
|
||||
assert flow.current_state == {"existing": "value", "a": 1}
|
||||
|
||||
def test_merge_overrides_existing_key(self):
|
||||
tool_return = state_update(text="", state={"existing": "new"})
|
||||
content = Content.from_function_result(call_id="c1", result=[tool_return])
|
||||
flow = FlowState(current_state={"existing": "old", "other": 1})
|
||||
|
||||
_emit_tool_result(content, flow)
|
||||
|
||||
assert flow.current_state == {"existing": "new", "other": 1}
|
||||
|
||||
def test_no_state_snapshot_when_result_has_no_state(self):
|
||||
"""Plain tool results must not emit a StateSnapshotEvent."""
|
||||
content = Content.from_function_result(call_id="c1", result="plain")
|
||||
flow = FlowState()
|
||||
|
||||
events = _emit_tool_result(content, flow)
|
||||
assert all(e.type != EventType.STATE_SNAPSHOT for e in events)
|
||||
|
||||
def test_tool_result_content_text_unchanged(self):
|
||||
"""The text sent to the LLM must not leak the state marker."""
|
||||
tool_return = state_update(text="Weather: 14°C", state={"weather": {"temp": 14}})
|
||||
content = Content.from_function_result(call_id="c1", result=[tool_return])
|
||||
flow = FlowState()
|
||||
|
||||
events = _emit_tool_result(content, flow)
|
||||
result_events = [e for e in events if e.type == EventType.TOOL_CALL_RESULT]
|
||||
assert len(result_events) == 1
|
||||
assert result_events[0].content == "Weather: 14°C"
|
||||
assert TOOL_RESULT_STATE_KEY not in result_events[0].content
|
||||
|
||||
def test_coexists_with_active_predictive_state_handler(self):
|
||||
"""Both predictive and deterministic state produce a single coalesced snapshot.
|
||||
|
||||
Predictive state (``predict_state_config``) and deterministic state
|
||||
(``state_update``) are two independent mechanisms. When both are active,
|
||||
a single coalesced ``StateSnapshotEvent`` is emitted containing the
|
||||
merged result of both contributions.
|
||||
"""
|
||||
flow = FlowState(current_state={"preexisting": "value"})
|
||||
handler = PredictiveStateHandler(
|
||||
predict_state_config={"draft": {"tool": "write_draft", "tool_argument": "body"}},
|
||||
current_state=flow.current_state,
|
||||
)
|
||||
|
||||
tool_return = state_update(text="Draft written", state={"draft_final": True})
|
||||
content = Content.from_function_result(call_id="c1", result=[tool_return])
|
||||
|
||||
events = _emit_tool_result(content, flow, predictive_handler=handler)
|
||||
|
||||
# Exactly one coalesced snapshot must be emitted containing all merged keys.
|
||||
snapshots = [e for e in events if e.type == EventType.STATE_SNAPSHOT]
|
||||
assert len(snapshots) == 1
|
||||
assert snapshots[0].snapshot["draft_final"] is True
|
||||
assert snapshots[0].snapshot["preexisting"] == "value"
|
||||
assert flow.current_state["draft_final"] is True
|
||||
assert flow.current_state["preexisting"] == "value"
|
||||
|
||||
def test_predictive_and_deterministic_emit_single_snapshot(self):
|
||||
"""When both predictive_handler and state_update are active, only one snapshot is emitted."""
|
||||
flow = FlowState(current_state={"existing": "yes"})
|
||||
handler = PredictiveStateHandler(
|
||||
predict_state_config={"draft": {"tool": "write_draft", "tool_argument": "body"}},
|
||||
current_state=flow.current_state,
|
||||
)
|
||||
|
||||
tool_return = state_update(text="ok", state={"new_key": 42})
|
||||
content = Content.from_function_result(call_id="c1", result=[tool_return])
|
||||
|
||||
events = _emit_tool_result(content, flow, predictive_handler=handler)
|
||||
|
||||
snapshots = [e for e in events if e.type == EventType.STATE_SNAPSHOT]
|
||||
assert len(snapshots) == 1, f"Expected 1 coalesced snapshot, got {len(snapshots)}"
|
||||
assert snapshots[0].snapshot == {"existing": "yes", "new_key": 42}
|
||||
|
||||
|
||||
class TestEmitMcpToolResultWithState:
|
||||
"""MCP tool results should honour the same state_update marker.
|
||||
|
||||
MCP results come from an external MCP server rather than a locally
|
||||
executed ``@tool`` function, so they do not flow through ``parse_result``
|
||||
and ``content.items`` is typically empty. State is instead carried on the
|
||||
outer content's ``additional_properties`` (e.g. by middleware that
|
||||
inspects the MCP output and attaches a marker). ``_extract_tool_result_state``
|
||||
supports both locations so this path remains usable.
|
||||
"""
|
||||
|
||||
def test_mcp_tool_result_emits_state_snapshot_from_additional_properties(self):
|
||||
content = Content.from_mcp_server_tool_result(
|
||||
call_id="mcp_1",
|
||||
output="server result",
|
||||
additional_properties={TOOL_RESULT_STATE_KEY: {"mcp_ok": True}},
|
||||
)
|
||||
flow = FlowState()
|
||||
|
||||
events = _emit_mcp_tool_result(content, flow)
|
||||
event_types = [e.type for e in events]
|
||||
|
||||
assert EventType.TOOL_CALL_END in event_types
|
||||
assert EventType.TOOL_CALL_RESULT in event_types
|
||||
assert EventType.STATE_SNAPSHOT in event_types
|
||||
assert flow.current_state == {"mcp_ok": True}
|
||||
|
||||
def test_mcp_tool_result_without_state_emits_no_snapshot(self):
|
||||
content = Content.from_mcp_server_tool_result(
|
||||
call_id="mcp_1",
|
||||
output="server result",
|
||||
)
|
||||
flow = FlowState()
|
||||
|
||||
events = _emit_mcp_tool_result(content, flow)
|
||||
assert all(e.type != EventType.STATE_SNAPSHOT for e in events)
|
||||
|
||||
@@ -7,10 +7,13 @@ This module lazily re-exports objects from:
|
||||
|
||||
Supported classes and functions:
|
||||
- AgentFrameworkAgent
|
||||
- AgentFrameworkWorkflow
|
||||
- AGUIChatClient
|
||||
- AGUIEventConverter
|
||||
- AGUIHttpService
|
||||
- add_agent_framework_fastapi_endpoint
|
||||
- state_update
|
||||
- __version__
|
||||
"""
|
||||
|
||||
import importlib
|
||||
@@ -23,6 +26,10 @@ _IMPORTS = [
|
||||
"AgentFrameworkWorkflow",
|
||||
"add_agent_framework_fastapi_endpoint",
|
||||
"AGUIChatClient",
|
||||
"AGUIEventConverter",
|
||||
"AGUIHttpService",
|
||||
"state_update",
|
||||
"__version__",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from agent_framework_ag_ui import (
|
||||
AGUIHttpService,
|
||||
__version__,
|
||||
add_agent_framework_fastapi_endpoint,
|
||||
state_update,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -18,4 +19,5 @@ __all__ = [
|
||||
"AgentFrameworkWorkflow",
|
||||
"__version__",
|
||||
"add_agent_framework_fastapi_endpoint",
|
||||
"state_update",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user