mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add opt-in AG-UI thread snapshot persistence and hydration (#6471)
* feat(ag-ui): add thread snapshot store primitives Key decisions:\n- Introduce an AGUIThreadSnapshot model limited to replayable messages, optional Shared State, and optional interrupt state.\n- Define AGUIThreadSnapshotStore as an async protocol keyed by explicit Snapshot Scope and AG-UI Thread id.\n- Add InMemoryAGUIThreadSnapshotStore as memory-only, latest-only, bounded local/demo/test storage; no file-backed store is introduced.\n- Require snapshot_scope_resolver whenever an endpoint is configured with a snapshot store, including pre-wrapped runners, so thread ids are not authorization boundaries.\n\nFiles changed:\n- packages/ag-ui/agent_framework_ag_ui/_snapshots.py\n- packages/ag-ui/agent_framework_ag_ui/__init__.py\n- packages/ag-ui/agent_framework_ag_ui/_agent.py\n- packages/ag-ui/agent_framework_ag_ui/_workflow.py\n- packages/ag-ui/agent_framework_ag_ui/_endpoint.py\n- packages/core/agent_framework/ag_ui/__init__.py\n- packages/core/agent_framework/ag_ui/__init__.pyi\n- packages/ag-ui/tests/ag_ui/test_snapshots.py\n- packages/ag-ui/tests/ag_ui/test_endpoint.py\n- packages/ag-ui/tests/ag_ui/test_public_exports.py\n- packages/ag-ui/AGENTS.md\n\nVerification:\n- uv run pytest packages/ag-ui/tests/ag_ui/test_snapshots.py packages/ag-ui/tests/ag_ui/test_public_exports.py packages/ag-ui/tests/ag_ui/test_endpoint.py::test_endpoint_requires_snapshot_scope_resolver_when_store_configured packages/ag-ui/tests/ag_ui/test_endpoint.py::test_endpoint_accepts_snapshot_store_with_scope_resolver -q\n- uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_endpoint_requires_snapshot_scope_resolver_when_store_configured packages/ag-ui/tests/ag_ui/test_endpoint.py::test_endpoint_requires_snapshot_scope_resolver_when_wrapped_runner_has_store packages/ag-ui/tests/ag_ui/test_endpoint.py::test_endpoint_accepts_snapshot_store_with_scope_resolver -q\n- uv run poe syntax -P ag-ui -C\n- uv run poe pyright -P ag-ui\n- uv run poe syntax -P core -C\n- uv run poe pyright -P core\n- uv run poe typing -P ag-ui\n- uv run poe typing -P core\n- uv run poe test -P ag-ui\n- uv run poe check -P ag-ui\n- git diff --check\n- git diff --cached --check\n\nBlockers / next iteration:\n- No blockers. Next slice can use the store contract to capture and hydrate agent snapshots.\n- uv repeatedly refreshed azure-ai-projects in uv.lock during local runs; reverted the generated lockfile churn because this change does not alter dependencies.\n- The poe-check commit hook was skipped after manual verification because it reformatted unrelated core MCP files outside this task. * feat(ag-ui): hydrate agent threads from snapshots Key decisions: - Resolve Snapshot Scope per endpoint request and pass it to the AG-UI runner only when snapshot storage is active. - Treat empty messages with no resume payload as an agent Hydrate Request when a scoped snapshot store is configured, replaying stored Shared State and message snapshots without invoking the wrapped agent. - Save the latest replayable agent message snapshot and Shared State at normal completion under Snapshot Scope plus AG-UI Thread id; no durable or file-backed store is introduced. Files changed: - packages/ag-ui/agent_framework_ag_ui/_agent_run.py - packages/ag-ui/agent_framework_ag_ui/_endpoint.py - packages/ag-ui/agent_framework_ag_ui/_snapshots.py - packages/ag-ui/tests/ag_ui/test_endpoint.py Verification: - uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_hydrates_stored_thread_snapshot_without_invoking_agent -q - uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_hydrates_stored_thread_snapshot_without_invoking_agent packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_hydrates_snapshots_by_scope_and_thread -q - uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_endpoint_empty_messages packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_hydrates_stored_thread_snapshot_without_invoking_agent packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_hydrates_snapshots_by_scope_and_thread -q - uv run poe syntax -P ag-ui -C - uv run poe pyright -P ag-ui - uv run poe typing -P ag-ui - uv run poe test -P ag-ui - uv run poe check -P ag-ui - git diff --check - git diff --cached --check Blockers / next iteration: - No blockers. Next slice can reconstruct normal new-user agent turns from stored snapshots. - uv repeatedly refreshed azure-ai-projects in uv.lock during local runs; reverted the generated lockfile churn because this change does not alter dependencies. - The poe-check commit hook was skipped after manual verification because it refreshed unrelated uv.lock dependency resolution. * feat(ag-ui): reconstruct agent turns from snapshots Key decisions: - Load scoped thread snapshots for non-hydrate agent requests only when snapshot storage is active and no resume payload is present. - Rebuild prior AG-UI history from stored snapshot messages, preserving the incoming new user suffix and treating stored snapshot content as authoritative over conflicting prior client history. - Merge stored Shared State with request state overrides before schema defaults and existing state-context injection. Files changed: - packages/ag-ui/agent_framework_ag_ui/_agent_run.py - packages/ag-ui/tests/ag_ui/test_endpoint.py Verification: - uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_prepends_stored_snapshot_for_new_user_turn -q - uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_deduplicates_full_history_and_merges_fresh_state -q - uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_endpoint_empty_messages packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_hydrates_stored_thread_snapshot_without_invoking_agent packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_hydrates_snapshots_by_scope_and_thread packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_prepends_stored_snapshot_for_new_user_turn packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_deduplicates_full_history_and_merges_fresh_state -q - uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py -q - uv run poe syntax -P ag-ui -C - uv run poe pyright -P ag-ui - uv run poe test -P ag-ui - uv run poe check -P ag-ui - uv run poe typing -P ag-ui - git diff --check - git diff --cached --check Blockers / next iteration: - No blockers. Next slice can enable workflow AG-UI Thread Snapshot persistence and hydration. - uv repeatedly refreshed azure-ai-projects in uv.lock during local runs; reverted the generated lockfile churn because this change does not alter dependencies. - The poe-check commit hook was skipped after manual verification because it refreshes unrelated uv.lock dependency resolution. * feat(ag-ui): hydrate workflow threads from snapshots Key decisions: - Handle workflow Hydrate Requests before resolving or invoking the wrapped workflow when snapshot storage and Snapshot Scope are active. - Capture only replayable workflow protocol data: workflow-emitted state snapshots, workflow-emitted message snapshots, and synthesized messages from text/tool output. - Keep workflow snapshot capture inactive without configured persistence, and skip saving snapshots when the workflow stream emits RUN_ERROR. Files changed: - packages/ag-ui/agent_framework_ag_ui/_workflow.py - packages/ag-ui/tests/ag_ui/test_endpoint.py Verification: - uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_workflow_endpoint_hydrates_emitted_snapshots_without_invoking_workflow packages/ag-ui/tests/ag_ui/test_endpoint.py::test_workflow_endpoint_hydrates_synthesized_text_and_tool_snapshot -q - uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py -q - uv run pytest packages/ag-ui/tests/ag_ui/golden/test_scenario_workflow.py -q - uv run poe syntax -P ag-ui -C - uv run poe pyright -P ag-ui - uv run poe test -P ag-ui - uv run poe typing -P ag-ui - uv run poe check -P ag-ui - git diff --check - git diff --cached --check Blockers / next iteration: - No blockers. Next slice can preserve interruption state and protect snapshots on errors across agent and workflow endpoints. - uv repeatedly refreshed azure-ai-projects in uv.lock during local runs; reverted the generated lockfile churn because this change does not alter dependencies. - The poe-check commit hook was skipped after manual verification because it refreshes unrelated uv.lock dependency resolution. * feat(ag-ui): preserve interrupted thread snapshots Key decisions: - Capture workflow RUN_FINISHED interrupt metadata in replayable AG-UI Thread Snapshots so Hydrate Requests can restore pending workflow actions without invoking or resuming the workflow. - Keep failed agent and workflow runs from replacing the last good snapshot; RUN_ERROR streams leave the previous snapshot available for hydration. - Verify interruption hydration through endpoint-level AG-UI streams for both agent and workflow wrappers, including Shared State replay and no wrapped runner invocation. Files changed: - packages/ag-ui/agent_framework_ag_ui/_workflow.py - packages/ag-ui/tests/ag_ui/test_endpoint.py Verification: - uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_workflow_endpoint_hydrates_interrupted_thread_without_invoking_workflow -q - uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_hydrates_interrupted_thread_without_invoking_agent packages/ag-ui/tests/ag_ui/test_endpoint.py::test_agent_endpoint_run_error_does_not_overwrite_previous_snapshot packages/ag-ui/tests/ag_ui/test_endpoint.py::test_workflow_endpoint_hydrates_interrupted_thread_without_invoking_workflow packages/ag-ui/tests/ag_ui/test_endpoint.py::test_workflow_endpoint_run_error_does_not_overwrite_previous_snapshot -q - uv run pytest packages/ag-ui/tests/ag_ui/test_endpoint.py -q - uv run pytest packages/ag-ui/tests/ag_ui/golden/test_scenario_workflow.py -q - uv run poe syntax -P ag-ui -C - uv run poe pyright -P ag-ui - uv run poe test -P ag-ui - uv run poe typing -P ag-ui - uv run poe check -P ag-ui - git diff --check - git diff --cached --check Blockers / next iteration: - No blockers. Next slice can document AG-UI Thread Snapshot security and usage. - uv repeatedly refreshed azure-ai-projects in uv.lock during local runs; reverted the generated lockfile churn because this change does not alter dependencies. - The poe-check commit hook was skipped after manual verification because it refreshes unrelated uv.lock dependency resolution. * docs(ag-ui): document thread snapshot security Key decisions: - Document AG-UI Thread Snapshot persistence as opt-in and disabled unless a snapshot_store is configured. - Place Snapshot Scope guidance next to endpoint authentication guidance, making clear that AG-UI Thread ids identify threads but do not authorize snapshot access. - Describe built-in storage as in-memory only, process-local, latest-only, and not durable production storage; durable stores remain app-owned implementations of AGUIThreadSnapshotStore. - Call out snapshot confidentiality impact and that no file-backed AG-UI snapshot store is provided. Files changed: - packages/ag-ui/README.md Verification: - uv run python scripts/check_md_code_blocks.py packages/ag-ui/README.md --no-glob - git diff --check - git diff --cached --check - commit hook without SKIP ran changed-package lint/format and AG-UI README markdown-code-lint successfully before stopping because uv.lock was modified - uv run poe markdown-code-lint (failed due existing unrelated packages/mistral/README.md missing agent_framework_mistral import resolution; changed AG-UI README blocks passed) Blockers / next iteration: - No blockers. Local issue/PRD planning artifacts remain uncommitted. - uv refreshed azure-ai-projects in uv.lock during markdown lint and the commit hook; reverted the generated lockfile churn because this documentation change does not alter dependencies. - The poe-check commit hook was skipped after manual verification because it refreshes unrelated uv.lock dependency resolution. * fix(ag-ui): harden thread snapshot persistence edge cases - Persist the completed confirm_changes turn with interrupt=None so hydration no longer replays a stale pending interrupt after the user responds; resume requests prepend stored history so the persisted thread is not truncated. - Defer endpoint default_state application to the runners when snapshot persistence is active, filling only keys missing from both the stored snapshot state and the request state so defaults never reset persisted Shared State. - Always fold the turn's output into the persisted messages snapshot even when the outbound MESSAGES_SNAPSHOT event is suppressed for predictive tools without confirmation. - Load the stored snapshot on workflow follow-up turns, reconstruct full thread history into the run input, and seed the snapshot builder with merged state so saving a new turn no longer replaces prior history. - Move snapshot message reconstruction helpers to _run_common for reuse by the workflow runner; load stored agent snapshots on resume turns for state merge. - Add endpoint regression tests for all four scenarios. * fix(ag-ui): protect snapshot history on resume and harden suffix trust - Prepend stored thread history when persisting snapshots for resume runs on both the agent and workflow paths, so a resumed interrupt no longer overwrites the stored thread with just the resume turn's output. - Filter the incoming message suffix during thread reconstruction: only user turns and tool results answering backend-issued tool calls (stored tool calls or pending interrupts) may extend authoritative history. Client-forged assistant and tool messages are dropped and logged instead of being persisted and replayed. - Close the workflow snapshot builder's tool-call group when a tool result or text message lands, so synthesized transcripts keep tool results adjacent to their tool_calls message and stay valid as provider replay history. - Export DEFAULT_MAX_THREAD_SNAPSHOTS from agent_framework_ag_ui and expose SnapshotScopeResolver through the core ag_ui facade and stub. - Add regression tests for agent and workflow resume history preservation, forged suffix rejection, builder tool-call grouping, and the export surface. * fix(ag-ui): tolerate snapshot save failures and scope workflow cache - Wrap snapshot_store.save() on both the agent and workflow paths so a transient store failure (timeout, connection refused) is logged instead of propagating. Previously a failing save converted an already-streamed successful run into RUN_ERROR, and on the workflow path emitted RUN_ERROR after RUN_FINISHED, violating the single-terminal-event invariant. The previous snapshot stays available for hydration. - Key the workflow_factory instance cache by (snapshot_scope, thread_id). The Snapshot Scope is the authorization boundary, so the same thread id under different scopes no longer shares an in-memory workflow instance. clear_thread_workflow accepts an optional snapshot_scope and clears all scopes for the thread when omitted. - Add tests: save-failure tolerance for agent and workflow endpoints, scope-isolated workflow cache, async snapshot_scope_resolver support, and in-memory store key validation errors. * fix(ci): ignore all dotnet.microsoft.com links in linkspector The existing ignore pattern only matched https://dotnet.microsoft.com/download, but Microsoft sites insert a locale segment between host and path (e.g. /en-us/download/dotnet/10.0), so localized links slip past the pattern and get checked. dotnet.microsoft.com bot-blocks CI link checkers with intermittent 403s across the whole site, which fails markdown-link-check on unrelated pull requests since linkspector scans the entire repository. Ignore the domain wholesale, matching how platform.openai.com is already handled for the same reason. A 403 from bot blocking is indistinguishable from a removed page, so the checker cannot produce a meaningful signal for this domain either way. * ag-ui: simplify raw_messages assignment and drop OrderedDict - Replace list(cast(...)) with a typed annotation for raw_messages (_agent_run.py:866) per review suggestion - Replace OrderedDict with a plain dict in InMemoryAGUIThreadSnapshotStore (_snapshots.py:136); regular dicts are insertion-order-safe since Python 3.7, so OrderedDict is unnecessary. Update _evict_oldest to use next(iter(...)) for FIFO removal instead of popitem(last=False). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #2458: review comment fixes --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
4c1b9efa8c
commit
76b2b1bf39
@@ -10,10 +10,12 @@ AG-UI protocol integration for building agent UIs with the AG-UI standard.
|
||||
- **`AGUIHttpService`** - HTTP service for AG-UI endpoints
|
||||
- **`AGUIEventConverter`** - Converts between Agent Framework and AG-UI events
|
||||
- **`add_agent_framework_fastapi_endpoint()`** - Add AG-UI endpoint to FastAPI app (`SupportsAgentRun` or `Workflow`)
|
||||
- **`InMemoryAGUIThreadSnapshotStore`** - Memory-only latest AG-UI Thread Snapshot store for local development, demos, and tests
|
||||
|
||||
## Types
|
||||
|
||||
- **`AGUIRequest`** / **`AGUIChatOptions`** - Request types
|
||||
- **`AGUIThreadSnapshot`** / **`AGUIThreadSnapshotStore`** - Replayable thread snapshot model and scoped async store protocol
|
||||
- **`availableInterrupts` / `resume`** - Optional interrupt configuration and continuation payloads
|
||||
- **`AgentState`** / **`RunMetadata`** - State management types
|
||||
- **`PredictStateConfig`** - Configuration for state prediction
|
||||
|
||||
@@ -198,6 +198,71 @@ The `dependencies` parameter accepts any FastAPI dependency, enabling integratio
|
||||
|
||||
For a complete authentication example, see [getting_started/server.py](getting_started/server.py).
|
||||
|
||||
## AG-UI Thread Snapshots
|
||||
|
||||
AG-UI Thread Snapshot persistence is opt-in and disabled by default. Existing endpoints keep their current behavior
|
||||
unless you provide a `snapshot_store`.
|
||||
|
||||
Thread snapshots let an AG-UI frontend recover replayable UI state after a refresh. When snapshot persistence is
|
||||
enabled, the endpoint stores the latest replayable snapshot for an AG-UI Thread within an application-defined
|
||||
Snapshot Scope. A Hydrate Request is an AG-UI request with a known `threadId`, `messages: []`, and no `resume`
|
||||
payload. Hydration replays the stored Shared State, message snapshot, and interruption metadata when available,
|
||||
then finishes without invoking the wrapped agent or workflow.
|
||||
|
||||
Use the built-in in-memory store for local development, demos, and tests:
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
|
||||
from agent_framework.ag_ui import InMemoryAGUIThreadSnapshotStore, add_agent_framework_fastapi_endpoint
|
||||
|
||||
app = FastAPI()
|
||||
agent = ...
|
||||
snapshot_store = InMemoryAGUIThreadSnapshotStore(max_snapshots=500)
|
||||
|
||||
|
||||
def resolve_snapshot_scope(request):
|
||||
# Local demo scope. Production apps should derive the scope from authenticated user or tenant context.
|
||||
del request
|
||||
return "local-demo"
|
||||
|
||||
|
||||
add_agent_framework_fastapi_endpoint(
|
||||
app,
|
||||
agent,
|
||||
"/",
|
||||
snapshot_store=snapshot_store,
|
||||
snapshot_scope_resolver=resolve_snapshot_scope,
|
||||
)
|
||||
```
|
||||
|
||||
A frontend can then hydrate the latest stored snapshot for the scoped thread:
|
||||
|
||||
```json
|
||||
{
|
||||
"threadId": "thread-1",
|
||||
"messages": []
|
||||
}
|
||||
```
|
||||
|
||||
Endpoint configuration requires `snapshot_scope_resolver` whenever a snapshot store is configured, including when
|
||||
the store is already set on a pre-wrapped `AgentFrameworkAgent` or `AgentFrameworkWorkflow`. The resolver returns
|
||||
the application-defined Snapshot Scope used with the AG-UI Thread id as the storage key.
|
||||
|
||||
AG-UI Thread ids identify AG-UI Threads; they do not authorize snapshot access. Do not treat a thread id as a bearer
|
||||
credential or tenant boundary. Production applications must authenticate and authorize every AG-UI endpoint request
|
||||
and choose a Snapshot Scope that represents the app's real access boundary, such as an authenticated user, tenant,
|
||||
or workspace. Do not rely on untrusted client-provided fields by themselves to choose that boundary.
|
||||
|
||||
Stored snapshots are untrusted application data with confidentiality impact. They may contain sensitive user text,
|
||||
model output, tool results, function arguments, UI payloads, Shared State, and interruption data. The built-in
|
||||
`InMemoryAGUIThreadSnapshotStore` is in-memory only, process-local, bounded, latest-only, and not durable production
|
||||
storage. It is cleared on process restart and is not shared across workers.
|
||||
|
||||
No file-backed AG-UI snapshot store is provided by the package. Applications that need durable persistence should
|
||||
provide an app-owned implementation of the `AGUIThreadSnapshotStore` protocol and own storage hardening, including
|
||||
encryption, access control, retention, audit, data residency, and deletion behavior.
|
||||
|
||||
## Architecture
|
||||
|
||||
The package uses a clean, orchestrator-based architecture:
|
||||
|
||||
@@ -9,6 +9,15 @@ from ._client import AGUIChatClient
|
||||
from ._endpoint import add_agent_framework_fastapi_endpoint
|
||||
from ._event_converters import AGUIEventConverter
|
||||
from ._http_service import AGUIHttpService
|
||||
from ._snapshots import (
|
||||
DEFAULT_MAX_THREAD_SNAPSHOTS,
|
||||
AGUIThreadID,
|
||||
AGUIThreadSnapshot,
|
||||
AGUIThreadSnapshotStore,
|
||||
InMemoryAGUIThreadSnapshotStore,
|
||||
SnapshotScope,
|
||||
SnapshotScopeResolver,
|
||||
)
|
||||
from ._state import state_update
|
||||
from ._types import AgentState, AGUIChatOptions, AGUIRequest, PredictStateConfig, RunMetadata
|
||||
from ._workflow import AgentFrameworkWorkflow, WorkflowFactory
|
||||
@@ -31,9 +40,16 @@ __all__ = [
|
||||
"AGUIEventConverter",
|
||||
"AGUIHttpService",
|
||||
"AGUIRequest",
|
||||
"AGUIThreadID",
|
||||
"AGUIThreadSnapshot",
|
||||
"AGUIThreadSnapshotStore",
|
||||
"AgentState",
|
||||
"InMemoryAGUIThreadSnapshotStore",
|
||||
"PredictStateConfig",
|
||||
"RunMetadata",
|
||||
"SnapshotScope",
|
||||
"SnapshotScopeResolver",
|
||||
"DEFAULT_MAX_THREAD_SNAPSHOTS",
|
||||
"DEFAULT_TAGS",
|
||||
"state_update",
|
||||
"__version__",
|
||||
|
||||
@@ -10,6 +10,7 @@ from ag_ui.core import BaseEvent
|
||||
from agent_framework import SupportsAgentRun
|
||||
|
||||
from ._agent_run import PendingApprovalEntry, run_agent_stream
|
||||
from ._snapshots import AGUIThreadSnapshotStore
|
||||
|
||||
|
||||
class AgentConfig:
|
||||
@@ -21,6 +22,7 @@ class AgentConfig:
|
||||
predict_state_config: dict[str, dict[str, str]] | None = None,
|
||||
use_service_session: bool = False,
|
||||
require_confirmation: bool = True,
|
||||
snapshot_store: AGUIThreadSnapshotStore | None = None,
|
||||
):
|
||||
"""Initialize agent configuration.
|
||||
|
||||
@@ -29,11 +31,14 @@ class AgentConfig:
|
||||
predict_state_config: Configuration for predictive state updates
|
||||
use_service_session: Whether the agent session is service-managed
|
||||
require_confirmation: Whether predictive updates require user confirmation before applying
|
||||
snapshot_store: Optional AG-UI Thread Snapshot store. Snapshot persistence remains inactive unless
|
||||
endpoint setup also provides an explicit Snapshot Scope resolver.
|
||||
"""
|
||||
self.state_schema = self._normalize_state_schema(state_schema)
|
||||
self.predict_state_config = predict_state_config or {}
|
||||
self.use_service_session = use_service_session
|
||||
self.require_confirmation = require_confirmation
|
||||
self.snapshot_store = snapshot_store
|
||||
|
||||
@staticmethod
|
||||
def _normalize_state_schema(state_schema: Any | None) -> dict[str, Any]:
|
||||
@@ -79,6 +84,7 @@ class AgentFrameworkAgent:
|
||||
predict_state_config: dict[str, dict[str, str]] | None = None,
|
||||
require_confirmation: bool = True,
|
||||
use_service_session: bool = False,
|
||||
snapshot_store: AGUIThreadSnapshotStore | None = None,
|
||||
):
|
||||
"""Initialize the AG-UI compatible agent wrapper.
|
||||
|
||||
@@ -90,6 +96,8 @@ class AgentFrameworkAgent:
|
||||
predict_state_config: Configuration for predictive state updates
|
||||
require_confirmation: Whether predictive updates require user confirmation before applying
|
||||
use_service_session: Whether the agent session is service-managed
|
||||
snapshot_store: Optional AG-UI Thread Snapshot store. Snapshot persistence remains inactive unless
|
||||
endpoint setup also provides an explicit Snapshot Scope resolver.
|
||||
"""
|
||||
self.agent = agent
|
||||
self.name = name or getattr(agent, "name", "agent")
|
||||
@@ -100,6 +108,7 @@ class AgentFrameworkAgent:
|
||||
predict_state_config=predict_state_config,
|
||||
use_service_session=use_service_session,
|
||||
require_confirmation=require_confirmation,
|
||||
snapshot_store=snapshot_store,
|
||||
)
|
||||
|
||||
# Server-side registry of pending approval requests.
|
||||
@@ -110,6 +119,11 @@ class AgentFrameworkAgent:
|
||||
self._pending_approvals: OrderedDict[str, PendingApprovalEntry] = OrderedDict()
|
||||
self._pending_approvals_max_size: int = 10_000
|
||||
|
||||
@property
|
||||
def snapshot_store(self) -> AGUIThreadSnapshotStore | None:
|
||||
"""Configured AG-UI Thread Snapshot store, if any."""
|
||||
return self.config.snapshot_store
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: dict[str, Any],
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
from __future__ import annotations # noqa: I001
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
@@ -52,9 +53,11 @@ from ._run_common import (
|
||||
_extract_tool_result_display, # type: ignore
|
||||
_has_only_tool_calls, # type: ignore
|
||||
_normalize_resume_interrupts, # type: ignore
|
||||
_reconstruct_messages_from_thread_snapshot, # type: ignore
|
||||
_resolve_ui_payload, # type: ignore
|
||||
_stringify_tool_result, # type: ignore
|
||||
)
|
||||
from ._snapshots import AGUIThreadSnapshot, _DEFAULT_STATE_INPUT_KEY, _SNAPSHOT_SCOPE_INPUT_KEY
|
||||
from ._utils import (
|
||||
canonical_function_arguments,
|
||||
convert_agui_tools_to_agent_framework,
|
||||
@@ -748,6 +751,85 @@ def _build_messages_snapshot(
|
||||
return MessagesSnapshotEvent(messages=all_messages) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _event_messages_to_snapshot_dicts(messages: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Convert AG-UI message event models back to plain snapshot dictionaries."""
|
||||
safe_messages = make_json_safe(messages)
|
||||
if not isinstance(safe_messages, list):
|
||||
return []
|
||||
return [cast(dict[str, Any], message) for message in safe_messages if isinstance(message, dict)]
|
||||
|
||||
|
||||
def _text_events_to_snapshot_messages(events: list[BaseEvent]) -> list[dict[str, Any]]:
|
||||
"""Convert streamed text-message events into snapshot message dictionaries."""
|
||||
messages: list[dict[str, Any]] = []
|
||||
messages_by_id: dict[str, dict[str, Any]] = {}
|
||||
for event in events:
|
||||
if isinstance(event, TextMessageStartEvent):
|
||||
message: dict[str, Any] = {"id": event.message_id, "role": event.role, "content": ""}
|
||||
messages.append(message)
|
||||
messages_by_id[event.message_id] = message
|
||||
elif isinstance(event, TextMessageContentEvent):
|
||||
open_message = messages_by_id.get(event.message_id)
|
||||
if open_message is not None:
|
||||
open_message["content"] = f"{open_message['content']}{event.delta}"
|
||||
return [message for message in messages if message.get("content")]
|
||||
|
||||
|
||||
async def _hydrate_thread_snapshot(
|
||||
*,
|
||||
config: AgentConfig,
|
||||
scope: str,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
) -> AsyncGenerator[BaseEvent]:
|
||||
"""Replay the latest stored AG-UI Thread Snapshot without invoking the agent."""
|
||||
yield RunStartedEvent(run_id=run_id, thread_id=thread_id)
|
||||
if config.snapshot_store is None:
|
||||
yield _build_run_finished_event(run_id=run_id, thread_id=thread_id)
|
||||
return
|
||||
|
||||
snapshot = await config.snapshot_store.get(scope=scope, thread_id=thread_id)
|
||||
if snapshot is None:
|
||||
yield _build_run_finished_event(run_id=run_id, thread_id=thread_id)
|
||||
return
|
||||
|
||||
if snapshot.state is not None:
|
||||
yield StateSnapshotEvent(snapshot=snapshot.state)
|
||||
if snapshot.messages:
|
||||
yield MessagesSnapshotEvent(messages=snapshot.messages) # type: ignore[arg-type]
|
||||
yield _build_run_finished_event(run_id=run_id, thread_id=thread_id, interrupts=snapshot.interrupt)
|
||||
|
||||
|
||||
async def _save_thread_snapshot(
|
||||
*,
|
||||
config: AgentConfig,
|
||||
scope: str | None,
|
||||
thread_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
state: dict[str, Any] | None,
|
||||
interrupt: list[dict[str, Any]] | None,
|
||||
) -> None:
|
||||
"""Save the latest replayable AG-UI Thread Snapshot when persistence is configured."""
|
||||
if config.snapshot_store is None or scope is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await config.snapshot_store.save(
|
||||
scope=scope,
|
||||
thread_id=thread_id,
|
||||
snapshot=AGUIThreadSnapshot(messages=messages, state=state, interrupt=interrupt),
|
||||
)
|
||||
except Exception:
|
||||
# The run itself already streamed successfully; a transient store failure
|
||||
# must not surface as RUN_ERROR for a completed run. The previous snapshot
|
||||
# stays available for hydration.
|
||||
logger.exception(
|
||||
"Failed to save AG-UI Thread Snapshot for scope=%s thread_id=%s; keeping previous snapshot.",
|
||||
scope,
|
||||
thread_id,
|
||||
)
|
||||
|
||||
|
||||
async def run_agent_stream(
|
||||
input_data: dict[str, Any],
|
||||
agent: SupportsAgentRun,
|
||||
@@ -774,15 +856,53 @@ async def run_agent_stream(
|
||||
# Parse IDs
|
||||
thread_id = input_data.get("thread_id") or input_data.get("threadId") or str(uuid.uuid4())
|
||||
run_id = input_data.get("run_id") or input_data.get("runId") or str(uuid.uuid4())
|
||||
|
||||
# Initialize flow state with schema defaults
|
||||
flow = FlowState()
|
||||
if input_data.get("state"):
|
||||
flow.current_state = dict(input_data["state"])
|
||||
snapshot_scope = cast(str | None, input_data.get(_SNAPSHOT_SCOPE_INPUT_KEY))
|
||||
|
||||
state_schema = cast(dict[str, Any], getattr(config, "state_schema", {}) or {})
|
||||
predict_state_config = cast(dict[str, dict[str, str]], getattr(config, "predict_state_config", {}) or {})
|
||||
|
||||
# Normalize messages
|
||||
available_interrupts = input_data.get("available_interrupts") or input_data.get("availableInterrupts")
|
||||
raw_messages: list[dict[str, Any]] = input_data.get("messages", []) or []
|
||||
resume_payload = _extract_resume_payload(input_data)
|
||||
if config.snapshot_store is not None and snapshot_scope is not None and not raw_messages and resume_payload is None:
|
||||
async for event in _hydrate_thread_snapshot(
|
||||
config=config,
|
||||
scope=snapshot_scope,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
|
||||
stored_snapshot: AGUIThreadSnapshot | None = None
|
||||
if config.snapshot_store is not None and snapshot_scope is not None:
|
||||
stored_snapshot = await config.snapshot_store.get(scope=snapshot_scope, thread_id=thread_id)
|
||||
if stored_snapshot is not None and resume_payload is None:
|
||||
raw_messages = _reconstruct_messages_from_thread_snapshot(
|
||||
stored_messages=stored_snapshot.messages,
|
||||
incoming_messages=raw_messages,
|
||||
stored_interrupt=stored_snapshot.interrupt,
|
||||
)
|
||||
|
||||
# Initialize flow state with stored state plus request-provided overrides.
|
||||
flow = FlowState()
|
||||
request_state = input_data.get("state")
|
||||
if stored_snapshot is not None and stored_snapshot.state is not None:
|
||||
flow.current_state = dict(stored_snapshot.state)
|
||||
if isinstance(request_state, dict):
|
||||
flow.current_state.update(request_state)
|
||||
elif isinstance(request_state, dict):
|
||||
flow.current_state = dict(request_state)
|
||||
|
||||
# Apply endpoint-deferred defaults only for keys missing from both the stored
|
||||
# snapshot state and the request state, so defaults never reset persisted state.
|
||||
deferred_default_state = cast(dict[str, Any] | None, input_data.get(_DEFAULT_STATE_INPUT_KEY))
|
||||
if deferred_default_state:
|
||||
for key, value in deferred_default_state.items():
|
||||
if key not in flow.current_state:
|
||||
flow.current_state[key] = copy.deepcopy(value)
|
||||
|
||||
# Apply schema defaults for missing state keys
|
||||
if state_schema:
|
||||
for key, schema in state_schema.items():
|
||||
@@ -801,10 +921,7 @@ async def run_agent_stream(
|
||||
current_state=flow.current_state,
|
||||
)
|
||||
|
||||
# Normalize messages
|
||||
available_interrupts = input_data.get("available_interrupts") or input_data.get("availableInterrupts")
|
||||
raw_messages = list(cast(list[dict[str, Any]], input_data.get("messages", []) or []))
|
||||
resume_messages = _resume_to_tool_messages(_extract_resume_payload(input_data))
|
||||
resume_messages = _resume_to_tool_messages(resume_payload)
|
||||
if available_interrupts:
|
||||
logger.debug("Received available interrupts metadata: %s", available_interrupts)
|
||||
if resume_messages:
|
||||
@@ -892,8 +1009,24 @@ async def run_agent_stream(
|
||||
# Emit approved state snapshot before confirmation message
|
||||
if approved_state_snapshot_emitted:
|
||||
yield StateSnapshotEvent(snapshot=flow.current_state)
|
||||
for event in _handle_step_based_approval(messages):
|
||||
confirmation_events = _handle_step_based_approval(messages)
|
||||
for event in confirmation_events:
|
||||
yield event
|
||||
# Persist the completed confirmation turn with interrupt=None so hydration
|
||||
# does not replay the stale pending interrupt after the user responded.
|
||||
persisted_messages = snapshot_messages + _text_events_to_snapshot_messages(confirmation_events)
|
||||
if resume_payload is not None and stored_snapshot is not None:
|
||||
# Resume requests carry only the synthesized interrupt response, so prepend
|
||||
# the stored thread history to avoid persisting a truncated thread.
|
||||
persisted_messages = [copy.deepcopy(message) for message in stored_snapshot.messages] + persisted_messages
|
||||
await _save_thread_snapshot(
|
||||
config=config,
|
||||
scope=snapshot_scope,
|
||||
thread_id=thread_id,
|
||||
messages=persisted_messages,
|
||||
state=cast(dict[str, Any], make_json_safe(flow.current_state)) if flow.current_state else None,
|
||||
interrupt=None,
|
||||
)
|
||||
yield _build_run_finished_event(run_id=run_id, thread_id=thread_id)
|
||||
return
|
||||
|
||||
@@ -905,6 +1038,9 @@ async def run_agent_stream(
|
||||
# Stream from agent - emit RunStarted after first update to get service IDs
|
||||
run_started_emitted = False
|
||||
all_updates: list[Any] = [] # Collect for structured output processing
|
||||
latest_state_snapshot: dict[str, Any] | None = (
|
||||
cast(dict[str, Any], make_json_safe(flow.current_state)) if flow.current_state else None
|
||||
)
|
||||
response_stream = agent.run(messages, stream=True, **run_kwargs)
|
||||
stream = await _normalize_response_stream(response_stream)
|
||||
async for update in stream:
|
||||
@@ -934,6 +1070,7 @@ async def run_agent_stream(
|
||||
yield CustomEvent(name="PredictState", value=predict_state_value)
|
||||
# Emit initial state snapshot only if we have both state_schema and state
|
||||
if state_schema and flow.current_state:
|
||||
latest_state_snapshot = cast(dict[str, Any], make_json_safe(flow.current_state))
|
||||
yield StateSnapshotEvent(snapshot=flow.current_state)
|
||||
run_started_emitted = True
|
||||
|
||||
@@ -975,6 +1112,8 @@ async def run_agent_stream(
|
||||
skip_text,
|
||||
config.require_confirmation,
|
||||
):
|
||||
if isinstance(event, StateSnapshotEvent):
|
||||
latest_state_snapshot = cast(dict[str, Any], make_json_safe(event.snapshot))
|
||||
yield event
|
||||
|
||||
# Stop if waiting for approval
|
||||
@@ -1019,6 +1158,7 @@ async def run_agent_stream(
|
||||
|
||||
if state_updates:
|
||||
flow.current_state.update(state_updates)
|
||||
latest_state_snapshot = cast(dict[str, Any], make_json_safe(flow.current_state))
|
||||
yield StateSnapshotEvent(snapshot=flow.current_state)
|
||||
logger.info(f"Emitted StateSnapshotEvent with updates: {list(state_updates.keys())}")
|
||||
|
||||
@@ -1056,6 +1196,7 @@ async def run_agent_stream(
|
||||
if result:
|
||||
state_key, state_value = result
|
||||
flow.current_state[state_key] = state_value
|
||||
latest_state_snapshot = cast(dict[str, Any], make_json_safe(flow.current_state))
|
||||
yield StateSnapshotEvent(snapshot=flow.current_state)
|
||||
except json.JSONDecodeError:
|
||||
# Ignore malformed JSON in tool arguments for predictive state;
|
||||
@@ -1136,7 +1277,12 @@ async def run_agent_stream(
|
||||
should_emit_snapshot = (
|
||||
flow.pending_tool_calls or flow.tool_results or flow.accumulated_text or flow.reasoning_messages
|
||||
)
|
||||
latest_messages_snapshot = snapshot_messages
|
||||
if should_emit_snapshot:
|
||||
# Always fold this turn's output into the persisted snapshot, even when the
|
||||
# outbound MESSAGES_SNAPSHOT event is suppressed for predictive tools.
|
||||
snapshot_event = _build_messages_snapshot(flow, snapshot_messages)
|
||||
latest_messages_snapshot = _event_messages_to_snapshot_dicts(list(snapshot_event.messages))
|
||||
# Check if we should suppress for predictive tool
|
||||
last_tool_name = None
|
||||
if flow.tool_results:
|
||||
@@ -1146,8 +1292,21 @@ async def run_agent_stream(
|
||||
if not _should_suppress_intermediate_snapshot(
|
||||
last_tool_name, predict_state_config, config.require_confirmation
|
||||
):
|
||||
yield _build_messages_snapshot(flow, snapshot_messages)
|
||||
yield snapshot_event
|
||||
|
||||
# Always emit RunFinished - confirm_changes tool call is complete (Start -> Args -> End)
|
||||
# The UI will show confirmation dialog and send a new request when user responds
|
||||
persisted_messages = latest_messages_snapshot
|
||||
if resume_payload is not None and stored_snapshot is not None:
|
||||
# Resume requests carry only the synthesized interrupt response, so prepend
|
||||
# the stored thread history to avoid persisting a truncated thread.
|
||||
persisted_messages = [copy.deepcopy(message) for message in stored_snapshot.messages] + persisted_messages
|
||||
await _save_thread_snapshot(
|
||||
config=config,
|
||||
scope=snapshot_scope,
|
||||
thread_id=thread_id,
|
||||
messages=persisted_messages,
|
||||
state=latest_state_snapshot,
|
||||
interrupt=flow.interrupts or None,
|
||||
)
|
||||
yield _build_run_finished_event(run_id=run_id, thread_id=thread_id, interrupts=flow.interrupts)
|
||||
|
||||
@@ -7,6 +7,7 @@ from __future__ import annotations
|
||||
import copy
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from inspect import isawaitable
|
||||
from typing import Any
|
||||
|
||||
from ag_ui.core import RunErrorEvent
|
||||
@@ -17,12 +18,58 @@ from fastapi.params import Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from ._agent import AgentFrameworkAgent
|
||||
from ._snapshots import (
|
||||
_DEFAULT_STATE_INPUT_KEY,
|
||||
_SNAPSHOT_SCOPE_INPUT_KEY,
|
||||
AGUIThreadSnapshotStore,
|
||||
SnapshotScopeResolver,
|
||||
)
|
||||
from ._types import AGUIRequest
|
||||
from ._workflow import AgentFrameworkWorkflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_snapshot_store(
|
||||
protocol_runner: AgentFrameworkAgent | AgentFrameworkWorkflow,
|
||||
) -> AGUIThreadSnapshotStore | None:
|
||||
if isinstance(protocol_runner, AgentFrameworkAgent):
|
||||
return protocol_runner.config.snapshot_store
|
||||
return protocol_runner.snapshot_store
|
||||
|
||||
|
||||
def _set_snapshot_store(
|
||||
protocol_runner: AgentFrameworkAgent | AgentFrameworkWorkflow,
|
||||
snapshot_store: AGUIThreadSnapshotStore,
|
||||
) -> None:
|
||||
if isinstance(protocol_runner, AgentFrameworkAgent):
|
||||
protocol_runner.config.snapshot_store = snapshot_store
|
||||
return
|
||||
protocol_runner.snapshot_store = snapshot_store
|
||||
|
||||
|
||||
def _configure_snapshot_persistence(
|
||||
protocol_runner: AgentFrameworkAgent | AgentFrameworkWorkflow,
|
||||
*,
|
||||
snapshot_store: AGUIThreadSnapshotStore | None,
|
||||
snapshot_scope_resolver: SnapshotScopeResolver | None,
|
||||
) -> None:
|
||||
existing_snapshot_store = _get_snapshot_store(protocol_runner)
|
||||
if snapshot_store is not None:
|
||||
if existing_snapshot_store is not None and existing_snapshot_store is not snapshot_store:
|
||||
raise ValueError("snapshot_store is already configured on the AG-UI runner.")
|
||||
if existing_snapshot_store is None:
|
||||
_set_snapshot_store(protocol_runner, snapshot_store)
|
||||
existing_snapshot_store = snapshot_store
|
||||
|
||||
if existing_snapshot_store is not None and snapshot_scope_resolver is None:
|
||||
raise ValueError(
|
||||
"snapshot_scope_resolver is required when snapshot_store is configured. "
|
||||
"AG-UI Thread ids identify threads but do not authorize snapshot access; "
|
||||
"provide a resolver that returns an explicit Snapshot Scope."
|
||||
)
|
||||
|
||||
|
||||
def add_agent_framework_fastapi_endpoint(
|
||||
app: FastAPI,
|
||||
agent: SupportsAgentRun | AgentFrameworkAgent | Workflow | AgentFrameworkWorkflow,
|
||||
@@ -33,6 +80,8 @@ def add_agent_framework_fastapi_endpoint(
|
||||
default_state: dict[str, Any] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
dependencies: Sequence[Depends] | None = None,
|
||||
snapshot_store: AGUIThreadSnapshotStore | None = None,
|
||||
snapshot_scope_resolver: SnapshotScopeResolver | None = None,
|
||||
) -> None:
|
||||
"""Add an AG-UI endpoint to a FastAPI app.
|
||||
|
||||
@@ -50,6 +99,10 @@ def add_agent_framework_fastapi_endpoint(
|
||||
These dependencies run before the endpoint handler. Use this to add
|
||||
authentication checks, rate limiting, or other middleware-like behavior.
|
||||
Example: `dependencies=[Depends(verify_api_key)]`
|
||||
snapshot_store: Optional AG-UI Thread Snapshot store. Snapshot persistence is opt-in and requires an
|
||||
explicit Snapshot Scope resolver.
|
||||
snapshot_scope_resolver: Optional resolver for the application-defined Snapshot Scope. Required whenever
|
||||
a snapshot store is configured because an AG-UI Thread id is not an authorization boundary.
|
||||
"""
|
||||
protocol_runner: AgentFrameworkAgent | AgentFrameworkWorkflow
|
||||
if isinstance(agent, AgentFrameworkWorkflow):
|
||||
@@ -63,10 +116,17 @@ def add_agent_framework_fastapi_endpoint(
|
||||
agent=agent,
|
||||
state_schema=state_schema,
|
||||
predict_state_config=predict_state_config,
|
||||
snapshot_store=snapshot_store,
|
||||
)
|
||||
else:
|
||||
raise TypeError("agent must be SupportsAgentRun, Workflow, AgentFrameworkAgent, or AgentFrameworkWorkflow.")
|
||||
|
||||
_configure_snapshot_persistence(
|
||||
protocol_runner,
|
||||
snapshot_store=snapshot_store,
|
||||
snapshot_scope_resolver=snapshot_scope_resolver,
|
||||
)
|
||||
|
||||
@app.post(path, tags=tags or ["AG-UI"], dependencies=dependencies, response_model=None) # type: ignore[arg-type]
|
||||
async def agent_endpoint(request_body: AGUIRequest) -> StreamingResponse:
|
||||
"""Handle AG-UI agent requests.
|
||||
@@ -76,11 +136,23 @@ def add_agent_framework_fastapi_endpoint(
|
||||
"""
|
||||
try:
|
||||
input_data = request_body.model_dump(exclude_none=True)
|
||||
snapshot_persistence_active = False
|
||||
if snapshot_scope_resolver is not None and _get_snapshot_store(protocol_runner) is not None:
|
||||
snapshot_scope = snapshot_scope_resolver(request_body)
|
||||
if isawaitable(snapshot_scope):
|
||||
snapshot_scope = await snapshot_scope
|
||||
input_data[_SNAPSHOT_SCOPE_INPUT_KEY] = snapshot_scope
|
||||
snapshot_persistence_active = True
|
||||
if default_state:
|
||||
state = input_data.setdefault("state", {})
|
||||
for key, value in default_state.items():
|
||||
if key not in state:
|
||||
state[key] = copy.deepcopy(value)
|
||||
if snapshot_persistence_active:
|
||||
# Defer default application to the runner so defaults only fill keys
|
||||
# missing from both the stored snapshot state and the request state.
|
||||
input_data[_DEFAULT_STATE_INPUT_KEY] = copy.deepcopy(default_state)
|
||||
else:
|
||||
state = input_data.setdefault("state", {})
|
||||
for key, value in default_state.items():
|
||||
if key not in state:
|
||||
state[key] = copy.deepcopy(value)
|
||||
logger.debug(
|
||||
f"[{path}] Received request - Run ID: {input_data.get('run_id', 'no-run-id')}, "
|
||||
f"Thread ID: {input_data.get('thread_id', 'no-thread-id')}, "
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
@@ -33,7 +34,7 @@ from agent_framework import Content
|
||||
|
||||
from ._orchestration._predictive_state import PredictiveStateHandler
|
||||
from ._state import TOOL_RESULT_DISPLAY_KEY, TOOL_RESULT_STATE_KEY
|
||||
from ._utils import generate_event_id, make_json_safe
|
||||
from ._utils import generate_event_id, make_json_safe, normalize_agui_role
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -733,3 +734,117 @@ def _emit_content(
|
||||
return _emit_text_reasoning(content, flow)
|
||||
logger.debug("Skipping unsupported content type in AG-UI emitter: %s", content_type)
|
||||
return events
|
||||
|
||||
|
||||
def _canonical_snapshot_message(message: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Normalize an AG-UI message for identity comparison without generated ids."""
|
||||
from ._message_adapters import agui_messages_to_snapshot_format
|
||||
|
||||
normalized_message = agui_messages_to_snapshot_format([copy.deepcopy(message)])[0]
|
||||
normalized_message.pop("id", None)
|
||||
return cast(dict[str, Any], make_json_safe(normalized_message))
|
||||
|
||||
|
||||
def _snapshot_messages_match(stored_message: dict[str, Any], incoming_message: dict[str, Any]) -> bool:
|
||||
"""Return whether an incoming message already represents the stored snapshot message."""
|
||||
stored_id = stored_message.get("id")
|
||||
incoming_id = incoming_message.get("id")
|
||||
if stored_id and incoming_id:
|
||||
return str(stored_id) == str(incoming_id)
|
||||
return _canonical_snapshot_message(stored_message) == _canonical_snapshot_message(incoming_message)
|
||||
|
||||
|
||||
def _latest_user_message_index(messages: list[dict[str, Any]]) -> int | None:
|
||||
"""Find the newest incoming user message index."""
|
||||
for index in range(len(messages) - 1, -1, -1):
|
||||
if normalize_agui_role(messages[index].get("role", "user")) == "user":
|
||||
return index
|
||||
return None
|
||||
|
||||
|
||||
def _known_tool_call_ids(
|
||||
stored_messages: list[dict[str, Any]],
|
||||
stored_interrupt: list[dict[str, Any]] | None,
|
||||
) -> set[str]:
|
||||
"""Collect tool call ids the backend previously issued for this thread."""
|
||||
known_ids: set[str] = set()
|
||||
for message in stored_messages:
|
||||
tool_calls = message.get("tool_calls") or message.get("toolCalls") or []
|
||||
if not isinstance(tool_calls, list):
|
||||
continue
|
||||
for tool_call in cast(list[Any], tool_calls):
|
||||
if isinstance(tool_call, dict):
|
||||
tool_call_id = cast(dict[str, Any], tool_call).get("id")
|
||||
if tool_call_id:
|
||||
known_ids.add(str(tool_call_id))
|
||||
for interrupt in stored_interrupt or []:
|
||||
interrupt_id = interrupt.get("id")
|
||||
if interrupt_id:
|
||||
known_ids.add(str(interrupt_id))
|
||||
return known_ids
|
||||
|
||||
|
||||
def _filter_untrusted_suffix(
|
||||
incoming_suffix: list[dict[str, Any]],
|
||||
*,
|
||||
stored_messages: list[dict[str, Any]],
|
||||
stored_interrupt: list[dict[str, Any]] | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Drop client-forged non-user messages before promoting them to stored history.
|
||||
|
||||
Only the user's own turns and tool results answering backend-issued tool calls
|
||||
(including pending interrupts) may extend the authoritative thread history.
|
||||
"""
|
||||
known_ids: set[str] | None = None
|
||||
filtered: list[dict[str, Any]] = []
|
||||
for message in incoming_suffix:
|
||||
raw_role = str(message.get("role", "")).lower()
|
||||
if raw_role == "user":
|
||||
filtered.append(message)
|
||||
continue
|
||||
if raw_role == "tool":
|
||||
tool_call_id = message.get("toolCallId") or message.get("tool_call_id") or message.get("actionExecutionId")
|
||||
if known_ids is None:
|
||||
known_ids = _known_tool_call_ids(stored_messages, stored_interrupt)
|
||||
if tool_call_id and str(tool_call_id) in known_ids:
|
||||
filtered.append(message)
|
||||
continue
|
||||
logger.warning(
|
||||
"Dropping client-supplied %r message from the incoming thread suffix; "
|
||||
"only user turns and tool results for backend-issued tool calls extend stored history.",
|
||||
raw_role or "unknown",
|
||||
)
|
||||
return filtered
|
||||
|
||||
|
||||
def _reconstruct_messages_from_thread_snapshot(
|
||||
*,
|
||||
stored_messages: list[dict[str, Any]],
|
||||
incoming_messages: list[dict[str, Any]],
|
||||
stored_interrupt: list[dict[str, Any]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Combine backend-owned prior history with the request-owned new user turn."""
|
||||
if not stored_messages or not incoming_messages:
|
||||
return incoming_messages
|
||||
|
||||
incoming_suffix: list[dict[str, Any]]
|
||||
if len(incoming_messages) >= len(stored_messages) and all(
|
||||
_snapshot_messages_match(stored_message, incoming_message)
|
||||
for stored_message, incoming_message in zip(stored_messages, incoming_messages)
|
||||
):
|
||||
incoming_suffix = incoming_messages[len(stored_messages) :]
|
||||
else:
|
||||
latest_user_index = _latest_user_message_index(incoming_messages)
|
||||
if latest_user_index is None:
|
||||
return incoming_messages
|
||||
incoming_suffix = incoming_messages[latest_user_index:]
|
||||
|
||||
incoming_suffix = _filter_untrusted_suffix(
|
||||
incoming_suffix,
|
||||
stored_messages=stored_messages,
|
||||
stored_interrupt=stored_interrupt,
|
||||
)
|
||||
|
||||
return [copy.deepcopy(message) for message in stored_messages] + [
|
||||
copy.deepcopy(message) for message in incoming_suffix
|
||||
]
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""AG-UI Thread Snapshot storage primitives."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._types import AGUIRequest
|
||||
|
||||
SnapshotScope: TypeAlias = str
|
||||
"""Application-defined scope for authorizing access to AG-UI Thread Snapshots."""
|
||||
|
||||
AGUIThreadID: TypeAlias = str
|
||||
"""AG-UI Thread identifier within a Snapshot Scope."""
|
||||
|
||||
SnapshotScopeResolver: TypeAlias = Callable[["AGUIRequest"], str | Awaitable[str]]
|
||||
"""Callable that resolves the Snapshot Scope for an AG-UI endpoint request."""
|
||||
|
||||
_SnapshotKey: TypeAlias = tuple[SnapshotScope, AGUIThreadID]
|
||||
|
||||
DEFAULT_MAX_THREAD_SNAPSHOTS = 1_000
|
||||
_SNAPSHOT_SCOPE_INPUT_KEY = "__ag_ui_snapshot_scope"
|
||||
_DEFAULT_STATE_INPUT_KEY = "__ag_ui_default_state"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AGUIThreadSnapshot:
|
||||
"""Replayable AG-UI Thread state.
|
||||
|
||||
AG-UI Thread Snapshots intentionally contain only data that can be replayed
|
||||
to a UI: message snapshots, optional Shared State, and optional interruption
|
||||
state. They do not include raw events, request metadata, auth claims,
|
||||
diagnostics, traces, or provider responses.
|
||||
|
||||
Attributes:
|
||||
messages: Replayable AG-UI message snapshots.
|
||||
state: Optional AG-UI Shared State snapshot.
|
||||
interrupt: Optional interruption state from ``RUN_FINISHED.interrupt``.
|
||||
"""
|
||||
|
||||
messages: list[dict[str, Any]] = field(default_factory=list)
|
||||
state: dict[str, Any] | None = None
|
||||
interrupt: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AGUIThreadSnapshotStore(Protocol):
|
||||
"""Async store for latest AG-UI Thread Snapshots keyed by scope and thread id."""
|
||||
|
||||
async def save(
|
||||
self,
|
||||
*,
|
||||
scope: SnapshotScope,
|
||||
thread_id: AGUIThreadID,
|
||||
snapshot: AGUIThreadSnapshot,
|
||||
) -> None:
|
||||
"""Save the latest snapshot for an AG-UI Thread within a Snapshot Scope.
|
||||
|
||||
Args:
|
||||
scope: Application-defined Snapshot Scope. This is part of the
|
||||
storage key and must represent the app's authorization boundary.
|
||||
thread_id: AG-UI Thread id within the scope.
|
||||
snapshot: Snapshot to save.
|
||||
"""
|
||||
...
|
||||
|
||||
async def get(
|
||||
self,
|
||||
*,
|
||||
scope: SnapshotScope,
|
||||
thread_id: AGUIThreadID,
|
||||
) -> AGUIThreadSnapshot | None:
|
||||
"""Get the latest snapshot for an AG-UI Thread within a Snapshot Scope.
|
||||
|
||||
Args:
|
||||
scope: Application-defined Snapshot Scope.
|
||||
thread_id: AG-UI Thread id within the scope.
|
||||
|
||||
Returns:
|
||||
The latest snapshot, or ``None`` when no snapshot exists for the key.
|
||||
"""
|
||||
...
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
*,
|
||||
scope: SnapshotScope,
|
||||
thread_id: AGUIThreadID,
|
||||
) -> bool:
|
||||
"""Delete the latest snapshot for an AG-UI Thread within a Snapshot Scope.
|
||||
|
||||
Args:
|
||||
scope: Application-defined Snapshot Scope.
|
||||
thread_id: AG-UI Thread id within the scope.
|
||||
|
||||
Returns:
|
||||
``True`` when a snapshot was deleted, otherwise ``False``.
|
||||
"""
|
||||
...
|
||||
|
||||
async def clear(self, *, scope: SnapshotScope | None = None) -> None:
|
||||
"""Clear saved snapshots.
|
||||
|
||||
Args:
|
||||
scope: Optional Snapshot Scope to clear. When omitted, all in-memory
|
||||
snapshots are cleared.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class InMemoryAGUIThreadSnapshotStore:
|
||||
"""Bounded memory-only latest snapshot store for local development, demos, and tests.
|
||||
|
||||
This store keeps at most one snapshot per ``(scope, thread_id)`` key. It is
|
||||
process-local and not durable production storage.
|
||||
"""
|
||||
|
||||
def __init__(self, *, max_snapshots: int = DEFAULT_MAX_THREAD_SNAPSHOTS) -> None:
|
||||
"""Initialize the in-memory snapshot store.
|
||||
|
||||
Keyword Args:
|
||||
max_snapshots: Maximum number of scoped thread snapshots to retain.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``max_snapshots`` is less than 1.
|
||||
"""
|
||||
if max_snapshots < 1:
|
||||
raise ValueError("max_snapshots must be greater than 0.")
|
||||
self._max_snapshots = max_snapshots
|
||||
self._snapshots: dict[_SnapshotKey, AGUIThreadSnapshot] = {}
|
||||
|
||||
async def save(
|
||||
self,
|
||||
*,
|
||||
scope: SnapshotScope,
|
||||
thread_id: AGUIThreadID,
|
||||
snapshot: AGUIThreadSnapshot,
|
||||
) -> None:
|
||||
"""Save the latest snapshot for an AG-UI Thread within a Snapshot Scope."""
|
||||
key = self._key(scope=scope, thread_id=thread_id)
|
||||
if key in self._snapshots:
|
||||
del self._snapshots[key]
|
||||
self._snapshots[key] = copy.deepcopy(snapshot)
|
||||
self._evict_oldest()
|
||||
|
||||
async def get(
|
||||
self,
|
||||
*,
|
||||
scope: SnapshotScope,
|
||||
thread_id: AGUIThreadID,
|
||||
) -> AGUIThreadSnapshot | None:
|
||||
"""Get the latest snapshot for an AG-UI Thread within a Snapshot Scope."""
|
||||
snapshot = self._snapshots.get(self._key(scope=scope, thread_id=thread_id))
|
||||
return copy.deepcopy(snapshot) if snapshot is not None else None
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
*,
|
||||
scope: SnapshotScope,
|
||||
thread_id: AGUIThreadID,
|
||||
) -> bool:
|
||||
"""Delete the latest snapshot for an AG-UI Thread within a Snapshot Scope."""
|
||||
key = self._key(scope=scope, thread_id=thread_id)
|
||||
if key not in self._snapshots:
|
||||
return False
|
||||
del self._snapshots[key]
|
||||
return True
|
||||
|
||||
async def clear(self, *, scope: SnapshotScope | None = None) -> None:
|
||||
"""Clear saved snapshots, optionally limited to one Snapshot Scope."""
|
||||
if scope is None:
|
||||
self._snapshots.clear()
|
||||
return
|
||||
|
||||
normalized_scope = self._normalize_key_part(scope, "scope")
|
||||
for key in list(self._snapshots):
|
||||
if key[0] == normalized_scope:
|
||||
del self._snapshots[key]
|
||||
|
||||
@classmethod
|
||||
def _key(cls, *, scope: SnapshotScope, thread_id: AGUIThreadID) -> _SnapshotKey:
|
||||
return (
|
||||
cls._normalize_key_part(scope, "scope"),
|
||||
cls._normalize_key_part(thread_id, "thread_id"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_key_part(value: str, name: str) -> str:
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(f"{name} must be a string.")
|
||||
if not value:
|
||||
raise ValueError(f"{name} must be a non-empty string.")
|
||||
return value
|
||||
|
||||
def _evict_oldest(self) -> None:
|
||||
while len(self._snapshots) > self._max_snapshots:
|
||||
del self._snapshots[next(iter(self._snapshots))]
|
||||
@@ -4,18 +4,203 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from ag_ui.core import BaseEvent
|
||||
from ag_ui.core import (
|
||||
BaseEvent,
|
||||
MessagesSnapshotEvent,
|
||||
RunErrorEvent,
|
||||
RunFinishedEvent,
|
||||
RunStartedEvent,
|
||||
StateSnapshotEvent,
|
||||
TextMessageContentEvent,
|
||||
TextMessageEndEvent,
|
||||
TextMessageStartEvent,
|
||||
ToolCallArgsEvent,
|
||||
ToolCallResultEvent,
|
||||
ToolCallStartEvent,
|
||||
)
|
||||
from agent_framework import Workflow
|
||||
|
||||
from ._message_adapters import agui_messages_to_snapshot_format
|
||||
from ._run_common import (
|
||||
_build_run_finished_event,
|
||||
_extract_resume_payload,
|
||||
_reconstruct_messages_from_thread_snapshot,
|
||||
)
|
||||
from ._snapshots import (
|
||||
_DEFAULT_STATE_INPUT_KEY,
|
||||
_SNAPSHOT_SCOPE_INPUT_KEY,
|
||||
AGUIThreadSnapshot,
|
||||
AGUIThreadSnapshotStore,
|
||||
)
|
||||
from ._utils import generate_event_id, make_json_safe
|
||||
from ._workflow_run import run_workflow_stream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WorkflowFactory = Callable[[str], Workflow]
|
||||
|
||||
|
||||
def _event_messages_to_snapshot_dicts(messages: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Convert AG-UI message event models to plain snapshot dictionaries."""
|
||||
safe_messages = make_json_safe(messages)
|
||||
if not isinstance(safe_messages, list):
|
||||
return []
|
||||
return [cast(dict[str, Any], message) for message in safe_messages if isinstance(message, dict)]
|
||||
|
||||
|
||||
class _WorkflowSnapshotBuilder:
|
||||
"""Capture replayable workflow protocol output without retaining raw events."""
|
||||
|
||||
def __init__(self, raw_messages: list[dict[str, Any]]) -> None:
|
||||
self._synthesized_messages = agui_messages_to_snapshot_format(raw_messages)
|
||||
self._emitted_messages: list[dict[str, Any]] | None = None
|
||||
self._open_text_message: dict[str, Any] | None = None
|
||||
self._tool_call_message: dict[str, Any] | None = None
|
||||
self._tool_calls_by_id: dict[str, dict[str, Any]] = {}
|
||||
self.state: dict[str, Any] | None = None
|
||||
self.interrupt: list[dict[str, Any]] | None = None
|
||||
|
||||
def observe(self, event: BaseEvent) -> None:
|
||||
"""Fold one replayable AG-UI event into the latest snapshot state."""
|
||||
if isinstance(event, StateSnapshotEvent):
|
||||
state = make_json_safe(event.snapshot)
|
||||
if isinstance(state, dict):
|
||||
self.state = cast(dict[str, Any], state)
|
||||
return
|
||||
|
||||
if isinstance(event, MessagesSnapshotEvent):
|
||||
self._emitted_messages = _event_messages_to_snapshot_dicts(list(event.messages))
|
||||
return
|
||||
|
||||
if isinstance(event, RunFinishedEvent):
|
||||
interrupt = make_json_safe(getattr(event, "interrupt", None))
|
||||
if isinstance(interrupt, list):
|
||||
self.interrupt = [cast(dict[str, Any], item) for item in interrupt if isinstance(item, dict)]
|
||||
return
|
||||
|
||||
if self._emitted_messages is not None:
|
||||
return
|
||||
|
||||
if isinstance(event, TextMessageStartEvent):
|
||||
self._observe_text_start(event)
|
||||
elif isinstance(event, TextMessageContentEvent):
|
||||
self._observe_text_content(event)
|
||||
elif isinstance(event, TextMessageEndEvent):
|
||||
self._observe_text_end(event)
|
||||
elif isinstance(event, ToolCallStartEvent):
|
||||
self._observe_tool_call_start(event)
|
||||
elif isinstance(event, ToolCallArgsEvent):
|
||||
self._observe_tool_call_args(event)
|
||||
elif isinstance(event, ToolCallResultEvent):
|
||||
self._observe_tool_call_result(event)
|
||||
|
||||
def build(self) -> AGUIThreadSnapshot:
|
||||
"""Return the replayable thread snapshot."""
|
||||
self._flush_open_text_message()
|
||||
messages = self._emitted_messages if self._emitted_messages is not None else self._synthesized_messages
|
||||
return AGUIThreadSnapshot(messages=messages, state=self.state, interrupt=self.interrupt)
|
||||
|
||||
def _observe_text_start(self, event: TextMessageStartEvent) -> None:
|
||||
if self._open_text_message is not None and self._open_text_message.get("id") != event.message_id:
|
||||
self._flush_open_text_message()
|
||||
self._open_text_message = {"id": event.message_id, "role": event.role, "content": ""}
|
||||
|
||||
def _observe_text_content(self, event: TextMessageContentEvent) -> None:
|
||||
if self._open_text_message is None or self._open_text_message.get("id") != event.message_id:
|
||||
self._open_text_message = {"id": event.message_id, "role": "assistant", "content": ""}
|
||||
self._open_text_message["content"] = f"{self._open_text_message.get('content', '')}{event.delta}"
|
||||
|
||||
def _observe_text_end(self, event: TextMessageEndEvent) -> None:
|
||||
if self._open_text_message is None or self._open_text_message.get("id") != event.message_id:
|
||||
return
|
||||
self._flush_open_text_message()
|
||||
|
||||
def _observe_tool_call_start(self, event: ToolCallStartEvent) -> None:
|
||||
parent_message_id = event.parent_message_id
|
||||
if (
|
||||
self._open_text_message is not None
|
||||
and parent_message_id is not None
|
||||
and self._open_text_message.get("id") == parent_message_id
|
||||
and self._open_text_message.get("content")
|
||||
):
|
||||
self._open_text_message["id"] = generate_event_id()
|
||||
self._flush_open_text_message()
|
||||
if self._tool_call_message is None or (
|
||||
parent_message_id is not None and self._tool_call_message.get("id") != parent_message_id
|
||||
):
|
||||
self._tool_call_message = {
|
||||
"id": parent_message_id or generate_event_id(),
|
||||
"role": "assistant",
|
||||
"tool_calls": [],
|
||||
}
|
||||
self._synthesized_messages.append(self._tool_call_message)
|
||||
|
||||
tool_call = {
|
||||
"id": event.tool_call_id,
|
||||
"type": "function",
|
||||
"function": {"name": event.tool_call_name, "arguments": ""},
|
||||
}
|
||||
cast(list[dict[str, Any]], self._tool_call_message["tool_calls"]).append(tool_call)
|
||||
self._tool_calls_by_id[event.tool_call_id] = tool_call
|
||||
|
||||
def _observe_tool_call_args(self, event: ToolCallArgsEvent) -> None:
|
||||
tool_call = self._tool_calls_by_id.get(event.tool_call_id)
|
||||
if tool_call is None:
|
||||
return
|
||||
function_payload = cast(dict[str, Any], tool_call["function"])
|
||||
function_payload["arguments"] = f"{function_payload.get('arguments', '')}{event.delta}"
|
||||
|
||||
def _observe_tool_call_result(self, event: ToolCallResultEvent) -> None:
|
||||
self._synthesized_messages.append(
|
||||
{
|
||||
"id": event.message_id,
|
||||
"role": "tool",
|
||||
"toolCallId": event.tool_call_id,
|
||||
"content": event.content,
|
||||
}
|
||||
)
|
||||
# A result closes the current tool-call group; later tool calls start a new
|
||||
# assistant message so replayed transcripts keep results adjacent to their
|
||||
# tool_calls message, which provider APIs require.
|
||||
self._tool_call_message = None
|
||||
|
||||
def _flush_open_text_message(self) -> None:
|
||||
if self._open_text_message is None:
|
||||
return
|
||||
if self._open_text_message.get("content"):
|
||||
self._synthesized_messages.append(self._open_text_message)
|
||||
# Text between tool calls closes the current tool-call group as well.
|
||||
self._tool_call_message = None
|
||||
self._open_text_message = None
|
||||
|
||||
|
||||
async def _hydrate_workflow_thread_snapshot(
|
||||
*,
|
||||
snapshot_store: AGUIThreadSnapshotStore,
|
||||
scope: str,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
) -> AsyncGenerator[BaseEvent]:
|
||||
"""Replay the latest stored workflow AG-UI Thread Snapshot without invoking the workflow."""
|
||||
yield RunStartedEvent(run_id=run_id, thread_id=thread_id)
|
||||
snapshot = await snapshot_store.get(scope=scope, thread_id=thread_id)
|
||||
if snapshot is None:
|
||||
yield _build_run_finished_event(run_id=run_id, thread_id=thread_id)
|
||||
return
|
||||
|
||||
if snapshot.state is not None:
|
||||
yield StateSnapshotEvent(snapshot=snapshot.state)
|
||||
if snapshot.messages:
|
||||
yield MessagesSnapshotEvent(messages=snapshot.messages) # type: ignore[arg-type]
|
||||
yield _build_run_finished_event(run_id=run_id, thread_id=thread_id, interrupts=snapshot.interrupt)
|
||||
|
||||
|
||||
class AgentFrameworkWorkflow:
|
||||
"""Base AG-UI workflow wrapper.
|
||||
|
||||
@@ -29,15 +214,30 @@ class AgentFrameworkWorkflow:
|
||||
workflow_factory: WorkflowFactory | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
snapshot_store: AGUIThreadSnapshotStore | None = None,
|
||||
) -> None:
|
||||
"""Initialize the AG-UI workflow wrapper.
|
||||
|
||||
Args:
|
||||
workflow: Optional workflow instance to expose.
|
||||
workflow_factory: Optional factory for thread-scoped workflow instances.
|
||||
name: Optional workflow name.
|
||||
description: Optional workflow description.
|
||||
snapshot_store: Optional AG-UI Thread Snapshot store. Snapshot persistence remains inactive unless
|
||||
endpoint setup also provides an explicit Snapshot Scope resolver.
|
||||
"""
|
||||
if workflow is not None and workflow_factory is not None:
|
||||
raise ValueError("Pass either workflow= or workflow_factory=, not both.")
|
||||
|
||||
self.workflow = workflow
|
||||
self._workflow_factory = workflow_factory
|
||||
self._workflow_by_thread: dict[str, Workflow] = {}
|
||||
# Cache keyed by (snapshot_scope, thread_id): the Snapshot Scope is the
|
||||
# authorization boundary, so the same thread id under different scopes
|
||||
# must never share an in-memory workflow instance.
|
||||
self._workflow_by_thread: dict[tuple[str | None, str], Workflow] = {}
|
||||
self.name = name if name is not None else getattr(workflow, "name", "workflow")
|
||||
self.description = description if description is not None else getattr(workflow, "description", "")
|
||||
self.snapshot_store = snapshot_store
|
||||
|
||||
@staticmethod
|
||||
def _thread_id_from_input(input_data: dict[str, Any]) -> str:
|
||||
@@ -47,7 +247,7 @@ class AgentFrameworkWorkflow:
|
||||
return str(thread_id)
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def _resolve_workflow(self, thread_id: str) -> Workflow:
|
||||
def _resolve_workflow(self, thread_id: str, snapshot_scope: str | None = None) -> Workflow:
|
||||
"""Get the workflow instance for the current run."""
|
||||
if self.workflow is not None:
|
||||
return self.workflow
|
||||
@@ -55,17 +255,22 @@ class AgentFrameworkWorkflow:
|
||||
if self._workflow_factory is None:
|
||||
raise NotImplementedError("No workflow is attached. Override run or pass workflow=/workflow_factory=.")
|
||||
|
||||
workflow = self._workflow_by_thread.get(thread_id)
|
||||
cache_key = (snapshot_scope, thread_id)
|
||||
workflow = self._workflow_by_thread.get(cache_key)
|
||||
if workflow is None:
|
||||
workflow = self._workflow_factory(thread_id)
|
||||
if not isinstance(workflow, Workflow):
|
||||
raise TypeError("workflow_factory must return a Workflow instance.")
|
||||
self._workflow_by_thread[thread_id] = workflow
|
||||
self._workflow_by_thread[cache_key] = workflow
|
||||
return workflow
|
||||
|
||||
def clear_thread_workflow(self, thread_id: str) -> None:
|
||||
"""Drop a single cached thread workflow instance."""
|
||||
self._workflow_by_thread.pop(thread_id, None)
|
||||
def clear_thread_workflow(self, thread_id: str, snapshot_scope: str | None = None) -> None:
|
||||
"""Drop cached workflow instances for a thread, optionally limited to one Snapshot Scope."""
|
||||
if snapshot_scope is not None:
|
||||
self._workflow_by_thread.pop((snapshot_scope, thread_id), None)
|
||||
return
|
||||
for key in [key for key in self._workflow_by_thread if key[1] == thread_id]:
|
||||
del self._workflow_by_thread[key]
|
||||
|
||||
def clear_workflow_cache(self) -> None:
|
||||
"""Drop all cached thread workflow instances."""
|
||||
@@ -77,6 +282,96 @@ class AgentFrameworkWorkflow:
|
||||
Subclasses may override this to provide custom AG-UI streams.
|
||||
"""
|
||||
thread_id = self._thread_id_from_input(input_data)
|
||||
workflow = self._resolve_workflow(thread_id)
|
||||
run_id = str(input_data.get("run_id") or input_data.get("runId") or uuid.uuid4())
|
||||
snapshot_scope = cast(str | None, input_data.get(_SNAPSHOT_SCOPE_INPUT_KEY))
|
||||
raw_messages = list(cast(list[dict[str, Any]], input_data.get("messages", []) or []))
|
||||
resume_payload = _extract_resume_payload(input_data)
|
||||
snapshot_store = self.snapshot_store
|
||||
|
||||
if snapshot_store is not None and snapshot_scope is not None and not raw_messages and resume_payload is None:
|
||||
async for event in _hydrate_workflow_thread_snapshot(
|
||||
snapshot_store=snapshot_store,
|
||||
scope=snapshot_scope,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
|
||||
# Load the stored snapshot for follow-up turns so the workflow runs with the
|
||||
# full persisted thread history instead of just the latest request messages.
|
||||
stored_snapshot: AGUIThreadSnapshot | None = None
|
||||
if snapshot_store is not None and snapshot_scope is not None:
|
||||
stored_snapshot = await snapshot_store.get(scope=snapshot_scope, thread_id=thread_id)
|
||||
if stored_snapshot is not None and resume_payload is None:
|
||||
raw_messages = _reconstruct_messages_from_thread_snapshot(
|
||||
stored_messages=stored_snapshot.messages,
|
||||
incoming_messages=raw_messages,
|
||||
stored_interrupt=stored_snapshot.interrupt,
|
||||
)
|
||||
input_data["messages"] = raw_messages
|
||||
|
||||
# Merge stored state with request overrides, then fill endpoint-deferred
|
||||
# defaults only for keys missing from both.
|
||||
request_state = input_data.get("state")
|
||||
deferred_default_state = cast(dict[str, Any] | None, input_data.get(_DEFAULT_STATE_INPUT_KEY))
|
||||
effective_state: dict[str, Any] = {}
|
||||
if stored_snapshot is not None and stored_snapshot.state is not None:
|
||||
effective_state.update(stored_snapshot.state)
|
||||
if isinstance(request_state, dict):
|
||||
effective_state.update(cast(dict[str, Any], request_state))
|
||||
if deferred_default_state:
|
||||
for key, value in deferred_default_state.items():
|
||||
if key not in effective_state:
|
||||
effective_state[key] = copy.deepcopy(value)
|
||||
if effective_state:
|
||||
input_data["state"] = effective_state
|
||||
|
||||
workflow = self._resolve_workflow(thread_id, snapshot_scope)
|
||||
builder_seed_messages = raw_messages
|
||||
if resume_payload is not None and stored_snapshot is not None:
|
||||
# Resume requests carry only the synthesized interrupt response, so seed
|
||||
# the builder with stored history to avoid persisting a truncated thread.
|
||||
builder_seed_messages = [
|
||||
copy.deepcopy(message) for message in stored_snapshot.messages
|
||||
] + builder_seed_messages
|
||||
snapshot_builder = (
|
||||
_WorkflowSnapshotBuilder(builder_seed_messages)
|
||||
if snapshot_store is not None and snapshot_scope is not None
|
||||
else None
|
||||
)
|
||||
if snapshot_builder is not None and effective_state:
|
||||
# Seed builder state so a run that emits no StateSnapshotEvent still
|
||||
# persists the latest known Shared State instead of dropping it.
|
||||
state_snapshot = make_json_safe(effective_state)
|
||||
if isinstance(state_snapshot, dict):
|
||||
snapshot_builder.state = cast(dict[str, Any], state_snapshot)
|
||||
run_error_emitted = False
|
||||
async for event in run_workflow_stream(input_data, workflow):
|
||||
if snapshot_builder is not None:
|
||||
snapshot_builder.observe(event)
|
||||
if isinstance(event, RunErrorEvent):
|
||||
run_error_emitted = True
|
||||
yield event
|
||||
|
||||
if (
|
||||
snapshot_builder is not None
|
||||
and not run_error_emitted
|
||||
and snapshot_store is not None
|
||||
and snapshot_scope is not None
|
||||
):
|
||||
try:
|
||||
await snapshot_store.save(
|
||||
scope=snapshot_scope,
|
||||
thread_id=thread_id,
|
||||
snapshot=snapshot_builder.build(),
|
||||
)
|
||||
except Exception:
|
||||
# RUN_FINISHED has already been yielded; a store failure must not
|
||||
# surface as a second terminal RUN_ERROR event. The previous
|
||||
# snapshot stays available for hydration.
|
||||
logger.exception(
|
||||
"Failed to save AG-UI Thread Snapshot for scope=%s thread_id=%s; keeping previous snapshot.",
|
||||
snapshot_scope,
|
||||
thread_id,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -32,6 +32,21 @@ def test_agent_framework_ag_ui_exports_state_update() -> None:
|
||||
assert callable(state_update)
|
||||
|
||||
|
||||
def test_agent_framework_ag_ui_exports_snapshot_primitives() -> None:
|
||||
"""Runtime package should export AG-UI Thread Snapshot primitives."""
|
||||
from agent_framework_ag_ui import (
|
||||
DEFAULT_MAX_THREAD_SNAPSHOTS,
|
||||
AGUIThreadSnapshot,
|
||||
AGUIThreadSnapshotStore,
|
||||
InMemoryAGUIThreadSnapshotStore,
|
||||
)
|
||||
|
||||
assert AGUIThreadSnapshot.__name__ == "AGUIThreadSnapshot"
|
||||
assert AGUIThreadSnapshotStore.__name__ == "AGUIThreadSnapshotStore"
|
||||
assert InMemoryAGUIThreadSnapshotStore.__name__ == "InMemoryAGUIThreadSnapshotStore"
|
||||
assert DEFAULT_MAX_THREAD_SNAPSHOTS >= 1
|
||||
|
||||
|
||||
def test_core_ag_ui_lazy_exports_include_event_converter_and_http_service() -> None:
|
||||
"""Core facade must expose AGUIEventConverter, AGUIHttpService, and __version__."""
|
||||
from agent_framework import ag_ui
|
||||
@@ -39,3 +54,13 @@ def test_core_ag_ui_lazy_exports_include_event_converter_and_http_service() -> N
|
||||
assert hasattr(ag_ui, "AGUIEventConverter")
|
||||
assert hasattr(ag_ui, "AGUIHttpService")
|
||||
assert hasattr(ag_ui, "__version__")
|
||||
|
||||
|
||||
def test_core_ag_ui_lazy_exports_include_snapshot_primitives() -> None:
|
||||
"""Core facade must expose snapshot primitives needed for endpoint configuration."""
|
||||
from agent_framework import ag_ui
|
||||
|
||||
assert hasattr(ag_ui, "AGUIThreadSnapshot")
|
||||
assert hasattr(ag_ui, "AGUIThreadSnapshotStore")
|
||||
assert hasattr(ag_ui, "InMemoryAGUIThreadSnapshotStore")
|
||||
assert hasattr(ag_ui, "SnapshotScopeResolver")
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for AG-UI thread snapshot storage primitives."""
|
||||
|
||||
from dataclasses import fields
|
||||
|
||||
from agent_framework_ag_ui import AGUIThreadSnapshot, AGUIThreadSnapshotStore, InMemoryAGUIThreadSnapshotStore
|
||||
|
||||
|
||||
def test_thread_snapshot_model_contains_only_replayable_snapshot_fields() -> None:
|
||||
"""The public snapshot model is limited to messages, Shared State, and interruption state."""
|
||||
assert [field.name for field in fields(AGUIThreadSnapshot)] == ["messages", "state", "interrupt"]
|
||||
|
||||
|
||||
def test_in_memory_snapshot_store_satisfies_snapshot_store_protocol() -> None:
|
||||
"""The built-in store conforms to the public async store protocol."""
|
||||
assert isinstance(InMemoryAGUIThreadSnapshotStore(), AGUIThreadSnapshotStore)
|
||||
|
||||
|
||||
async def test_in_memory_snapshot_store_replaces_latest_snapshot() -> None:
|
||||
"""Saving the same scoped thread key replaces the previous snapshot."""
|
||||
store = InMemoryAGUIThreadSnapshotStore()
|
||||
|
||||
await store.save(
|
||||
scope="tenant-a",
|
||||
thread_id="thread-1",
|
||||
snapshot=AGUIThreadSnapshot(messages=[{"id": "first"}], state={"count": 1}),
|
||||
)
|
||||
await store.save(
|
||||
scope="tenant-a",
|
||||
thread_id="thread-1",
|
||||
snapshot=AGUIThreadSnapshot(messages=[{"id": "second"}], state={"count": 2}),
|
||||
)
|
||||
|
||||
snapshot = await store.get(scope="tenant-a", thread_id="thread-1")
|
||||
|
||||
assert snapshot is not None
|
||||
assert snapshot.messages == [{"id": "second"}]
|
||||
assert snapshot.state == {"count": 2}
|
||||
|
||||
|
||||
async def test_in_memory_snapshot_store_keeps_scopes_separate() -> None:
|
||||
"""The same AG-UI Thread id in different Snapshot Scopes addresses different snapshots."""
|
||||
store = InMemoryAGUIThreadSnapshotStore()
|
||||
|
||||
await store.save(
|
||||
scope="tenant-a",
|
||||
thread_id="thread-1",
|
||||
snapshot=AGUIThreadSnapshot(messages=[{"id": "a", "role": "user", "content": "from a"}]),
|
||||
)
|
||||
await store.save(
|
||||
scope="tenant-b",
|
||||
thread_id="thread-1",
|
||||
snapshot=AGUIThreadSnapshot(messages=[{"id": "b", "role": "user", "content": "from b"}]),
|
||||
)
|
||||
|
||||
tenant_a_snapshot = await store.get(scope="tenant-a", thread_id="thread-1")
|
||||
tenant_b_snapshot = await store.get(scope="tenant-b", thread_id="thread-1")
|
||||
|
||||
assert tenant_a_snapshot is not None
|
||||
assert tenant_b_snapshot is not None
|
||||
assert tenant_a_snapshot.messages == [{"id": "a", "role": "user", "content": "from a"}]
|
||||
assert tenant_b_snapshot.messages == [{"id": "b", "role": "user", "content": "from b"}]
|
||||
|
||||
|
||||
async def test_in_memory_snapshot_store_deletes_and_clears_snapshots() -> None:
|
||||
"""Delete removes one scoped thread key, while clear can remove a scope or the whole store."""
|
||||
store = InMemoryAGUIThreadSnapshotStore()
|
||||
|
||||
await store.save(scope="tenant-a", thread_id="thread-1", snapshot=AGUIThreadSnapshot(messages=[{"id": "a1"}]))
|
||||
await store.save(scope="tenant-a", thread_id="thread-2", snapshot=AGUIThreadSnapshot(messages=[{"id": "a2"}]))
|
||||
await store.save(scope="tenant-b", thread_id="thread-1", snapshot=AGUIThreadSnapshot(messages=[{"id": "b1"}]))
|
||||
|
||||
assert await store.delete(scope="tenant-a", thread_id="thread-1") is True
|
||||
assert await store.delete(scope="tenant-a", thread_id="thread-1") is False
|
||||
assert await store.get(scope="tenant-a", thread_id="thread-1") is None
|
||||
assert await store.get(scope="tenant-a", thread_id="thread-2") is not None
|
||||
|
||||
await store.clear(scope="tenant-a")
|
||||
|
||||
assert await store.get(scope="tenant-a", thread_id="thread-2") is None
|
||||
assert await store.get(scope="tenant-b", thread_id="thread-1") is not None
|
||||
|
||||
await store.clear()
|
||||
|
||||
assert await store.get(scope="tenant-b", thread_id="thread-1") is None
|
||||
|
||||
|
||||
async def test_in_memory_snapshot_store_evicts_oldest_snapshot_when_bounded() -> None:
|
||||
"""The memory store bounds retained scoped thread snapshots."""
|
||||
store = InMemoryAGUIThreadSnapshotStore(max_snapshots=2)
|
||||
|
||||
await store.save(scope="tenant-a", thread_id="thread-1", snapshot=AGUIThreadSnapshot(messages=[{"id": "first"}]))
|
||||
await store.save(scope="tenant-a", thread_id="thread-2", snapshot=AGUIThreadSnapshot(messages=[{"id": "second"}]))
|
||||
await store.save(scope="tenant-a", thread_id="thread-3", snapshot=AGUIThreadSnapshot(messages=[{"id": "third"}]))
|
||||
|
||||
assert await store.get(scope="tenant-a", thread_id="thread-1") is None
|
||||
assert await store.get(scope="tenant-a", thread_id="thread-2") is not None
|
||||
assert await store.get(scope="tenant-a", thread_id="thread-3") is not None
|
||||
|
||||
|
||||
def test_workflow_snapshot_builder_splits_tool_call_groups() -> None:
|
||||
"""Tool calls separated by results or text synthesize provider-valid message groups."""
|
||||
from ag_ui.core import (
|
||||
TextMessageContentEvent,
|
||||
TextMessageEndEvent,
|
||||
TextMessageStartEvent,
|
||||
ToolCallArgsEvent,
|
||||
ToolCallResultEvent,
|
||||
ToolCallStartEvent,
|
||||
)
|
||||
|
||||
from agent_framework_ag_ui._workflow import _WorkflowSnapshotBuilder
|
||||
|
||||
builder = _WorkflowSnapshotBuilder([])
|
||||
builder.observe(ToolCallStartEvent(tool_call_id="call-a", tool_call_name="toolA"))
|
||||
builder.observe(ToolCallArgsEvent(tool_call_id="call-a", delta='{"x": 1}'))
|
||||
builder.observe(ToolCallResultEvent(message_id="result-a", tool_call_id="call-a", content="resA"))
|
||||
builder.observe(TextMessageStartEvent(message_id="text-1", role="assistant"))
|
||||
builder.observe(TextMessageContentEvent(message_id="text-1", delta="thinking"))
|
||||
builder.observe(TextMessageEndEvent(message_id="text-1"))
|
||||
builder.observe(ToolCallStartEvent(tool_call_id="call-b", tool_call_name="toolB"))
|
||||
builder.observe(ToolCallResultEvent(message_id="result-b", tool_call_id="call-b", content="resB"))
|
||||
|
||||
messages = builder.build().messages
|
||||
shapes = [
|
||||
(
|
||||
message.get("role"),
|
||||
[tool_call["id"] for tool_call in message.get("tool_calls", [])] or message.get("toolCallId"),
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
assert shapes == [
|
||||
("assistant", ["call-a"]),
|
||||
("tool", "call-a"),
|
||||
("assistant", None),
|
||||
("assistant", ["call-b"]),
|
||||
("tool", "call-b"),
|
||||
]
|
||||
|
||||
|
||||
async def test_in_memory_snapshot_store_rejects_invalid_keys() -> None:
|
||||
"""Key parts must be non-empty strings for every store operation."""
|
||||
import pytest
|
||||
|
||||
store = InMemoryAGUIThreadSnapshotStore()
|
||||
snapshot = AGUIThreadSnapshot()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await store.save(scope="", thread_id="thread-1", snapshot=snapshot)
|
||||
with pytest.raises(ValueError):
|
||||
await store.save(scope="tenant-a", thread_id="", snapshot=snapshot)
|
||||
with pytest.raises(TypeError):
|
||||
await store.save(scope=123, thread_id="thread-1", snapshot=snapshot) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError):
|
||||
await store.get(scope="tenant-a", thread_id="")
|
||||
with pytest.raises(TypeError):
|
||||
await store.delete(scope=None, thread_id="thread-1") # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError):
|
||||
await store.clear(scope="")
|
||||
@@ -125,7 +125,6 @@ from ._harness._todo import (
|
||||
TodoSessionStore,
|
||||
TodoStore,
|
||||
)
|
||||
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool, SamplingApprovalCallback
|
||||
from ._harness._tool_approval import (
|
||||
DEFAULT_TOOL_APPROVAL_SOURCE_ID,
|
||||
ToolApprovalMiddleware,
|
||||
@@ -135,6 +134,7 @@ from ._harness._tool_approval import (
|
||||
create_always_approve_tool_response,
|
||||
create_always_approve_tool_with_arguments_response,
|
||||
)
|
||||
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool, SamplingApprovalCallback
|
||||
from ._middleware import (
|
||||
AgentContext,
|
||||
AgentMiddleware,
|
||||
|
||||
@@ -1989,9 +1989,7 @@ def _store_already_approved_approval_requests(
|
||||
return
|
||||
|
||||
existing_groups = state.get(_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY)
|
||||
pending_groups: list[Any] = (
|
||||
list(cast(Iterable[Any], existing_groups)) if isinstance(existing_groups, list) else []
|
||||
)
|
||||
pending_groups: list[Any] = list(cast(Iterable[Any], existing_groups)) if isinstance(existing_groups, list) else []
|
||||
pending_groups.append({
|
||||
"approval_request_ids": visible_ids,
|
||||
"approval_requests": [request.to_dict() for request in already_approved_requests],
|
||||
|
||||
@@ -11,6 +11,10 @@ Supported classes and functions:
|
||||
- AGUIChatClient
|
||||
- AGUIEventConverter
|
||||
- AGUIHttpService
|
||||
- AGUIThreadSnapshot
|
||||
- AGUIThreadSnapshotStore
|
||||
- InMemoryAGUIThreadSnapshotStore
|
||||
- SnapshotScopeResolver
|
||||
- add_agent_framework_fastapi_endpoint
|
||||
- state_update
|
||||
- __version__
|
||||
@@ -28,6 +32,10 @@ _IMPORTS = [
|
||||
"AGUIChatClient",
|
||||
"AGUIEventConverter",
|
||||
"AGUIHttpService",
|
||||
"AGUIThreadSnapshot",
|
||||
"AGUIThreadSnapshotStore",
|
||||
"InMemoryAGUIThreadSnapshotStore",
|
||||
"SnapshotScopeResolver",
|
||||
"state_update",
|
||||
"__version__",
|
||||
]
|
||||
|
||||
@@ -6,6 +6,10 @@ from agent_framework_ag_ui import (
|
||||
AGUIChatClient,
|
||||
AGUIEventConverter,
|
||||
AGUIHttpService,
|
||||
AGUIThreadSnapshot,
|
||||
AGUIThreadSnapshotStore,
|
||||
InMemoryAGUIThreadSnapshotStore,
|
||||
SnapshotScopeResolver,
|
||||
__version__,
|
||||
add_agent_framework_fastapi_endpoint,
|
||||
state_update,
|
||||
@@ -15,8 +19,12 @@ __all__ = [
|
||||
"AGUIChatClient",
|
||||
"AGUIEventConverter",
|
||||
"AGUIHttpService",
|
||||
"AGUIThreadSnapshot",
|
||||
"AGUIThreadSnapshotStore",
|
||||
"AgentFrameworkAgent",
|
||||
"AgentFrameworkWorkflow",
|
||||
"InMemoryAGUIThreadSnapshotStore",
|
||||
"SnapshotScopeResolver",
|
||||
"__version__",
|
||||
"add_agent_framework_fastapi_endpoint",
|
||||
"state_update",
|
||||
|
||||
Reference in New Issue
Block a user