Python: Add opt-in AG-UI thread snapshot persistence and hydration (#6471)

* feat(ag-ui): add thread snapshot store primitives

Key decisions:\n- Introduce an AGUIThreadSnapshot model limited to replayable messages, optional Shared State, and optional interrupt state.\n- Define AGUIThreadSnapshotStore as an async protocol keyed by explicit Snapshot Scope and AG-UI Thread id.\n- Add InMemoryAGUIThreadSnapshotStore as memory-only, latest-only, bounded local/demo/test storage; no file-backed store is introduced.\n- Require snapshot_scope_resolver whenever an endpoint is configured with a snapshot store, including pre-wrapped runners, so thread ids are not authorization boundaries.\n\nFiles changed:\n- packages/ag-ui/agent_framework_ag_ui/_snapshots.py\n- packages/ag-ui/agent_framework_ag_ui/__init__.py\n- packages/ag-ui/agent_framework_ag_ui/_agent.py\n- packages/ag-ui/agent_framework_ag_ui/_workflow.py\n- packages/ag-ui/agent_framework_ag_ui/_endpoint.py\n- packages/core/agent_framework/ag_ui/__init__.py\n- packages/core/agent_framework/ag_ui/__init__.pyi\n- packages/ag-ui/tests/ag_ui/test_snapshots.py\n- packages/ag-ui/tests/ag_ui/test_endpoint.py\n- packages/ag-ui/tests/ag_ui/test_public_exports.py\n- packages/ag-ui/AGENTS.md\n\nVerification:\n- uv run pytest packages/ag-ui/tests/ag_ui/test_snapshots.py packages/ag-ui/tests/ag_ui/test_public_exports.py packages/ag-ui/tests/ag_ui/test_endpoint.py::test_endpoint_requires_snapshot_scope_resolver_when_store_configured packages/ag-ui/tests/ag_ui/test_endpoint.py::test_endpoint_accepts_snapshot_store_with_scope_resolver -q\n- uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_endpoint_requires_snapshot_scope_resolver_when_store_configured packages/ag-ui/tests/ag_ui/test_endpoint.py::test_endpoint_requires_snapshot_scope_resolver_when_wrapped_runner_has_store packages/ag-ui/tests/ag_ui/test_endpoint.py::test_endpoint_accepts_snapshot_store_with_scope_resolver -q\n- uv run poe syntax -P ag-ui -C\n- uv run poe pyright -P ag-ui\n- uv run poe syntax -P core -C\n- uv run poe pyright -P core\n- uv run poe typing -P ag-ui\n- uv run poe typing -P core\n- uv run poe test -P ag-ui\n- uv run poe check -P ag-ui\n- git diff --check\n- git diff --cached --check\n\nBlockers / next iteration:\n- No blockers. Next slice can use the store contract to capture and hydrate agent snapshots.\n- uv repeatedly refreshed azure-ai-projects in uv.lock during local runs; reverted the generated lockfile churn because this change does not alter dependencies.\n- The poe-check commit hook was skipped after manual verification because it reformatted unrelated core MCP files outside this task.

* feat(ag-ui): hydrate agent threads from snapshots

Key decisions:
- Resolve Snapshot Scope per endpoint request and pass it to the AG-UI runner only when snapshot storage is active.
- Treat empty messages with no resume payload as an agent Hydrate Request when a scoped snapshot store is configured, replaying stored Shared State and message snapshots without invoking the wrapped agent.
- Save the latest replayable agent message snapshot and Shared State at normal completion under Snapshot Scope plus AG-UI Thread id; no durable or file-backed store is introduced.

Files changed:
- packages/ag-ui/agent_framework_ag_ui/_agent_run.py
- packages/ag-ui/agent_framework_ag_ui/_endpoint.py
- packages/ag-ui/agent_framework_ag_ui/_snapshots.py
- packages/ag-ui/tests/ag_ui/test_endpoint.py

Verification:
- uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_hydrates_stored_thread_snapshot_without_invoking_agent -q
- uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_hydrates_stored_thread_snapshot_without_invoking_agent packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_hydrates_snapshots_by_scope_and_thread -q
- uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_endpoint_empty_messages packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_hydrates_stored_thread_snapshot_without_invoking_agent packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_hydrates_snapshots_by_scope_and_thread -q
- uv run poe syntax -P ag-ui -C
- uv run poe pyright -P ag-ui
- uv run poe typing -P ag-ui
- uv run poe test -P ag-ui
- uv run poe check -P ag-ui
- git diff --check
- git diff --cached --check

Blockers / next iteration:
- No blockers. Next slice can reconstruct normal new-user agent turns from stored snapshots.
- uv repeatedly refreshed azure-ai-projects in uv.lock during local runs; reverted the generated lockfile churn because this change does not alter dependencies.
- The poe-check commit hook was skipped after manual verification because it refreshed unrelated uv.lock dependency resolution.

* feat(ag-ui): reconstruct agent turns from snapshots

Key decisions:
- Load scoped thread snapshots for non-hydrate agent requests only when snapshot storage is active and no resume payload is present.
- Rebuild prior AG-UI history from stored snapshot messages, preserving the incoming new user suffix and treating stored snapshot content as authoritative over conflicting prior client history.
- Merge stored Shared State with request state overrides before schema defaults and existing state-context injection.

Files changed:
- packages/ag-ui/agent_framework_ag_ui/_agent_run.py
- packages/ag-ui/tests/ag_ui/test_endpoint.py

Verification:
- uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_prepends_stored_snapshot_for_new_user_turn -q
- uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_deduplicates_full_history_and_merges_fresh_state -q
- uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_endpoint_empty_messages packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_hydrates_stored_thread_snapshot_without_invoking_agent packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_hydrates_snapshots_by_scope_and_thread packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_prepends_stored_snapshot_for_new_user_turn packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_deduplicates_full_history_and_merges_fresh_state -q
- uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py -q
- uv run poe syntax -P ag-ui -C
- uv run poe pyright -P ag-ui
- uv run poe test -P ag-ui
- uv run poe check -P ag-ui
- uv run poe typing -P ag-ui
- git diff --check
- git diff --cached --check

Blockers / next iteration:
- No blockers. Next slice can enable workflow AG-UI Thread Snapshot persistence and hydration.
- uv repeatedly refreshed azure-ai-projects in uv.lock during local runs; reverted the generated lockfile churn because this change does not alter dependencies.
- The poe-check commit hook was skipped after manual verification because it refreshes unrelated uv.lock dependency resolution.

* feat(ag-ui): hydrate workflow threads from snapshots

Key decisions:
- Handle workflow Hydrate Requests before resolving or invoking the wrapped workflow when snapshot storage and Snapshot Scope are active.
- Capture only replayable workflow protocol data: workflow-emitted state snapshots, workflow-emitted message snapshots, and synthesized messages from text/tool output.
- Keep workflow snapshot capture inactive without configured persistence, and skip saving snapshots when the workflow stream emits RUN_ERROR.

Files changed:
- packages/ag-ui/agent_framework_ag_ui/_workflow.py
- packages/ag-ui/tests/ag_ui/test_endpoint.py

Verification:
- uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_workflow_endpoint_hydrates_emitted_snapshots_without_invoking_workflow packages/ag-ui/tests/ag_ui/test_endpoint.py::test_workflow_endpoint_hydrates_synthesized_text_and_tool_snapshot -q
- uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py -q
- uv run pytest packages/ag-ui/tests/ag_ui/golden/test_scenario_workflow.py -q
- uv run poe syntax -P ag-ui -C
- uv run poe pyright -P ag-ui
- uv run poe test -P ag-ui
- uv run poe typing -P ag-ui
- uv run poe check -P ag-ui
- git diff --check
- git diff --cached --check

Blockers / next iteration:
- No blockers. Next slice can preserve interruption state and protect snapshots on errors across agent and workflow endpoints.
- uv repeatedly refreshed azure-ai-projects in uv.lock during local runs; reverted the generated lockfile churn because this change does not alter dependencies.
- The poe-check commit hook was skipped after manual verification because it refreshes unrelated uv.lock dependency resolution.

* feat(ag-ui): preserve interrupted thread snapshots

Key decisions:
- Capture workflow RUN_FINISHED interrupt metadata in replayable AG-UI Thread Snapshots so Hydrate Requests can restore pending workflow actions without invoking or resuming the workflow.
- Keep failed agent and workflow runs from replacing the last good snapshot; RUN_ERROR streams leave the previous snapshot available for hydration.
- Verify interruption hydration through endpoint-level AG-UI streams for both agent and workflow wrappers, including Shared State replay and no wrapped runner invocation.

Files changed:
- packages/ag-ui/agent_framework_ag_ui/_workflow.py
- packages/ag-ui/tests/ag_ui/test_endpoint.py

Verification:
- uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_workflow_endpoint_hydrates_interrupted_thread_without_invoking_workflow -q
- uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_hydrates_interrupted_thread_without_invoking_agent packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_run_error_does_not_overwrite_previous_snapshot packages/ag-ui/tests/ag_ui/test_endpoint.py::test_workflow_endpoint_hydrates_interrupted_thread_without_invoking_workflow packages/ag-ui/tests/ag_ui/test_endpoint.py::test_workflow_endpoint_run_error_does_not_overwrite_previous_snapshot -q
- uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py -q
- uv run pytest packages/ag-ui/tests/ag_ui/golden/test_scenario_workflow.py -q
- uv run poe syntax -P ag-ui -C
- uv run poe pyright -P ag-ui
- uv run poe test -P ag-ui
- uv run poe typing -P ag-ui
- uv run poe check -P ag-ui
- git diff --check
- git diff --cached --check

Blockers / next iteration:
- No blockers. Next slice can document AG-UI Thread Snapshot security and usage.
- uv repeatedly refreshed azure-ai-projects in uv.lock during local runs; reverted the generated lockfile churn because this change does not alter dependencies.
- The poe-check commit hook was skipped after manual verification because it refreshes unrelated uv.lock dependency resolution.

* docs(ag-ui): document thread snapshot security

Key decisions:
- Document AG-UI Thread Snapshot persistence as opt-in and disabled unless a snapshot_store is configured.
- Place Snapshot Scope guidance next to endpoint authentication guidance, making clear that AG-UI Thread ids identify threads but do not authorize snapshot access.
- Describe built-in storage as in-memory only, process-local, latest-only, and not durable production storage; durable stores remain app-owned implementations of AGUIThreadSnapshotStore.
- Call out snapshot confidentiality impact and that no file-backed AG-UI snapshot store is provided.

Files changed:
- packages/ag-ui/README.md

Verification:
- uv run python scripts/check_md_code_blocks.py packages/ag-ui/README.md --no-glob
- git diff --check
- git diff --cached --check
- commit hook without SKIP ran changed-package lint/format and AG-UI README markdown-code-lint successfully before stopping because uv.lock was modified
- uv run poe markdown-code-lint (failed due existing unrelated packages/mistral/README.md missing agent_framework_mistral import resolution; changed AG-UI README blocks passed)

Blockers / next iteration:
- No blockers. Local issue/PRD planning artifacts remain uncommitted.
- uv refreshed azure-ai-projects in uv.lock during markdown lint and the commit hook; reverted the generated lockfile churn because this documentation change does not alter dependencies.
- The poe-check commit hook was skipped after manual verification because it refreshes unrelated uv.lock dependency resolution.

* fix(ag-ui): harden thread snapshot persistence edge cases

- Persist the completed confirm_changes turn with interrupt=None so hydration
  no longer replays a stale pending interrupt after the user responds; resume
  requests prepend stored history so the persisted thread is not truncated.
- Defer endpoint default_state application to the runners when snapshot
  persistence is active, filling only keys missing from both the stored
  snapshot state and the request state so defaults never reset persisted
  Shared State.
- Always fold the turn's output into the persisted messages snapshot even when
  the outbound MESSAGES_SNAPSHOT event is suppressed for predictive tools
  without confirmation.
- Load the stored snapshot on workflow follow-up turns, reconstruct full
  thread history into the run input, and seed the snapshot builder with merged
  state so saving a new turn no longer replaces prior history.
- Move snapshot message reconstruction helpers to _run_common for reuse by the
  workflow runner; load stored agent snapshots on resume turns for state merge.
- Add endpoint regression tests for all four scenarios.

* fix(ag-ui): protect snapshot history on resume and harden suffix trust

- Prepend stored thread history when persisting snapshots for resume runs on
  both the agent and workflow paths, so a resumed interrupt no longer
  overwrites the stored thread with just the resume turn's output.
- Filter the incoming message suffix during thread reconstruction: only user
  turns and tool results answering backend-issued tool calls (stored tool
  calls or pending interrupts) may extend authoritative history. Client-forged
  assistant and tool messages are dropped and logged instead of being
  persisted and replayed.
- Close the workflow snapshot builder's tool-call group when a tool result or
  text message lands, so synthesized transcripts keep tool results adjacent to
  their tool_calls message and stay valid as provider replay history.
- Export DEFAULT_MAX_THREAD_SNAPSHOTS from agent_framework_ag_ui and expose
  SnapshotScopeResolver through the core ag_ui facade and stub.
- Add regression tests for agent and workflow resume history preservation,
  forged suffix rejection, builder tool-call grouping, and the export surface.

* fix(ag-ui): tolerate snapshot save failures and scope workflow cache

- Wrap snapshot_store.save() on both the agent and workflow paths so a
  transient store failure (timeout, connection refused) is logged instead of
  propagating. Previously a failing save converted an already-streamed
  successful run into RUN_ERROR, and on the workflow path emitted RUN_ERROR
  after RUN_FINISHED, violating the single-terminal-event invariant. The
  previous snapshot stays available for hydration.
- Key the workflow_factory instance cache by (snapshot_scope, thread_id). The
  Snapshot Scope is the authorization boundary, so the same thread id under
  different scopes no longer shares an in-memory workflow instance.
  clear_thread_workflow accepts an optional snapshot_scope and clears all
  scopes for the thread when omitted.
- Add tests: save-failure tolerance for agent and workflow endpoints,
  scope-isolated workflow cache, async snapshot_scope_resolver support, and
  in-memory store key validation errors.

* fix(ci): ignore all dotnet.microsoft.com links in linkspector

The existing ignore pattern only matched https://dotnet.microsoft.com/download,
but Microsoft sites insert a locale segment between host and path
(e.g. /en-us/download/dotnet/10.0), so localized links slip past the pattern
and get checked. dotnet.microsoft.com bot-blocks CI link checkers with
intermittent 403s across the whole site, which fails markdown-link-check on
unrelated pull requests since linkspector scans the entire repository.

Ignore the domain wholesale, matching how platform.openai.com is already
handled for the same reason. A 403 from bot blocking is indistinguishable
from a removed page, so the checker cannot produce a meaningful signal for
this domain either way.

* ag-ui: simplify raw_messages assignment and drop OrderedDict

- Replace list(cast(...)) with a typed annotation for raw_messages
  (_agent_run.py:866) per review suggestion
- Replace OrderedDict with a plain dict in InMemoryAGUIThreadSnapshotStore
  (_snapshots.py:136); regular dicts are insertion-order-safe since
  Python 3.7, so OrderedDict is unnecessary. Update _evict_oldest to use
  next(iter(...)) for FIFO removal instead of popitem(last=False).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Address review feedback for #2458: review comment fixes

---------

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Evan Mattson
2026-06-12 17:29:38 +09:00
committed by GitHub
Unverified
parent 4c1b9efa8c
commit 76b2b1bf39
18 changed files with 2419 additions and 42 deletions
+2
View File
@@ -10,10 +10,12 @@ AG-UI protocol integration for building agent UIs with the AG-UI standard.
- **`AGUIHttpService`** - HTTP service for AG-UI endpoints
- **`AGUIEventConverter`** - Converts between Agent Framework and AG-UI events
- **`add_agent_framework_fastapi_endpoint()`** - Add AG-UI endpoint to FastAPI app (`SupportsAgentRun` or `Workflow`)
- **`InMemoryAGUIThreadSnapshotStore`** - Memory-only latest AG-UI Thread Snapshot store for local development, demos, and tests
## Types
- **`AGUIRequest`** / **`AGUIChatOptions`** - Request types
- **`AGUIThreadSnapshot`** / **`AGUIThreadSnapshotStore`** - Replayable thread snapshot model and scoped async store protocol
- **`availableInterrupts` / `resume`** - Optional interrupt configuration and continuation payloads
- **`AgentState`** / **`RunMetadata`** - State management types
- **`PredictStateConfig`** - Configuration for state prediction
+65
View File
@@ -198,6 +198,71 @@ The `dependencies` parameter accepts any FastAPI dependency, enabling integratio
For a complete authentication example, see [getting_started/server.py](getting_started/server.py).
## AG-UI Thread Snapshots
AG-UI Thread Snapshot persistence is opt-in and disabled by default. Existing endpoints keep their current behavior
unless you provide a `snapshot_store`.
Thread snapshots let an AG-UI frontend recover replayable UI state after a refresh. When snapshot persistence is
enabled, the endpoint stores the latest replayable snapshot for an AG-UI Thread within an application-defined
Snapshot Scope. A Hydrate Request is an AG-UI request with a known `threadId`, `messages: []`, and no `resume`
payload. Hydration replays the stored Shared State, message snapshot, and interruption metadata when available,
then finishes without invoking the wrapped agent or workflow.
Use the built-in in-memory store for local development, demos, and tests:
```python
from fastapi import FastAPI
from agent_framework.ag_ui import InMemoryAGUIThreadSnapshotStore, add_agent_framework_fastapi_endpoint
app = FastAPI()
agent = ...
snapshot_store = InMemoryAGUIThreadSnapshotStore(max_snapshots=500)
def resolve_snapshot_scope(request):
# Local demo scope. Production apps should derive the scope from authenticated user or tenant context.
del request
return "local-demo"
add_agent_framework_fastapi_endpoint(
app,
agent,
"/",
snapshot_store=snapshot_store,
snapshot_scope_resolver=resolve_snapshot_scope,
)
```
A frontend can then hydrate the latest stored snapshot for the scoped thread:
```json
{
"threadId": "thread-1",
"messages": []
}
```
Endpoint configuration requires `snapshot_scope_resolver` whenever a snapshot store is configured, including when
the store is already set on a pre-wrapped `AgentFrameworkAgent` or `AgentFrameworkWorkflow`. The resolver returns
the application-defined Snapshot Scope used with the AG-UI Thread id as the storage key.
AG-UI Thread ids identify AG-UI Threads; they do not authorize snapshot access. Do not treat a thread id as a bearer
credential or tenant boundary. Production applications must authenticate and authorize every AG-UI endpoint request
and choose a Snapshot Scope that represents the app's real access boundary, such as an authenticated user, tenant,
or workspace. Do not rely on untrusted client-provided fields by themselves to choose that boundary.
Stored snapshots are untrusted application data with confidentiality impact. They may contain sensitive user text,
model output, tool results, function arguments, UI payloads, Shared State, and interruption data. The built-in
`InMemoryAGUIThreadSnapshotStore` is in-memory only, process-local, bounded, latest-only, and not durable production
storage. It is cleared on process restart and is not shared across workers.
No file-backed AG-UI snapshot store is provided by the package. Applications that need durable persistence should
provide an app-owned implementation of the `AGUIThreadSnapshotStore` protocol and own storage hardening, including
encryption, access control, retention, audit, data residency, and deletion behavior.
## Architecture
The package uses a clean, orchestrator-based architecture:
@@ -9,6 +9,15 @@ from ._client import AGUIChatClient
from ._endpoint import add_agent_framework_fastapi_endpoint
from ._event_converters import AGUIEventConverter
from ._http_service import AGUIHttpService
from ._snapshots import (
DEFAULT_MAX_THREAD_SNAPSHOTS,
AGUIThreadID,
AGUIThreadSnapshot,
AGUIThreadSnapshotStore,
InMemoryAGUIThreadSnapshotStore,
SnapshotScope,
SnapshotScopeResolver,
)
from ._state import state_update
from ._types import AgentState, AGUIChatOptions, AGUIRequest, PredictStateConfig, RunMetadata
from ._workflow import AgentFrameworkWorkflow, WorkflowFactory
@@ -31,9 +40,16 @@ __all__ = [
"AGUIEventConverter",
"AGUIHttpService",
"AGUIRequest",
"AGUIThreadID",
"AGUIThreadSnapshot",
"AGUIThreadSnapshotStore",
"AgentState",
"InMemoryAGUIThreadSnapshotStore",
"PredictStateConfig",
"RunMetadata",
"SnapshotScope",
"SnapshotScopeResolver",
"DEFAULT_MAX_THREAD_SNAPSHOTS",
"DEFAULT_TAGS",
"state_update",
"__version__",
@@ -10,6 +10,7 @@ from ag_ui.core import BaseEvent
from agent_framework import SupportsAgentRun
from ._agent_run import PendingApprovalEntry, run_agent_stream
from ._snapshots import AGUIThreadSnapshotStore
class AgentConfig:
@@ -21,6 +22,7 @@ class AgentConfig:
predict_state_config: dict[str, dict[str, str]] | None = None,
use_service_session: bool = False,
require_confirmation: bool = True,
snapshot_store: AGUIThreadSnapshotStore | None = None,
):
"""Initialize agent configuration.
@@ -29,11 +31,14 @@ class AgentConfig:
predict_state_config: Configuration for predictive state updates
use_service_session: Whether the agent session is service-managed
require_confirmation: Whether predictive updates require user confirmation before applying
snapshot_store: Optional AG-UI Thread Snapshot store. Snapshot persistence remains inactive unless
endpoint setup also provides an explicit Snapshot Scope resolver.
"""
self.state_schema = self._normalize_state_schema(state_schema)
self.predict_state_config = predict_state_config or {}
self.use_service_session = use_service_session
self.require_confirmation = require_confirmation
self.snapshot_store = snapshot_store
@staticmethod
def _normalize_state_schema(state_schema: Any | None) -> dict[str, Any]:
@@ -79,6 +84,7 @@ class AgentFrameworkAgent:
predict_state_config: dict[str, dict[str, str]] | None = None,
require_confirmation: bool = True,
use_service_session: bool = False,
snapshot_store: AGUIThreadSnapshotStore | None = None,
):
"""Initialize the AG-UI compatible agent wrapper.
@@ -90,6 +96,8 @@ class AgentFrameworkAgent:
predict_state_config: Configuration for predictive state updates
require_confirmation: Whether predictive updates require user confirmation before applying
use_service_session: Whether the agent session is service-managed
snapshot_store: Optional AG-UI Thread Snapshot store. Snapshot persistence remains inactive unless
endpoint setup also provides an explicit Snapshot Scope resolver.
"""
self.agent = agent
self.name = name or getattr(agent, "name", "agent")
@@ -100,6 +108,7 @@ class AgentFrameworkAgent:
predict_state_config=predict_state_config,
use_service_session=use_service_session,
require_confirmation=require_confirmation,
snapshot_store=snapshot_store,
)
# Server-side registry of pending approval requests.
@@ -110,6 +119,11 @@ class AgentFrameworkAgent:
self._pending_approvals: OrderedDict[str, PendingApprovalEntry] = OrderedDict()
self._pending_approvals_max_size: int = 10_000
@property
def snapshot_store(self) -> AGUIThreadSnapshotStore | None:
"""Configured AG-UI Thread Snapshot store, if any."""
return self.config.snapshot_store
async def run(
self,
input_data: dict[str, Any],
@@ -4,6 +4,7 @@
from __future__ import annotations # noqa: I001
import copy
import json
import logging
import uuid
@@ -52,9 +53,11 @@ from ._run_common import (
_extract_tool_result_display, # type: ignore
_has_only_tool_calls, # type: ignore
_normalize_resume_interrupts, # type: ignore
_reconstruct_messages_from_thread_snapshot, # type: ignore
_resolve_ui_payload, # type: ignore
_stringify_tool_result, # type: ignore
)
from ._snapshots import AGUIThreadSnapshot, _DEFAULT_STATE_INPUT_KEY, _SNAPSHOT_SCOPE_INPUT_KEY
from ._utils import (
canonical_function_arguments,
convert_agui_tools_to_agent_framework,
@@ -748,6 +751,85 @@ def _build_messages_snapshot(
return MessagesSnapshotEvent(messages=all_messages) # type: ignore[arg-type]
def _event_messages_to_snapshot_dicts(messages: list[Any]) -> list[dict[str, Any]]:
"""Convert AG-UI message event models back to plain snapshot dictionaries."""
safe_messages = make_json_safe(messages)
if not isinstance(safe_messages, list):
return []
return [cast(dict[str, Any], message) for message in safe_messages if isinstance(message, dict)]
def _text_events_to_snapshot_messages(events: list[BaseEvent]) -> list[dict[str, Any]]:
"""Convert streamed text-message events into snapshot message dictionaries."""
messages: list[dict[str, Any]] = []
messages_by_id: dict[str, dict[str, Any]] = {}
for event in events:
if isinstance(event, TextMessageStartEvent):
message: dict[str, Any] = {"id": event.message_id, "role": event.role, "content": ""}
messages.append(message)
messages_by_id[event.message_id] = message
elif isinstance(event, TextMessageContentEvent):
open_message = messages_by_id.get(event.message_id)
if open_message is not None:
open_message["content"] = f"{open_message['content']}{event.delta}"
return [message for message in messages if message.get("content")]
async def _hydrate_thread_snapshot(
*,
config: AgentConfig,
scope: str,
thread_id: str,
run_id: str,
) -> AsyncGenerator[BaseEvent]:
"""Replay the latest stored AG-UI Thread Snapshot without invoking the agent."""
yield RunStartedEvent(run_id=run_id, thread_id=thread_id)
if config.snapshot_store is None:
yield _build_run_finished_event(run_id=run_id, thread_id=thread_id)
return
snapshot = await config.snapshot_store.get(scope=scope, thread_id=thread_id)
if snapshot is None:
yield _build_run_finished_event(run_id=run_id, thread_id=thread_id)
return
if snapshot.state is not None:
yield StateSnapshotEvent(snapshot=snapshot.state)
if snapshot.messages:
yield MessagesSnapshotEvent(messages=snapshot.messages) # type: ignore[arg-type]
yield _build_run_finished_event(run_id=run_id, thread_id=thread_id, interrupts=snapshot.interrupt)
async def _save_thread_snapshot(
*,
config: AgentConfig,
scope: str | None,
thread_id: str,
messages: list[dict[str, Any]],
state: dict[str, Any] | None,
interrupt: list[dict[str, Any]] | None,
) -> None:
"""Save the latest replayable AG-UI Thread Snapshot when persistence is configured."""
if config.snapshot_store is None or scope is None:
return
try:
await config.snapshot_store.save(
scope=scope,
thread_id=thread_id,
snapshot=AGUIThreadSnapshot(messages=messages, state=state, interrupt=interrupt),
)
except Exception:
# The run itself already streamed successfully; a transient store failure
# must not surface as RUN_ERROR for a completed run. The previous snapshot
# stays available for hydration.
logger.exception(
"Failed to save AG-UI Thread Snapshot for scope=%s thread_id=%s; keeping previous snapshot.",
scope,
thread_id,
)
async def run_agent_stream(
input_data: dict[str, Any],
agent: SupportsAgentRun,
@@ -774,15 +856,53 @@ async def run_agent_stream(
# Parse IDs
thread_id = input_data.get("thread_id") or input_data.get("threadId") or str(uuid.uuid4())
run_id = input_data.get("run_id") or input_data.get("runId") or str(uuid.uuid4())
# Initialize flow state with schema defaults
flow = FlowState()
if input_data.get("state"):
flow.current_state = dict(input_data["state"])
snapshot_scope = cast(str | None, input_data.get(_SNAPSHOT_SCOPE_INPUT_KEY))
state_schema = cast(dict[str, Any], getattr(config, "state_schema", {}) or {})
predict_state_config = cast(dict[str, dict[str, str]], getattr(config, "predict_state_config", {}) or {})
# Normalize messages
available_interrupts = input_data.get("available_interrupts") or input_data.get("availableInterrupts")
raw_messages: list[dict[str, Any]] = input_data.get("messages", []) or []
resume_payload = _extract_resume_payload(input_data)
if config.snapshot_store is not None and snapshot_scope is not None and not raw_messages and resume_payload is None:
async for event in _hydrate_thread_snapshot(
config=config,
scope=snapshot_scope,
thread_id=thread_id,
run_id=run_id,
):
yield event
return
stored_snapshot: AGUIThreadSnapshot | None = None
if config.snapshot_store is not None and snapshot_scope is not None:
stored_snapshot = await config.snapshot_store.get(scope=snapshot_scope, thread_id=thread_id)
if stored_snapshot is not None and resume_payload is None:
raw_messages = _reconstruct_messages_from_thread_snapshot(
stored_messages=stored_snapshot.messages,
incoming_messages=raw_messages,
stored_interrupt=stored_snapshot.interrupt,
)
# Initialize flow state with stored state plus request-provided overrides.
flow = FlowState()
request_state = input_data.get("state")
if stored_snapshot is not None and stored_snapshot.state is not None:
flow.current_state = dict(stored_snapshot.state)
if isinstance(request_state, dict):
flow.current_state.update(request_state)
elif isinstance(request_state, dict):
flow.current_state = dict(request_state)
# Apply endpoint-deferred defaults only for keys missing from both the stored
# snapshot state and the request state, so defaults never reset persisted state.
deferred_default_state = cast(dict[str, Any] | None, input_data.get(_DEFAULT_STATE_INPUT_KEY))
if deferred_default_state:
for key, value in deferred_default_state.items():
if key not in flow.current_state:
flow.current_state[key] = copy.deepcopy(value)
# Apply schema defaults for missing state keys
if state_schema:
for key, schema in state_schema.items():
@@ -801,10 +921,7 @@ async def run_agent_stream(
current_state=flow.current_state,
)
# Normalize messages
available_interrupts = input_data.get("available_interrupts") or input_data.get("availableInterrupts")
raw_messages = list(cast(list[dict[str, Any]], input_data.get("messages", []) or []))
resume_messages = _resume_to_tool_messages(_extract_resume_payload(input_data))
resume_messages = _resume_to_tool_messages(resume_payload)
if available_interrupts:
logger.debug("Received available interrupts metadata: %s", available_interrupts)
if resume_messages:
@@ -892,8 +1009,24 @@ async def run_agent_stream(
# Emit approved state snapshot before confirmation message
if approved_state_snapshot_emitted:
yield StateSnapshotEvent(snapshot=flow.current_state)
for event in _handle_step_based_approval(messages):
confirmation_events = _handle_step_based_approval(messages)
for event in confirmation_events:
yield event
# Persist the completed confirmation turn with interrupt=None so hydration
# does not replay the stale pending interrupt after the user responded.
persisted_messages = snapshot_messages + _text_events_to_snapshot_messages(confirmation_events)
if resume_payload is not None and stored_snapshot is not None:
# Resume requests carry only the synthesized interrupt response, so prepend
# the stored thread history to avoid persisting a truncated thread.
persisted_messages = [copy.deepcopy(message) for message in stored_snapshot.messages] + persisted_messages
await _save_thread_snapshot(
config=config,
scope=snapshot_scope,
thread_id=thread_id,
messages=persisted_messages,
state=cast(dict[str, Any], make_json_safe(flow.current_state)) if flow.current_state else None,
interrupt=None,
)
yield _build_run_finished_event(run_id=run_id, thread_id=thread_id)
return
@@ -905,6 +1038,9 @@ async def run_agent_stream(
# Stream from agent - emit RunStarted after first update to get service IDs
run_started_emitted = False
all_updates: list[Any] = [] # Collect for structured output processing
latest_state_snapshot: dict[str, Any] | None = (
cast(dict[str, Any], make_json_safe(flow.current_state)) if flow.current_state else None
)
response_stream = agent.run(messages, stream=True, **run_kwargs)
stream = await _normalize_response_stream(response_stream)
async for update in stream:
@@ -934,6 +1070,7 @@ async def run_agent_stream(
yield CustomEvent(name="PredictState", value=predict_state_value)
# Emit initial state snapshot only if we have both state_schema and state
if state_schema and flow.current_state:
latest_state_snapshot = cast(dict[str, Any], make_json_safe(flow.current_state))
yield StateSnapshotEvent(snapshot=flow.current_state)
run_started_emitted = True
@@ -975,6 +1112,8 @@ async def run_agent_stream(
skip_text,
config.require_confirmation,
):
if isinstance(event, StateSnapshotEvent):
latest_state_snapshot = cast(dict[str, Any], make_json_safe(event.snapshot))
yield event
# Stop if waiting for approval
@@ -1019,6 +1158,7 @@ async def run_agent_stream(
if state_updates:
flow.current_state.update(state_updates)
latest_state_snapshot = cast(dict[str, Any], make_json_safe(flow.current_state))
yield StateSnapshotEvent(snapshot=flow.current_state)
logger.info(f"Emitted StateSnapshotEvent with updates: {list(state_updates.keys())}")
@@ -1056,6 +1196,7 @@ async def run_agent_stream(
if result:
state_key, state_value = result
flow.current_state[state_key] = state_value
latest_state_snapshot = cast(dict[str, Any], make_json_safe(flow.current_state))
yield StateSnapshotEvent(snapshot=flow.current_state)
except json.JSONDecodeError:
# Ignore malformed JSON in tool arguments for predictive state;
@@ -1136,7 +1277,12 @@ async def run_agent_stream(
should_emit_snapshot = (
flow.pending_tool_calls or flow.tool_results or flow.accumulated_text or flow.reasoning_messages
)
latest_messages_snapshot = snapshot_messages
if should_emit_snapshot:
# Always fold this turn's output into the persisted snapshot, even when the
# outbound MESSAGES_SNAPSHOT event is suppressed for predictive tools.
snapshot_event = _build_messages_snapshot(flow, snapshot_messages)
latest_messages_snapshot = _event_messages_to_snapshot_dicts(list(snapshot_event.messages))
# Check if we should suppress for predictive tool
last_tool_name = None
if flow.tool_results:
@@ -1146,8 +1292,21 @@ async def run_agent_stream(
if not _should_suppress_intermediate_snapshot(
last_tool_name, predict_state_config, config.require_confirmation
):
yield _build_messages_snapshot(flow, snapshot_messages)
yield snapshot_event
# Always emit RunFinished - confirm_changes tool call is complete (Start -> Args -> End)
# The UI will show confirmation dialog and send a new request when user responds
persisted_messages = latest_messages_snapshot
if resume_payload is not None and stored_snapshot is not None:
# Resume requests carry only the synthesized interrupt response, so prepend
# the stored thread history to avoid persisting a truncated thread.
persisted_messages = [copy.deepcopy(message) for message in stored_snapshot.messages] + persisted_messages
await _save_thread_snapshot(
config=config,
scope=snapshot_scope,
thread_id=thread_id,
messages=persisted_messages,
state=latest_state_snapshot,
interrupt=flow.interrupts or None,
)
yield _build_run_finished_event(run_id=run_id, thread_id=thread_id, interrupts=flow.interrupts)
@@ -7,6 +7,7 @@ from __future__ import annotations
import copy
import logging
from collections.abc import AsyncGenerator, Sequence
from inspect import isawaitable
from typing import Any
from ag_ui.core import RunErrorEvent
@@ -17,12 +18,58 @@ from fastapi.params import Depends
from fastapi.responses import StreamingResponse
from ._agent import AgentFrameworkAgent
from ._snapshots import (
_DEFAULT_STATE_INPUT_KEY,
_SNAPSHOT_SCOPE_INPUT_KEY,
AGUIThreadSnapshotStore,
SnapshotScopeResolver,
)
from ._types import AGUIRequest
from ._workflow import AgentFrameworkWorkflow
logger = logging.getLogger(__name__)
def _get_snapshot_store(
protocol_runner: AgentFrameworkAgent | AgentFrameworkWorkflow,
) -> AGUIThreadSnapshotStore | None:
if isinstance(protocol_runner, AgentFrameworkAgent):
return protocol_runner.config.snapshot_store
return protocol_runner.snapshot_store
def _set_snapshot_store(
protocol_runner: AgentFrameworkAgent | AgentFrameworkWorkflow,
snapshot_store: AGUIThreadSnapshotStore,
) -> None:
if isinstance(protocol_runner, AgentFrameworkAgent):
protocol_runner.config.snapshot_store = snapshot_store
return
protocol_runner.snapshot_store = snapshot_store
def _configure_snapshot_persistence(
protocol_runner: AgentFrameworkAgent | AgentFrameworkWorkflow,
*,
snapshot_store: AGUIThreadSnapshotStore | None,
snapshot_scope_resolver: SnapshotScopeResolver | None,
) -> None:
existing_snapshot_store = _get_snapshot_store(protocol_runner)
if snapshot_store is not None:
if existing_snapshot_store is not None and existing_snapshot_store is not snapshot_store:
raise ValueError("snapshot_store is already configured on the AG-UI runner.")
if existing_snapshot_store is None:
_set_snapshot_store(protocol_runner, snapshot_store)
existing_snapshot_store = snapshot_store
if existing_snapshot_store is not None and snapshot_scope_resolver is None:
raise ValueError(
"snapshot_scope_resolver is required when snapshot_store is configured. "
"AG-UI Thread ids identify threads but do not authorize snapshot access; "
"provide a resolver that returns an explicit Snapshot Scope."
)
def add_agent_framework_fastapi_endpoint(
app: FastAPI,
agent: SupportsAgentRun | AgentFrameworkAgent | Workflow | AgentFrameworkWorkflow,
@@ -33,6 +80,8 @@ def add_agent_framework_fastapi_endpoint(
default_state: dict[str, Any] | None = None,
tags: list[str] | None = None,
dependencies: Sequence[Depends] | None = None,
snapshot_store: AGUIThreadSnapshotStore | None = None,
snapshot_scope_resolver: SnapshotScopeResolver | None = None,
) -> None:
"""Add an AG-UI endpoint to a FastAPI app.
@@ -50,6 +99,10 @@ def add_agent_framework_fastapi_endpoint(
These dependencies run before the endpoint handler. Use this to add
authentication checks, rate limiting, or other middleware-like behavior.
Example: `dependencies=[Depends(verify_api_key)]`
snapshot_store: Optional AG-UI Thread Snapshot store. Snapshot persistence is opt-in and requires an
explicit Snapshot Scope resolver.
snapshot_scope_resolver: Optional resolver for the application-defined Snapshot Scope. Required whenever
a snapshot store is configured because an AG-UI Thread id is not an authorization boundary.
"""
protocol_runner: AgentFrameworkAgent | AgentFrameworkWorkflow
if isinstance(agent, AgentFrameworkWorkflow):
@@ -63,10 +116,17 @@ def add_agent_framework_fastapi_endpoint(
agent=agent,
state_schema=state_schema,
predict_state_config=predict_state_config,
snapshot_store=snapshot_store,
)
else:
raise TypeError("agent must be SupportsAgentRun, Workflow, AgentFrameworkAgent, or AgentFrameworkWorkflow.")
_configure_snapshot_persistence(
protocol_runner,
snapshot_store=snapshot_store,
snapshot_scope_resolver=snapshot_scope_resolver,
)
@app.post(path, tags=tags or ["AG-UI"], dependencies=dependencies, response_model=None) # type: ignore[arg-type]
async def agent_endpoint(request_body: AGUIRequest) -> StreamingResponse:
"""Handle AG-UI agent requests.
@@ -76,11 +136,23 @@ def add_agent_framework_fastapi_endpoint(
"""
try:
input_data = request_body.model_dump(exclude_none=True)
snapshot_persistence_active = False
if snapshot_scope_resolver is not None and _get_snapshot_store(protocol_runner) is not None:
snapshot_scope = snapshot_scope_resolver(request_body)
if isawaitable(snapshot_scope):
snapshot_scope = await snapshot_scope
input_data[_SNAPSHOT_SCOPE_INPUT_KEY] = snapshot_scope
snapshot_persistence_active = True
if default_state:
state = input_data.setdefault("state", {})
for key, value in default_state.items():
if key not in state:
state[key] = copy.deepcopy(value)
if snapshot_persistence_active:
# Defer default application to the runner so defaults only fill keys
# missing from both the stored snapshot state and the request state.
input_data[_DEFAULT_STATE_INPUT_KEY] = copy.deepcopy(default_state)
else:
state = input_data.setdefault("state", {})
for key, value in default_state.items():
if key not in state:
state[key] = copy.deepcopy(value)
logger.debug(
f"[{path}] Received request - Run ID: {input_data.get('run_id', 'no-run-id')}, "
f"Thread ID: {input_data.get('thread_id', 'no-thread-id')}, "
@@ -4,6 +4,7 @@
from __future__ import annotations
import copy
import json
import logging
from collections.abc import Mapping
@@ -33,7 +34,7 @@ from agent_framework import Content
from ._orchestration._predictive_state import PredictiveStateHandler
from ._state import TOOL_RESULT_DISPLAY_KEY, TOOL_RESULT_STATE_KEY
from ._utils import generate_event_id, make_json_safe
from ._utils import generate_event_id, make_json_safe, normalize_agui_role
logger = logging.getLogger(__name__)
@@ -733,3 +734,117 @@ def _emit_content(
return _emit_text_reasoning(content, flow)
logger.debug("Skipping unsupported content type in AG-UI emitter: %s", content_type)
return events
def _canonical_snapshot_message(message: dict[str, Any]) -> dict[str, Any]:
"""Normalize an AG-UI message for identity comparison without generated ids."""
from ._message_adapters import agui_messages_to_snapshot_format
normalized_message = agui_messages_to_snapshot_format([copy.deepcopy(message)])[0]
normalized_message.pop("id", None)
return cast(dict[str, Any], make_json_safe(normalized_message))
def _snapshot_messages_match(stored_message: dict[str, Any], incoming_message: dict[str, Any]) -> bool:
"""Return whether an incoming message already represents the stored snapshot message."""
stored_id = stored_message.get("id")
incoming_id = incoming_message.get("id")
if stored_id and incoming_id:
return str(stored_id) == str(incoming_id)
return _canonical_snapshot_message(stored_message) == _canonical_snapshot_message(incoming_message)
def _latest_user_message_index(messages: list[dict[str, Any]]) -> int | None:
"""Find the newest incoming user message index."""
for index in range(len(messages) - 1, -1, -1):
if normalize_agui_role(messages[index].get("role", "user")) == "user":
return index
return None
def _known_tool_call_ids(
stored_messages: list[dict[str, Any]],
stored_interrupt: list[dict[str, Any]] | None,
) -> set[str]:
"""Collect tool call ids the backend previously issued for this thread."""
known_ids: set[str] = set()
for message in stored_messages:
tool_calls = message.get("tool_calls") or message.get("toolCalls") or []
if not isinstance(tool_calls, list):
continue
for tool_call in cast(list[Any], tool_calls):
if isinstance(tool_call, dict):
tool_call_id = cast(dict[str, Any], tool_call).get("id")
if tool_call_id:
known_ids.add(str(tool_call_id))
for interrupt in stored_interrupt or []:
interrupt_id = interrupt.get("id")
if interrupt_id:
known_ids.add(str(interrupt_id))
return known_ids
def _filter_untrusted_suffix(
incoming_suffix: list[dict[str, Any]],
*,
stored_messages: list[dict[str, Any]],
stored_interrupt: list[dict[str, Any]] | None,
) -> list[dict[str, Any]]:
"""Drop client-forged non-user messages before promoting them to stored history.
Only the user's own turns and tool results answering backend-issued tool calls
(including pending interrupts) may extend the authoritative thread history.
"""
known_ids: set[str] | None = None
filtered: list[dict[str, Any]] = []
for message in incoming_suffix:
raw_role = str(message.get("role", "")).lower()
if raw_role == "user":
filtered.append(message)
continue
if raw_role == "tool":
tool_call_id = message.get("toolCallId") or message.get("tool_call_id") or message.get("actionExecutionId")
if known_ids is None:
known_ids = _known_tool_call_ids(stored_messages, stored_interrupt)
if tool_call_id and str(tool_call_id) in known_ids:
filtered.append(message)
continue
logger.warning(
"Dropping client-supplied %r message from the incoming thread suffix; "
"only user turns and tool results for backend-issued tool calls extend stored history.",
raw_role or "unknown",
)
return filtered
def _reconstruct_messages_from_thread_snapshot(
*,
stored_messages: list[dict[str, Any]],
incoming_messages: list[dict[str, Any]],
stored_interrupt: list[dict[str, Any]] | None = None,
) -> list[dict[str, Any]]:
"""Combine backend-owned prior history with the request-owned new user turn."""
if not stored_messages or not incoming_messages:
return incoming_messages
incoming_suffix: list[dict[str, Any]]
if len(incoming_messages) >= len(stored_messages) and all(
_snapshot_messages_match(stored_message, incoming_message)
for stored_message, incoming_message in zip(stored_messages, incoming_messages)
):
incoming_suffix = incoming_messages[len(stored_messages) :]
else:
latest_user_index = _latest_user_message_index(incoming_messages)
if latest_user_index is None:
return incoming_messages
incoming_suffix = incoming_messages[latest_user_index:]
incoming_suffix = _filter_untrusted_suffix(
incoming_suffix,
stored_messages=stored_messages,
stored_interrupt=stored_interrupt,
)
return [copy.deepcopy(message) for message in stored_messages] + [
copy.deepcopy(message) for message in incoming_suffix
]
@@ -0,0 +1,202 @@
# Copyright (c) Microsoft. All rights reserved.
"""AG-UI Thread Snapshot storage primitives."""
from __future__ import annotations
import copy
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, runtime_checkable
if TYPE_CHECKING:
from ._types import AGUIRequest
SnapshotScope: TypeAlias = str
"""Application-defined scope for authorizing access to AG-UI Thread Snapshots."""
AGUIThreadID: TypeAlias = str
"""AG-UI Thread identifier within a Snapshot Scope."""
SnapshotScopeResolver: TypeAlias = Callable[["AGUIRequest"], str | Awaitable[str]]
"""Callable that resolves the Snapshot Scope for an AG-UI endpoint request."""
_SnapshotKey: TypeAlias = tuple[SnapshotScope, AGUIThreadID]
DEFAULT_MAX_THREAD_SNAPSHOTS = 1_000
_SNAPSHOT_SCOPE_INPUT_KEY = "__ag_ui_snapshot_scope"
_DEFAULT_STATE_INPUT_KEY = "__ag_ui_default_state"
@dataclass(slots=True)
class AGUIThreadSnapshot:
"""Replayable AG-UI Thread state.
AG-UI Thread Snapshots intentionally contain only data that can be replayed
to a UI: message snapshots, optional Shared State, and optional interruption
state. They do not include raw events, request metadata, auth claims,
diagnostics, traces, or provider responses.
Attributes:
messages: Replayable AG-UI message snapshots.
state: Optional AG-UI Shared State snapshot.
interrupt: Optional interruption state from ``RUN_FINISHED.interrupt``.
"""
messages: list[dict[str, Any]] = field(default_factory=list)
state: dict[str, Any] | None = None
interrupt: list[dict[str, Any]] | None = None
@runtime_checkable
class AGUIThreadSnapshotStore(Protocol):
"""Async store for latest AG-UI Thread Snapshots keyed by scope and thread id."""
async def save(
self,
*,
scope: SnapshotScope,
thread_id: AGUIThreadID,
snapshot: AGUIThreadSnapshot,
) -> None:
"""Save the latest snapshot for an AG-UI Thread within a Snapshot Scope.
Args:
scope: Application-defined Snapshot Scope. This is part of the
storage key and must represent the app's authorization boundary.
thread_id: AG-UI Thread id within the scope.
snapshot: Snapshot to save.
"""
...
async def get(
self,
*,
scope: SnapshotScope,
thread_id: AGUIThreadID,
) -> AGUIThreadSnapshot | None:
"""Get the latest snapshot for an AG-UI Thread within a Snapshot Scope.
Args:
scope: Application-defined Snapshot Scope.
thread_id: AG-UI Thread id within the scope.
Returns:
The latest snapshot, or ``None`` when no snapshot exists for the key.
"""
...
async def delete(
self,
*,
scope: SnapshotScope,
thread_id: AGUIThreadID,
) -> bool:
"""Delete the latest snapshot for an AG-UI Thread within a Snapshot Scope.
Args:
scope: Application-defined Snapshot Scope.
thread_id: AG-UI Thread id within the scope.
Returns:
``True`` when a snapshot was deleted, otherwise ``False``.
"""
...
async def clear(self, *, scope: SnapshotScope | None = None) -> None:
"""Clear saved snapshots.
Args:
scope: Optional Snapshot Scope to clear. When omitted, all in-memory
snapshots are cleared.
"""
...
class InMemoryAGUIThreadSnapshotStore:
"""Bounded memory-only latest snapshot store for local development, demos, and tests.
This store keeps at most one snapshot per ``(scope, thread_id)`` key. It is
process-local and not durable production storage.
"""
def __init__(self, *, max_snapshots: int = DEFAULT_MAX_THREAD_SNAPSHOTS) -> None:
"""Initialize the in-memory snapshot store.
Keyword Args:
max_snapshots: Maximum number of scoped thread snapshots to retain.
Raises:
ValueError: If ``max_snapshots`` is less than 1.
"""
if max_snapshots < 1:
raise ValueError("max_snapshots must be greater than 0.")
self._max_snapshots = max_snapshots
self._snapshots: dict[_SnapshotKey, AGUIThreadSnapshot] = {}
async def save(
self,
*,
scope: SnapshotScope,
thread_id: AGUIThreadID,
snapshot: AGUIThreadSnapshot,
) -> None:
"""Save the latest snapshot for an AG-UI Thread within a Snapshot Scope."""
key = self._key(scope=scope, thread_id=thread_id)
if key in self._snapshots:
del self._snapshots[key]
self._snapshots[key] = copy.deepcopy(snapshot)
self._evict_oldest()
async def get(
self,
*,
scope: SnapshotScope,
thread_id: AGUIThreadID,
) -> AGUIThreadSnapshot | None:
"""Get the latest snapshot for an AG-UI Thread within a Snapshot Scope."""
snapshot = self._snapshots.get(self._key(scope=scope, thread_id=thread_id))
return copy.deepcopy(snapshot) if snapshot is not None else None
async def delete(
self,
*,
scope: SnapshotScope,
thread_id: AGUIThreadID,
) -> bool:
"""Delete the latest snapshot for an AG-UI Thread within a Snapshot Scope."""
key = self._key(scope=scope, thread_id=thread_id)
if key not in self._snapshots:
return False
del self._snapshots[key]
return True
async def clear(self, *, scope: SnapshotScope | None = None) -> None:
"""Clear saved snapshots, optionally limited to one Snapshot Scope."""
if scope is None:
self._snapshots.clear()
return
normalized_scope = self._normalize_key_part(scope, "scope")
for key in list(self._snapshots):
if key[0] == normalized_scope:
del self._snapshots[key]
@classmethod
def _key(cls, *, scope: SnapshotScope, thread_id: AGUIThreadID) -> _SnapshotKey:
return (
cls._normalize_key_part(scope, "scope"),
cls._normalize_key_part(thread_id, "thread_id"),
)
@staticmethod
def _normalize_key_part(value: str, name: str) -> str:
if not isinstance(value, str):
raise TypeError(f"{name} must be a string.")
if not value:
raise ValueError(f"{name} must be a non-empty string.")
return value
def _evict_oldest(self) -> None:
while len(self._snapshots) > self._max_snapshots:
del self._snapshots[next(iter(self._snapshots))]
@@ -4,18 +4,203 @@
from __future__ import annotations
import copy
import logging
import uuid
from collections.abc import AsyncGenerator, Callable
from typing import Any
from typing import Any, cast
from ag_ui.core import BaseEvent
from ag_ui.core import (
BaseEvent,
MessagesSnapshotEvent,
RunErrorEvent,
RunFinishedEvent,
RunStartedEvent,
StateSnapshotEvent,
TextMessageContentEvent,
TextMessageEndEvent,
TextMessageStartEvent,
ToolCallArgsEvent,
ToolCallResultEvent,
ToolCallStartEvent,
)
from agent_framework import Workflow
from ._message_adapters import agui_messages_to_snapshot_format
from ._run_common import (
_build_run_finished_event,
_extract_resume_payload,
_reconstruct_messages_from_thread_snapshot,
)
from ._snapshots import (
_DEFAULT_STATE_INPUT_KEY,
_SNAPSHOT_SCOPE_INPUT_KEY,
AGUIThreadSnapshot,
AGUIThreadSnapshotStore,
)
from ._utils import generate_event_id, make_json_safe
from ._workflow_run import run_workflow_stream
logger = logging.getLogger(__name__)
WorkflowFactory = Callable[[str], Workflow]
def _event_messages_to_snapshot_dicts(messages: list[Any]) -> list[dict[str, Any]]:
"""Convert AG-UI message event models to plain snapshot dictionaries."""
safe_messages = make_json_safe(messages)
if not isinstance(safe_messages, list):
return []
return [cast(dict[str, Any], message) for message in safe_messages if isinstance(message, dict)]
class _WorkflowSnapshotBuilder:
"""Capture replayable workflow protocol output without retaining raw events."""
def __init__(self, raw_messages: list[dict[str, Any]]) -> None:
self._synthesized_messages = agui_messages_to_snapshot_format(raw_messages)
self._emitted_messages: list[dict[str, Any]] | None = None
self._open_text_message: dict[str, Any] | None = None
self._tool_call_message: dict[str, Any] | None = None
self._tool_calls_by_id: dict[str, dict[str, Any]] = {}
self.state: dict[str, Any] | None = None
self.interrupt: list[dict[str, Any]] | None = None
def observe(self, event: BaseEvent) -> None:
"""Fold one replayable AG-UI event into the latest snapshot state."""
if isinstance(event, StateSnapshotEvent):
state = make_json_safe(event.snapshot)
if isinstance(state, dict):
self.state = cast(dict[str, Any], state)
return
if isinstance(event, MessagesSnapshotEvent):
self._emitted_messages = _event_messages_to_snapshot_dicts(list(event.messages))
return
if isinstance(event, RunFinishedEvent):
interrupt = make_json_safe(getattr(event, "interrupt", None))
if isinstance(interrupt, list):
self.interrupt = [cast(dict[str, Any], item) for item in interrupt if isinstance(item, dict)]
return
if self._emitted_messages is not None:
return
if isinstance(event, TextMessageStartEvent):
self._observe_text_start(event)
elif isinstance(event, TextMessageContentEvent):
self._observe_text_content(event)
elif isinstance(event, TextMessageEndEvent):
self._observe_text_end(event)
elif isinstance(event, ToolCallStartEvent):
self._observe_tool_call_start(event)
elif isinstance(event, ToolCallArgsEvent):
self._observe_tool_call_args(event)
elif isinstance(event, ToolCallResultEvent):
self._observe_tool_call_result(event)
def build(self) -> AGUIThreadSnapshot:
"""Return the replayable thread snapshot."""
self._flush_open_text_message()
messages = self._emitted_messages if self._emitted_messages is not None else self._synthesized_messages
return AGUIThreadSnapshot(messages=messages, state=self.state, interrupt=self.interrupt)
def _observe_text_start(self, event: TextMessageStartEvent) -> None:
if self._open_text_message is not None and self._open_text_message.get("id") != event.message_id:
self._flush_open_text_message()
self._open_text_message = {"id": event.message_id, "role": event.role, "content": ""}
def _observe_text_content(self, event: TextMessageContentEvent) -> None:
if self._open_text_message is None or self._open_text_message.get("id") != event.message_id:
self._open_text_message = {"id": event.message_id, "role": "assistant", "content": ""}
self._open_text_message["content"] = f"{self._open_text_message.get('content', '')}{event.delta}"
def _observe_text_end(self, event: TextMessageEndEvent) -> None:
if self._open_text_message is None or self._open_text_message.get("id") != event.message_id:
return
self._flush_open_text_message()
def _observe_tool_call_start(self, event: ToolCallStartEvent) -> None:
parent_message_id = event.parent_message_id
if (
self._open_text_message is not None
and parent_message_id is not None
and self._open_text_message.get("id") == parent_message_id
and self._open_text_message.get("content")
):
self._open_text_message["id"] = generate_event_id()
self._flush_open_text_message()
if self._tool_call_message is None or (
parent_message_id is not None and self._tool_call_message.get("id") != parent_message_id
):
self._tool_call_message = {
"id": parent_message_id or generate_event_id(),
"role": "assistant",
"tool_calls": [],
}
self._synthesized_messages.append(self._tool_call_message)
tool_call = {
"id": event.tool_call_id,
"type": "function",
"function": {"name": event.tool_call_name, "arguments": ""},
}
cast(list[dict[str, Any]], self._tool_call_message["tool_calls"]).append(tool_call)
self._tool_calls_by_id[event.tool_call_id] = tool_call
def _observe_tool_call_args(self, event: ToolCallArgsEvent) -> None:
tool_call = self._tool_calls_by_id.get(event.tool_call_id)
if tool_call is None:
return
function_payload = cast(dict[str, Any], tool_call["function"])
function_payload["arguments"] = f"{function_payload.get('arguments', '')}{event.delta}"
def _observe_tool_call_result(self, event: ToolCallResultEvent) -> None:
self._synthesized_messages.append(
{
"id": event.message_id,
"role": "tool",
"toolCallId": event.tool_call_id,
"content": event.content,
}
)
# A result closes the current tool-call group; later tool calls start a new
# assistant message so replayed transcripts keep results adjacent to their
# tool_calls message, which provider APIs require.
self._tool_call_message = None
def _flush_open_text_message(self) -> None:
if self._open_text_message is None:
return
if self._open_text_message.get("content"):
self._synthesized_messages.append(self._open_text_message)
# Text between tool calls closes the current tool-call group as well.
self._tool_call_message = None
self._open_text_message = None
async def _hydrate_workflow_thread_snapshot(
*,
snapshot_store: AGUIThreadSnapshotStore,
scope: str,
thread_id: str,
run_id: str,
) -> AsyncGenerator[BaseEvent]:
"""Replay the latest stored workflow AG-UI Thread Snapshot without invoking the workflow."""
yield RunStartedEvent(run_id=run_id, thread_id=thread_id)
snapshot = await snapshot_store.get(scope=scope, thread_id=thread_id)
if snapshot is None:
yield _build_run_finished_event(run_id=run_id, thread_id=thread_id)
return
if snapshot.state is not None:
yield StateSnapshotEvent(snapshot=snapshot.state)
if snapshot.messages:
yield MessagesSnapshotEvent(messages=snapshot.messages) # type: ignore[arg-type]
yield _build_run_finished_event(run_id=run_id, thread_id=thread_id, interrupts=snapshot.interrupt)
class AgentFrameworkWorkflow:
"""Base AG-UI workflow wrapper.
@@ -29,15 +214,30 @@ class AgentFrameworkWorkflow:
workflow_factory: WorkflowFactory | None = None,
name: str | None = None,
description: str | None = None,
snapshot_store: AGUIThreadSnapshotStore | None = None,
) -> None:
"""Initialize the AG-UI workflow wrapper.
Args:
workflow: Optional workflow instance to expose.
workflow_factory: Optional factory for thread-scoped workflow instances.
name: Optional workflow name.
description: Optional workflow description.
snapshot_store: Optional AG-UI Thread Snapshot store. Snapshot persistence remains inactive unless
endpoint setup also provides an explicit Snapshot Scope resolver.
"""
if workflow is not None and workflow_factory is not None:
raise ValueError("Pass either workflow= or workflow_factory=, not both.")
self.workflow = workflow
self._workflow_factory = workflow_factory
self._workflow_by_thread: dict[str, Workflow] = {}
# Cache keyed by (snapshot_scope, thread_id): the Snapshot Scope is the
# authorization boundary, so the same thread id under different scopes
# must never share an in-memory workflow instance.
self._workflow_by_thread: dict[tuple[str | None, str], Workflow] = {}
self.name = name if name is not None else getattr(workflow, "name", "workflow")
self.description = description if description is not None else getattr(workflow, "description", "")
self.snapshot_store = snapshot_store
@staticmethod
def _thread_id_from_input(input_data: dict[str, Any]) -> str:
@@ -47,7 +247,7 @@ class AgentFrameworkWorkflow:
return str(thread_id)
return str(uuid.uuid4())
def _resolve_workflow(self, thread_id: str) -> Workflow:
def _resolve_workflow(self, thread_id: str, snapshot_scope: str | None = None) -> Workflow:
"""Get the workflow instance for the current run."""
if self.workflow is not None:
return self.workflow
@@ -55,17 +255,22 @@ class AgentFrameworkWorkflow:
if self._workflow_factory is None:
raise NotImplementedError("No workflow is attached. Override run or pass workflow=/workflow_factory=.")
workflow = self._workflow_by_thread.get(thread_id)
cache_key = (snapshot_scope, thread_id)
workflow = self._workflow_by_thread.get(cache_key)
if workflow is None:
workflow = self._workflow_factory(thread_id)
if not isinstance(workflow, Workflow):
raise TypeError("workflow_factory must return a Workflow instance.")
self._workflow_by_thread[thread_id] = workflow
self._workflow_by_thread[cache_key] = workflow
return workflow
def clear_thread_workflow(self, thread_id: str) -> None:
"""Drop a single cached thread workflow instance."""
self._workflow_by_thread.pop(thread_id, None)
def clear_thread_workflow(self, thread_id: str, snapshot_scope: str | None = None) -> None:
"""Drop cached workflow instances for a thread, optionally limited to one Snapshot Scope."""
if snapshot_scope is not None:
self._workflow_by_thread.pop((snapshot_scope, thread_id), None)
return
for key in [key for key in self._workflow_by_thread if key[1] == thread_id]:
del self._workflow_by_thread[key]
def clear_workflow_cache(self) -> None:
"""Drop all cached thread workflow instances."""
@@ -77,6 +282,96 @@ class AgentFrameworkWorkflow:
Subclasses may override this to provide custom AG-UI streams.
"""
thread_id = self._thread_id_from_input(input_data)
workflow = self._resolve_workflow(thread_id)
run_id = str(input_data.get("run_id") or input_data.get("runId") or uuid.uuid4())
snapshot_scope = cast(str | None, input_data.get(_SNAPSHOT_SCOPE_INPUT_KEY))
raw_messages = list(cast(list[dict[str, Any]], input_data.get("messages", []) or []))
resume_payload = _extract_resume_payload(input_data)
snapshot_store = self.snapshot_store
if snapshot_store is not None and snapshot_scope is not None and not raw_messages and resume_payload is None:
async for event in _hydrate_workflow_thread_snapshot(
snapshot_store=snapshot_store,
scope=snapshot_scope,
thread_id=thread_id,
run_id=run_id,
):
yield event
return
# Load the stored snapshot for follow-up turns so the workflow runs with the
# full persisted thread history instead of just the latest request messages.
stored_snapshot: AGUIThreadSnapshot | None = None
if snapshot_store is not None and snapshot_scope is not None:
stored_snapshot = await snapshot_store.get(scope=snapshot_scope, thread_id=thread_id)
if stored_snapshot is not None and resume_payload is None:
raw_messages = _reconstruct_messages_from_thread_snapshot(
stored_messages=stored_snapshot.messages,
incoming_messages=raw_messages,
stored_interrupt=stored_snapshot.interrupt,
)
input_data["messages"] = raw_messages
# Merge stored state with request overrides, then fill endpoint-deferred
# defaults only for keys missing from both.
request_state = input_data.get("state")
deferred_default_state = cast(dict[str, Any] | None, input_data.get(_DEFAULT_STATE_INPUT_KEY))
effective_state: dict[str, Any] = {}
if stored_snapshot is not None and stored_snapshot.state is not None:
effective_state.update(stored_snapshot.state)
if isinstance(request_state, dict):
effective_state.update(cast(dict[str, Any], request_state))
if deferred_default_state:
for key, value in deferred_default_state.items():
if key not in effective_state:
effective_state[key] = copy.deepcopy(value)
if effective_state:
input_data["state"] = effective_state
workflow = self._resolve_workflow(thread_id, snapshot_scope)
builder_seed_messages = raw_messages
if resume_payload is not None and stored_snapshot is not None:
# Resume requests carry only the synthesized interrupt response, so seed
# the builder with stored history to avoid persisting a truncated thread.
builder_seed_messages = [
copy.deepcopy(message) for message in stored_snapshot.messages
] + builder_seed_messages
snapshot_builder = (
_WorkflowSnapshotBuilder(builder_seed_messages)
if snapshot_store is not None and snapshot_scope is not None
else None
)
if snapshot_builder is not None and effective_state:
# Seed builder state so a run that emits no StateSnapshotEvent still
# persists the latest known Shared State instead of dropping it.
state_snapshot = make_json_safe(effective_state)
if isinstance(state_snapshot, dict):
snapshot_builder.state = cast(dict[str, Any], state_snapshot)
run_error_emitted = False
async for event in run_workflow_stream(input_data, workflow):
if snapshot_builder is not None:
snapshot_builder.observe(event)
if isinstance(event, RunErrorEvent):
run_error_emitted = True
yield event
if (
snapshot_builder is not None
and not run_error_emitted
and snapshot_store is not None
and snapshot_scope is not None
):
try:
await snapshot_store.save(
scope=snapshot_scope,
thread_id=thread_id,
snapshot=snapshot_builder.build(),
)
except Exception:
# RUN_FINISHED has already been yielded; a store failure must not
# surface as a second terminal RUN_ERROR event. The previous
# snapshot stays available for hydration.
logger.exception(
"Failed to save AG-UI Thread Snapshot for scope=%s thread_id=%s; keeping previous snapshot.",
snapshot_scope,
thread_id,
)
File diff suppressed because it is too large Load Diff
@@ -32,6 +32,21 @@ def test_agent_framework_ag_ui_exports_state_update() -> None:
assert callable(state_update)
def test_agent_framework_ag_ui_exports_snapshot_primitives() -> None:
"""Runtime package should export AG-UI Thread Snapshot primitives."""
from agent_framework_ag_ui import (
DEFAULT_MAX_THREAD_SNAPSHOTS,
AGUIThreadSnapshot,
AGUIThreadSnapshotStore,
InMemoryAGUIThreadSnapshotStore,
)
assert AGUIThreadSnapshot.__name__ == "AGUIThreadSnapshot"
assert AGUIThreadSnapshotStore.__name__ == "AGUIThreadSnapshotStore"
assert InMemoryAGUIThreadSnapshotStore.__name__ == "InMemoryAGUIThreadSnapshotStore"
assert DEFAULT_MAX_THREAD_SNAPSHOTS >= 1
def test_core_ag_ui_lazy_exports_include_event_converter_and_http_service() -> None:
"""Core facade must expose AGUIEventConverter, AGUIHttpService, and __version__."""
from agent_framework import ag_ui
@@ -39,3 +54,13 @@ def test_core_ag_ui_lazy_exports_include_event_converter_and_http_service() -> N
assert hasattr(ag_ui, "AGUIEventConverter")
assert hasattr(ag_ui, "AGUIHttpService")
assert hasattr(ag_ui, "__version__")
def test_core_ag_ui_lazy_exports_include_snapshot_primitives() -> None:
"""Core facade must expose snapshot primitives needed for endpoint configuration."""
from agent_framework import ag_ui
assert hasattr(ag_ui, "AGUIThreadSnapshot")
assert hasattr(ag_ui, "AGUIThreadSnapshotStore")
assert hasattr(ag_ui, "InMemoryAGUIThreadSnapshotStore")
assert hasattr(ag_ui, "SnapshotScopeResolver")
@@ -0,0 +1,160 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for AG-UI thread snapshot storage primitives."""
from dataclasses import fields
from agent_framework_ag_ui import AGUIThreadSnapshot, AGUIThreadSnapshotStore, InMemoryAGUIThreadSnapshotStore
def test_thread_snapshot_model_contains_only_replayable_snapshot_fields() -> None:
"""The public snapshot model is limited to messages, Shared State, and interruption state."""
assert [field.name for field in fields(AGUIThreadSnapshot)] == ["messages", "state", "interrupt"]
def test_in_memory_snapshot_store_satisfies_snapshot_store_protocol() -> None:
"""The built-in store conforms to the public async store protocol."""
assert isinstance(InMemoryAGUIThreadSnapshotStore(), AGUIThreadSnapshotStore)
async def test_in_memory_snapshot_store_replaces_latest_snapshot() -> None:
"""Saving the same scoped thread key replaces the previous snapshot."""
store = InMemoryAGUIThreadSnapshotStore()
await store.save(
scope="tenant-a",
thread_id="thread-1",
snapshot=AGUIThreadSnapshot(messages=[{"id": "first"}], state={"count": 1}),
)
await store.save(
scope="tenant-a",
thread_id="thread-1",
snapshot=AGUIThreadSnapshot(messages=[{"id": "second"}], state={"count": 2}),
)
snapshot = await store.get(scope="tenant-a", thread_id="thread-1")
assert snapshot is not None
assert snapshot.messages == [{"id": "second"}]
assert snapshot.state == {"count": 2}
async def test_in_memory_snapshot_store_keeps_scopes_separate() -> None:
"""The same AG-UI Thread id in different Snapshot Scopes addresses different snapshots."""
store = InMemoryAGUIThreadSnapshotStore()
await store.save(
scope="tenant-a",
thread_id="thread-1",
snapshot=AGUIThreadSnapshot(messages=[{"id": "a", "role": "user", "content": "from a"}]),
)
await store.save(
scope="tenant-b",
thread_id="thread-1",
snapshot=AGUIThreadSnapshot(messages=[{"id": "b", "role": "user", "content": "from b"}]),
)
tenant_a_snapshot = await store.get(scope="tenant-a", thread_id="thread-1")
tenant_b_snapshot = await store.get(scope="tenant-b", thread_id="thread-1")
assert tenant_a_snapshot is not None
assert tenant_b_snapshot is not None
assert tenant_a_snapshot.messages == [{"id": "a", "role": "user", "content": "from a"}]
assert tenant_b_snapshot.messages == [{"id": "b", "role": "user", "content": "from b"}]
async def test_in_memory_snapshot_store_deletes_and_clears_snapshots() -> None:
"""Delete removes one scoped thread key, while clear can remove a scope or the whole store."""
store = InMemoryAGUIThreadSnapshotStore()
await store.save(scope="tenant-a", thread_id="thread-1", snapshot=AGUIThreadSnapshot(messages=[{"id": "a1"}]))
await store.save(scope="tenant-a", thread_id="thread-2", snapshot=AGUIThreadSnapshot(messages=[{"id": "a2"}]))
await store.save(scope="tenant-b", thread_id="thread-1", snapshot=AGUIThreadSnapshot(messages=[{"id": "b1"}]))
assert await store.delete(scope="tenant-a", thread_id="thread-1") is True
assert await store.delete(scope="tenant-a", thread_id="thread-1") is False
assert await store.get(scope="tenant-a", thread_id="thread-1") is None
assert await store.get(scope="tenant-a", thread_id="thread-2") is not None
await store.clear(scope="tenant-a")
assert await store.get(scope="tenant-a", thread_id="thread-2") is None
assert await store.get(scope="tenant-b", thread_id="thread-1") is not None
await store.clear()
assert await store.get(scope="tenant-b", thread_id="thread-1") is None
async def test_in_memory_snapshot_store_evicts_oldest_snapshot_when_bounded() -> None:
"""The memory store bounds retained scoped thread snapshots."""
store = InMemoryAGUIThreadSnapshotStore(max_snapshots=2)
await store.save(scope="tenant-a", thread_id="thread-1", snapshot=AGUIThreadSnapshot(messages=[{"id": "first"}]))
await store.save(scope="tenant-a", thread_id="thread-2", snapshot=AGUIThreadSnapshot(messages=[{"id": "second"}]))
await store.save(scope="tenant-a", thread_id="thread-3", snapshot=AGUIThreadSnapshot(messages=[{"id": "third"}]))
assert await store.get(scope="tenant-a", thread_id="thread-1") is None
assert await store.get(scope="tenant-a", thread_id="thread-2") is not None
assert await store.get(scope="tenant-a", thread_id="thread-3") is not None
def test_workflow_snapshot_builder_splits_tool_call_groups() -> None:
"""Tool calls separated by results or text synthesize provider-valid message groups."""
from ag_ui.core import (
TextMessageContentEvent,
TextMessageEndEvent,
TextMessageStartEvent,
ToolCallArgsEvent,
ToolCallResultEvent,
ToolCallStartEvent,
)
from agent_framework_ag_ui._workflow import _WorkflowSnapshotBuilder
builder = _WorkflowSnapshotBuilder([])
builder.observe(ToolCallStartEvent(tool_call_id="call-a", tool_call_name="toolA"))
builder.observe(ToolCallArgsEvent(tool_call_id="call-a", delta='{"x": 1}'))
builder.observe(ToolCallResultEvent(message_id="result-a", tool_call_id="call-a", content="resA"))
builder.observe(TextMessageStartEvent(message_id="text-1", role="assistant"))
builder.observe(TextMessageContentEvent(message_id="text-1", delta="thinking"))
builder.observe(TextMessageEndEvent(message_id="text-1"))
builder.observe(ToolCallStartEvent(tool_call_id="call-b", tool_call_name="toolB"))
builder.observe(ToolCallResultEvent(message_id="result-b", tool_call_id="call-b", content="resB"))
messages = builder.build().messages
shapes = [
(
message.get("role"),
[tool_call["id"] for tool_call in message.get("tool_calls", [])] or message.get("toolCallId"),
)
for message in messages
]
assert shapes == [
("assistant", ["call-a"]),
("tool", "call-a"),
("assistant", None),
("assistant", ["call-b"]),
("tool", "call-b"),
]
async def test_in_memory_snapshot_store_rejects_invalid_keys() -> None:
"""Key parts must be non-empty strings for every store operation."""
import pytest
store = InMemoryAGUIThreadSnapshotStore()
snapshot = AGUIThreadSnapshot()
with pytest.raises(ValueError):
await store.save(scope="", thread_id="thread-1", snapshot=snapshot)
with pytest.raises(ValueError):
await store.save(scope="tenant-a", thread_id="", snapshot=snapshot)
with pytest.raises(TypeError):
await store.save(scope=123, thread_id="thread-1", snapshot=snapshot) # type: ignore[arg-type]
with pytest.raises(ValueError):
await store.get(scope="tenant-a", thread_id="")
with pytest.raises(TypeError):
await store.delete(scope=None, thread_id="thread-1") # type: ignore[arg-type]
with pytest.raises(ValueError):
await store.clear(scope="")
@@ -125,7 +125,6 @@ from ._harness._todo import (
TodoSessionStore,
TodoStore,
)
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool, SamplingApprovalCallback
from ._harness._tool_approval import (
DEFAULT_TOOL_APPROVAL_SOURCE_ID,
ToolApprovalMiddleware,
@@ -135,6 +134,7 @@ from ._harness._tool_approval import (
create_always_approve_tool_response,
create_always_approve_tool_with_arguments_response,
)
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool, SamplingApprovalCallback
from ._middleware import (
AgentContext,
AgentMiddleware,
@@ -1989,9 +1989,7 @@ def _store_already_approved_approval_requests(
return
existing_groups = state.get(_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY)
pending_groups: list[Any] = (
list(cast(Iterable[Any], existing_groups)) if isinstance(existing_groups, list) else []
)
pending_groups: list[Any] = list(cast(Iterable[Any], existing_groups)) if isinstance(existing_groups, list) else []
pending_groups.append({
"approval_request_ids": visible_ids,
"approval_requests": [request.to_dict() for request in already_approved_requests],
@@ -11,6 +11,10 @@ Supported classes and functions:
- AGUIChatClient
- AGUIEventConverter
- AGUIHttpService
- AGUIThreadSnapshot
- AGUIThreadSnapshotStore
- InMemoryAGUIThreadSnapshotStore
- SnapshotScopeResolver
- add_agent_framework_fastapi_endpoint
- state_update
- __version__
@@ -28,6 +32,10 @@ _IMPORTS = [
"AGUIChatClient",
"AGUIEventConverter",
"AGUIHttpService",
"AGUIThreadSnapshot",
"AGUIThreadSnapshotStore",
"InMemoryAGUIThreadSnapshotStore",
"SnapshotScopeResolver",
"state_update",
"__version__",
]
@@ -6,6 +6,10 @@ from agent_framework_ag_ui import (
AGUIChatClient,
AGUIEventConverter,
AGUIHttpService,
AGUIThreadSnapshot,
AGUIThreadSnapshotStore,
InMemoryAGUIThreadSnapshotStore,
SnapshotScopeResolver,
__version__,
add_agent_framework_fastapi_endpoint,
state_update,
@@ -15,8 +19,12 @@ __all__ = [
"AGUIChatClient",
"AGUIEventConverter",
"AGUIHttpService",
"AGUIThreadSnapshot",
"AGUIThreadSnapshotStore",
"AgentFrameworkAgent",
"AgentFrameworkWorkflow",
"InMemoryAGUIThreadSnapshotStore",
"SnapshotScopeResolver",
"__version__",
"add_agent_framework_fastapi_endpoint",
"state_update",