Files
Eduard van Valkenburg 36ce0950e4 Simplify Python hosting core (#6492)
Remove linking, multicast, durable delivery, and host push machinery from the v1 hosting core. Keep those scenarios in a proposed follow-up ADR and update channel packages, samples, docs, tests, and workspace metadata around the smaller host/channel contract.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-06-12 08:34:08 +02:00

1307 lines
52 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
"""Tests for :class:`AgentFrameworkHost` invocation, session, and delivery routing."""
from __future__ import annotations
from collections.abc import AsyncIterator, Sequence
from dataclasses import dataclass, field
from typing import Any
import pytest
from agent_framework import AgentResponse, AgentResponseUpdate, Content, Message, ResponseStream
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import BaseRoute, Route
from starlette.testclient import TestClient
from agent_framework_hosting import (
AgentFrameworkHost,
Channel,
ChannelContext,
ChannelContribution,
ChannelIdentity,
ChannelRequest,
ChannelSession,
HostedRunResult,
)
async def _ping(_request: Request) -> JSONResponse:
return JSONResponse({"ok": True})
# --------------------------------------------------------------------------- #
# Fakes #
# --------------------------------------------------------------------------- #
@dataclass
class _FakeAgentSession:
session_id: str | None = None
service_session_id: str | None = None
@dataclass
class _FakeAgentResponse:
text: str
@property
def messages(self) -> list[Message]:
# Real ``AgentResponse`` carries a list of messages; the host's
# ``_invoke`` forwards them on the ``HostedRunResult``. Synthesise
# a single assistant text message so tests that assert on
# ``payload.text`` keep working unchanged.
return [Message(role="assistant", contents=[Content.from_text(text=self.text)])]
class _FakeAgent:
"""Minimal :class:`SupportsAgentRun` implementation that records invocations."""
def __init__(self, reply: str = "ok") -> None:
self._reply = reply
self.calls: list[dict[str, Any]] = []
self.created_sessions: list[_FakeAgentSession] = []
def create_session(self, *, session_id: str | None = None) -> _FakeAgentSession:
s = _FakeAgentSession(session_id=session_id)
self.created_sessions.append(s)
return s
def run(self, messages: Any = None, *, stream: bool = False, session: Any = None, **kwargs: Any) -> Any:
self.calls.append({"messages": messages, "stream": stream, "session": session, "kwargs": kwargs})
if stream:
updates = [AgentResponseUpdate(contents=[Content.from_text(text=self._reply)], role="assistant")]
async def _gen() -> AsyncIterator[AgentResponseUpdate]:
for update in updates:
yield update
async def _finalize(items: Sequence[AgentResponseUpdate]) -> AgentResponse: # noqa: RUF029
return AgentResponse.from_updates(items)
return ResponseStream[AgentResponseUpdate, AgentResponse](_gen(), finalizer=_finalize)
async def _coro() -> _FakeAgentResponse:
return _FakeAgentResponse(text=self._reply)
return _coro()
class _RecordingChannel:
"""Minimal :class:`Channel` for host tests."""
def __init__(self, name: str = "fake", path: str = "/fake") -> None:
self.name = name
self.path = path
self.context: ChannelContext | None = None
# Provide a single trivial route so contribute() exercises the endpoint path.
self._routes: Sequence[BaseRoute] = (Route("/ping", _ping),)
def contribute(self, context: ChannelContext) -> ChannelContribution:
self.context = context
return ChannelContribution(routes=self._routes)
def _assistant_response(text: str) -> AgentResponse:
"""Build a one-message ``AgentResponse`` to use as a ``HostedRunResult.result``."""
return AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text(text=text)])])
def _make_reply(text: str = "reply") -> HostedRunResult[AgentResponse]:
"""Build a ``HostedRunResult[AgentResponse]`` carrying a single assistant text message.
Test ergonomic mirroring what the host's ``_invoke`` produces for an
agent target — channels (and our delivery tests) receive a typed
envelope whose ``result`` is a real :class:`AgentResponse`.
"""
return HostedRunResult(_assistant_response(text))
@dataclass
class _LifecycleChannel:
name: str = "lifecycle"
path: str = ""
started: list[str] = field(default_factory=list)
stopped: list[str] = field(default_factory=list)
def contribute(self, context: ChannelContext) -> ChannelContribution:
async def on_start() -> None:
self.started.append("up")
async def on_stop() -> None:
self.stopped.append("down")
return ChannelContribution(on_startup=[on_start], on_shutdown=[on_stop])
# --------------------------------------------------------------------------- #
# Host wiring #
# --------------------------------------------------------------------------- #
class TestHostWiring:
def test_channel_is_recognized(self) -> None:
ch = _RecordingChannel()
assert isinstance(ch, Channel)
def test_app_mounts_channel_routes_under_path(self) -> None:
agent = _FakeAgent()
ch = _RecordingChannel(path="/fake")
host = AgentFrameworkHost(target=agent, channels=[ch])
with TestClient(host.app) as client:
r = client.get("/fake/ping")
assert r.status_code == 200
assert r.json() == {"ok": True}
def test_app_mounts_root_route_at_exact_channel_path(self) -> None:
agent = _FakeAgent()
ch = _RecordingChannel(path="/fake")
ch._routes = (Route("/", _ping),)
host = AgentFrameworkHost(target=agent, channels=[ch])
with TestClient(host.app, follow_redirects=False) as client:
r = client.get("/fake")
assert r.status_code == 200
assert r.json() == {"ok": True}
assert client.get("/fake/").status_code == 200
def test_app_mounts_at_root_when_path_is_empty(self) -> None:
agent = _FakeAgent()
ch = _RecordingChannel(path="")
host = AgentFrameworkHost(target=agent, channels=[ch])
with TestClient(host.app) as client:
r = client.get("/ping")
assert r.status_code == 200
def test_app_is_cached(self) -> None:
host = AgentFrameworkHost(target=_FakeAgent(), channels=[_RecordingChannel()])
assert host.app is host.app
def test_lifespan_invokes_startup_and_shutdown(self) -> None:
agent = _FakeAgent()
ch = _LifecycleChannel()
host = AgentFrameworkHost(target=agent, channels=[ch])
with TestClient(host.app):
assert ch.started == ["up"]
assert ch.stopped == ["down"]
def test_app_exposes_readiness_probe(self) -> None:
host = AgentFrameworkHost(target=_FakeAgent(), channels=[_RecordingChannel()])
with TestClient(host.app) as client:
r = client.get("/readiness")
assert r.status_code == 200
assert r.text == "ok"
# --------------------------------------------------------------------------- #
# Invoke + sessions #
# --------------------------------------------------------------------------- #
class TestHostInvoke:
async def test_invoke_wraps_input_with_hosting_metadata(self) -> None:
agent = _FakeAgent(reply="hello")
ch = _RecordingChannel(name="responses")
host = AgentFrameworkHost(target=agent, channels=[ch])
# Force ``app`` build to trigger ``contribute``.
_ = host.app
assert ch.context is not None
req = ChannelRequest(
channel="responses",
operation="message.create",
input="hi",
session=ChannelSession(isolation_key="user:1"),
identity=ChannelIdentity(channel="responses", native_id="user:1"),
)
result = await ch.context.run(req)
assert result.result.text == "hello"
assert len(agent.calls) == 1
msg = agent.calls[0]["messages"]
assert msg.role == "user"
assert msg.additional_properties["hosting"]["channel"] == "responses"
assert msg.additional_properties["hosting"]["identity"] == {
"channel": "responses",
"native_id": "user:1",
"attributes": {},
}
async def test_invoke_caches_session_per_isolation_key(self) -> None:
agent = _FakeAgent()
ch = _RecordingChannel()
host = AgentFrameworkHost(target=agent, channels=[ch])
_ = host.app
assert ch.context is not None
req_a = ChannelRequest(
channel=ch.name, operation="op", input="1", session=ChannelSession(isolation_key="alice")
)
req_b = ChannelRequest(
channel=ch.name, operation="op", input="2", session=ChannelSession(isolation_key="alice")
)
req_c = ChannelRequest(channel=ch.name, operation="op", input="3", session=ChannelSession(isolation_key="bob"))
await ch.context.run(req_a)
await ch.context.run(req_b)
await ch.context.run(req_c)
# Two distinct sessions created (alice, bob) — never re-created.
assert len(agent.created_sessions) == 2
assert agent.calls[0]["session"] is agent.calls[1]["session"]
assert agent.calls[0]["session"] is not agent.calls[2]["session"]
async def test_session_disabled_does_not_create_session(self) -> None:
agent = _FakeAgent()
ch = _RecordingChannel()
host = AgentFrameworkHost(target=agent, channels=[ch])
_ = host.app
assert ch.context is not None
req = ChannelRequest(
channel=ch.name,
operation="op",
input="x",
session=ChannelSession(isolation_key="alice"),
session_mode="disabled",
)
await ch.context.run(req)
assert agent.created_sessions == []
assert agent.calls[0]["session"] is None
async def test_reset_session_rotates_id_and_drops_cache(self) -> None:
agent = _FakeAgent()
ch = _RecordingChannel()
host = AgentFrameworkHost(target=agent, channels=[ch])
_ = host.app
assert ch.context is not None
req = ChannelRequest(channel=ch.name, operation="op", input="x", session=ChannelSession(isolation_key="alice"))
await ch.context.run(req)
first_session = agent.calls[-1]["session"]
assert first_session.session_id == "alice"
host.reset_session("alice")
await ch.context.run(req)
second_session = agent.calls[-1]["session"]
# New session, new id (alias rotation), distinct object.
assert second_session is not first_session
assert second_session.session_id != "alice"
assert second_session.session_id.startswith("alice#")
async def test_options_propagates_to_target_run(self) -> None:
agent = _FakeAgent()
ch = _RecordingChannel()
host = AgentFrameworkHost(target=agent, channels=[ch])
_ = host.app
assert ch.context is not None
req = ChannelRequest(
channel=ch.name,
operation="op",
input="x",
session=ChannelSession(isolation_key="alice"),
options={"temperature": 0.4},
)
await ch.context.run(req)
assert agent.calls[0]["kwargs"]["options"] == {"temperature": 0.4}
class TestHostOwnedHooks:
async def test_context_run_applies_run_hook_before_invocation(self) -> None:
agent = _FakeAgent()
ch = _RecordingChannel()
host = AgentFrameworkHost(target=agent, channels=[ch])
_ = host.app
assert ch.context is not None
captured: dict[str, Any] = {}
async def hook(request: ChannelRequest, **kwargs: Any) -> ChannelRequest:
captured["target"] = kwargs["target"]
captured["protocol_request"] = kwargs["protocol_request"]
return ChannelRequest(
channel=request.channel,
operation=request.operation,
input="rewritten",
session=request.session,
)
req = ChannelRequest(channel=ch.name, operation="op", input="original", session=ChannelSession("alice"))
await ch.context.run(req, run_hook=hook, protocol_request={"raw": True})
assert captured["target"] is agent
assert captured["protocol_request"] == {"raw": True}
assert agent.calls[0]["messages"].text == "rewritten"
async def test_context_run_stream_applies_run_hook_before_opening_stream(self) -> None:
agent = _FakeAgent()
ch = _RecordingChannel()
host = AgentFrameworkHost(target=agent, channels=[ch])
_ = host.app
assert ch.context is not None
def hook(request: ChannelRequest, **_: Any) -> ChannelRequest:
return ChannelRequest(channel=request.channel, operation=request.operation, input="streamed")
stream = await ch.context.run_stream(
ChannelRequest(channel=ch.name, operation="op", input="original"),
run_hook=hook,
stream_update_hook=lambda update: AgentResponseUpdate(
contents=[Content.from_text(text=update.text.upper())],
role="assistant",
),
)
chunks = [update.text async for update in stream]
assert chunks == ["OK"]
assert agent.calls[0]["messages"].text == "streamed"
# --------------------------------------------------------------------------- #
# Workflow target #
# --------------------------------------------------------------------------- #
class TestHostWorkflowTarget:
"""The host accepts a ``Workflow`` and dispatches to ``workflow.run(...)``."""
async def test_invoke_workflow_collapses_outputs_to_hosted_run_result(self) -> None:
from tests._workflow_fixtures import build_upper_workflow
workflow = build_upper_workflow()
ch = _RecordingChannel()
host = AgentFrameworkHost(target=workflow, channels=[ch])
_ = host.app
assert ch.context is not None
# The channel's run_hook is the canonical adapter from a free-form input
# to a workflow's typed input; here the start executor accepts ``str``
# already so the channel forwards ``input`` verbatim.
req = ChannelRequest(channel="fake", operation="message.create", input="hello")
result = await ch.context.run(req)
assert list(result.result.get_outputs()) == ["HELLO"]
# No session caching for workflow targets — Workflow has no
# ``create_session`` and the host must not invent one.
assert host._sessions == {}
async def test_stream_workflow_yields_updates_and_finalizes(self) -> None:
from tests._workflow_fixtures import build_echo_workflow
workflow = build_echo_workflow()
ch = _RecordingChannel()
host = AgentFrameworkHost(target=workflow, channels=[ch])
_ = host.app
assert ch.context is not None
req = ChannelRequest(channel="fake", operation="message.create", input="hi")
stream = await ch.context.run_stream(req)
updates: list[AgentResponseUpdate] = []
async for update in stream:
updates.append(update)
# The echo workflow yields a single ``output`` event whose payload is
# the original string; the host wraps non-update payloads into a
# one-shot ``AgentResponseUpdate`` carrying the text.
assert [u.text for u in updates] == ["hi"]
# ``raw_representation`` preserves the source ``WorkflowEvent`` so
# advanced consumers (telemetry, debug UIs) can recover the full
# workflow timeline.
assert all(u.raw_representation is not None for u in updates)
final = await stream.get_final_response()
assert final.text == "hi"
async def test_stream_workflow_yields_one_update_per_output_event(self) -> None:
from tests._workflow_fixtures import build_multi_chunk_workflow
workflow = build_multi_chunk_workflow()
ch = _RecordingChannel()
host = AgentFrameworkHost(target=workflow, channels=[ch])
_ = host.app
assert ch.context is not None
req = ChannelRequest(channel="fake", operation="message.create", input="x")
stream = await ch.context.run_stream(req)
chunks: list[str] = []
async for update in stream:
chunks.append(update.text)
# The originating ``executor_id`` is propagated via author_name so
# multi-agent workflows can route per-author rendering downstream.
assert update.author_name == "multi"
assert chunks == ["x-1", "x-2", "x-3"]
final = await stream.get_final_response()
assert final.text == "x-1x-2x-3"
class TestHostWorkflowCheckpointing:
"""The host scopes per-conversation checkpoints when ``checkpoint_location`` is set."""
def test_rejects_workflow_with_existing_checkpoint_storage(self, tmp_path: Any) -> None:
from agent_framework import InMemoryCheckpointStorage, WorkflowBuilder
from tests._workflow_fixtures import _UpperExecutor
workflow = WorkflowBuilder(
start_executor=_UpperExecutor(id="upper"),
checkpoint_storage=InMemoryCheckpointStorage(),
).build()
with pytest.raises(RuntimeError, match="already has checkpoint storage"):
AgentFrameworkHost(
target=workflow,
channels=[_RecordingChannel()],
checkpoint_location=tmp_path,
)
def test_warns_when_target_is_agent(self, tmp_path: Any, caplog: Any) -> None:
import logging as _logging
agent = _FakeAgent()
with caplog.at_level(_logging.WARNING, logger="agent_framework.hosting"):
host = AgentFrameworkHost(target=agent, channels=[_RecordingChannel()], checkpoint_location=tmp_path)
assert host._checkpoint_location is None
assert any("checkpoint_location" in rec.message for rec in caplog.records)
async def test_invoke_skips_checkpointing_when_no_isolation_key(self, tmp_path: Any) -> None:
from tests._workflow_fixtures import build_upper_workflow
workflow = build_upper_workflow()
ch = _RecordingChannel()
host = AgentFrameworkHost(target=workflow, channels=[ch], checkpoint_location=tmp_path)
_ = host.app
assert ch.context is not None
# No session -> no scoping key -> no checkpoint storage written.
req = ChannelRequest(channel="fake", operation="message.create", input="hi")
result = await ch.context.run(req)
assert list(result.result.get_outputs()) == ["HI"]
assert list(tmp_path.iterdir()) == []
async def test_invoke_writes_checkpoint_under_isolation_key(self, tmp_path: Any) -> None:
from tests._workflow_fixtures import build_upper_workflow
workflow = build_upper_workflow()
ch = _RecordingChannel()
host = AgentFrameworkHost(target=workflow, channels=[ch], checkpoint_location=tmp_path)
_ = host.app
assert ch.context is not None
req = ChannelRequest(
channel="fake",
operation="message.create",
input="hi",
session=ChannelSession(isolation_key="alice"),
)
result = await ch.context.run(req)
assert list(result.result.get_outputs()) == ["HI"]
# FileCheckpointStorage rooted at <tmp_path>/<isolation_key> should
# have produced at least one checkpoint file scoped to that user.
scoped = tmp_path / "alice"
assert scoped.exists()
assert any(scoped.iterdir()), "expected at least one checkpoint to be written under the per-user dir"
async def test_stream_writes_checkpoint_under_isolation_key(self, tmp_path: Any) -> None:
from tests._workflow_fixtures import build_echo_workflow
workflow = build_echo_workflow()
ch = _RecordingChannel()
host = AgentFrameworkHost(target=workflow, channels=[ch], checkpoint_location=tmp_path)
_ = host.app
assert ch.context is not None
req = ChannelRequest(
channel="fake",
operation="message.create",
input="hi",
session=ChannelSession(isolation_key="bob"),
)
stream = await ch.context.run_stream(req)
async for _ in stream:
pass
await stream.get_final_response()
scoped = tmp_path / "bob"
assert scoped.exists()
assert any(scoped.iterdir())
async def test_caller_supplied_checkpoint_storage_used_as_is(self, tmp_path: Any) -> None:
from agent_framework import InMemoryCheckpointStorage
from tests._workflow_fixtures import build_upper_workflow
storage = InMemoryCheckpointStorage()
workflow = build_upper_workflow()
ch = _RecordingChannel()
host = AgentFrameworkHost(target=workflow, channels=[ch], checkpoint_location=storage)
_ = host.app
assert ch.context is not None
assert host._checkpoint_location is storage
req = ChannelRequest(
channel="fake",
operation="message.create",
input="hi",
session=ChannelSession(isolation_key="carol"),
)
await ch.context.run(req)
# The caller-owned storage is used directly (no per-user scoping
# applied by the host); a checkpoint should appear in it.
checkpoints = await storage.list_checkpoints(workflow_name=workflow.name)
assert checkpoints, "expected the caller-supplied storage to receive a checkpoint"
# And nothing should have been written into the tmp_path tree.
assert list(tmp_path.iterdir()) == []
class TestCheckpointPathForIsolationKey:
"""Path-traversal hardening for isolation keys joined into checkpoint paths."""
@pytest.mark.parametrize(
"isolation_key",
[
"alice",
"telegram:42",
"entra:abc-def_0123",
"responses:user.name",
"x" * 200,
],
)
def test_accepts_legitimate_keys(self, tmp_path: Any, isolation_key: str) -> None:
from agent_framework_hosting._host import _checkpoint_path_for_isolation_key
target = _checkpoint_path_for_isolation_key(tmp_path, isolation_key)
assert target == (tmp_path / isolation_key).resolve()
assert target.is_relative_to(tmp_path.resolve())
@pytest.mark.parametrize(
"isolation_key",
[
"",
".",
"..",
"...",
"../etc",
"../../etc/passwd",
"a/b",
"a\\b",
"with\x00nul",
"/abs/path",
"C:/foo",
"C:foo",
],
)
def test_rejects_traversal_patterns(self, tmp_path: Any, isolation_key: str) -> None:
from agent_framework_hosting._host import _checkpoint_path_for_isolation_key
with pytest.raises(ValueError, match="isolation_key"):
_checkpoint_path_for_isolation_key(tmp_path, isolation_key)
def test_rejects_non_string(self, tmp_path: Any) -> None:
from agent_framework_hosting._host import _checkpoint_path_for_isolation_key
with pytest.raises(ValueError, match="non-empty string"):
_checkpoint_path_for_isolation_key(tmp_path, None) # type: ignore[arg-type]
class TestHostWorkflowCheckpointingPathTraversal:
"""End-to-end: malicious isolation keys must not escape ``checkpoint_location``."""
async def test_traversal_key_skips_checkpointing_with_warning(self, tmp_path: Any, caplog: Any) -> None:
import logging as _logging
from tests._workflow_fixtures import build_upper_workflow
workflow = build_upper_workflow()
ch = _RecordingChannel()
host = AgentFrameworkHost(target=workflow, channels=[ch], checkpoint_location=tmp_path)
_ = host.app
assert ch.context is not None
req = ChannelRequest(
channel="fake",
operation="message.create",
input="hi",
session=ChannelSession(isolation_key="../escape"),
)
with caplog.at_level(_logging.WARNING, logger="agent_framework.hosting"):
result = await ch.context.run(req)
assert list(result.result.get_outputs()) == ["HI"]
# Nothing should have been written under tmp_path.
assert list(tmp_path.iterdir()) == []
assert any(
"Skipping checkpoint storage" in rec.message and "isolation_key" in rec.message for rec in caplog.records
)
async def test_separator_in_key_skips_checkpointing(self, tmp_path: Any) -> None:
from tests._workflow_fixtures import build_upper_workflow
workflow = build_upper_workflow()
ch = _RecordingChannel()
host = AgentFrameworkHost(target=workflow, channels=[ch], checkpoint_location=tmp_path)
_ = host.app
assert ch.context is not None
# A literal separator in the key is a configuration smell at best
# and an attack at worst; either way it must not create a sub-path.
req = ChannelRequest(
channel="fake",
operation="message.create",
input="hi",
session=ChannelSession(isolation_key="evil/sub"),
)
result = await ch.context.run(req)
assert list(result.result.get_outputs()) == ["HI"]
assert list(tmp_path.iterdir()) == []
# --------------------------------------------------------------------------- #
# HostedRunResult — generic typed envelope #
# --------------------------------------------------------------------------- #
class TestHostedRunResult:
"""The envelope is a thin generic wrapper around the target's
full-fidelity ``result`` plus an optional session reference. The
host does NOT pre-shape or flatten ``result.messages`` /
``result.get_outputs()`` — channels read the canonical accessor on
the underlying result type themselves."""
def test_result_field_carries_full_fidelity_payload(self) -> None:
resp = AgentResponse(
messages=[Message(role="assistant", contents=[Content.from_text("hello")])],
response_id="r-1",
)
env: HostedRunResult[AgentResponse] = HostedRunResult(resp)
# ``result`` is the canonical accessor; metadata like
# ``response_id`` round-trips through unchanged because the host
# never re-shapes the payload.
assert env.result is resp
assert env.result.text == "hello"
assert env.result.response_id == "r-1"
assert env.session is None
def test_session_field_attached_and_optional(self) -> None:
resp = _assistant_response("ok")
session = _FakeAgentSession(session_id="sess-1")
env = HostedRunResult(resp, session=session)
assert env.session is session
def test_replace_clones_envelope_without_touching_result_by_default(self) -> None:
resp = _assistant_response("orig")
original = HostedRunResult(resp, session=_FakeAgentSession(session_id="s"))
clone = original.replace()
# Clone is a distinct envelope but the inner ``result`` is the
# same object — channels that need a deep copy of ``result``
# itself do the copy themselves.
assert clone is not original
assert clone.result is original.result
assert clone.session is original.session
def test_replace_rebinds_result_without_perturbing_original(self) -> None:
original = HostedRunResult(_assistant_response("orig"))
clone = original.replace(result=_assistant_response("shaped"))
assert original.result.text == "orig"
assert clone.result.text == "shaped"
def test_replace_supports_explicit_none_session(self) -> None:
original = HostedRunResult(_assistant_response("x"), session=_FakeAgentSession(session_id="s"))
clone = original.replace(session=None)
assert clone.session is None
# Source envelope untouched.
assert original.session is not None
async def test_invoke_preserves_full_agent_response_on_result(self) -> None:
"""The host's ``_invoke`` carries the agent's ``AgentResponse``
through unchanged on ``result``. Channels see image / tool /
structured content alongside text — and metadata like
``response_id`` — without the host pre-shaping anything."""
class _MultiModalResponse:
def __init__(self) -> None:
self.text = "summary"
self.response_id = "resp-xyz"
self.messages = [
Message(
role="assistant",
contents=[
Content.from_text("summary"),
# Non-text content the host must NOT drop.
Content.from_data(data=b"\x89PNG", media_type="image/png"),
],
),
]
class _MultiModalAgent:
def create_session(self, *, session_id: str | None = None) -> _FakeAgentSession:
return _FakeAgentSession(session_id=session_id)
async def run(self, *_args: Any, **_kwargs: Any) -> Any:
return _MultiModalResponse()
ch = _RecordingChannel(name="responses")
host = AgentFrameworkHost(target=_MultiModalAgent(), channels=[ch])
_ = host.app
assert ch.context is not None
req = ChannelRequest(channel="responses", operation="op", input="hi")
env = await ch.context.run(req)
# Full agent response carried through verbatim — no flattening.
assert env.result.text == "summary"
assert env.result.response_id == "resp-xyz"
assert len(env.result.messages) == 1
types = [c.type for c in env.result.messages[0].contents]
assert "text" in types and "data" in types
# --------------------------------------------------------------------------- #
# Bind request context — duck-typed hook on context providers #
# --------------------------------------------------------------------------- #
from contextlib import contextmanager # noqa: E402
class _RecordingContextProvider:
"""Stand-in for a ``HistoryProvider`` that exposes the duck-typed
``bind_request_context(response_id=..., previous_response_id=..., **_)``
seam the host calls. Records (event, payload) pairs so tests can
assert call ordering relative to the agent run + stream lifecycle.
"""
def __init__(self, *, name: str = "rec") -> None:
self.name = name
# (event, payload) tuples — events: "enter", "exit", "agent_start",
# "agent_end", "stream_yield", "stream_done".
self.events: list[tuple[str, Any]] = []
@contextmanager
def bind_request_context(self, **kwargs: Any) -> Any:
# Snapshot the call kwargs on enter (so tests can assert
# response_id / previous_response_id forwarding) and the same
# snapshot on exit so we can verify the SAME payload bracketed
# the agent run.
snapshot = dict(kwargs)
self.events.append(("enter", snapshot))
try:
yield
finally:
self.events.append(("exit", snapshot))
class _ProvidersAgent:
"""Agent stand-in that exposes ``context_providers`` so the host's
``_flat_context_providers`` finds the recording provider.
Mirrors the real :class:`agent_framework.Agent.run` shape: a sync
``def`` that returns either an ``Awaitable[AgentResponse]`` (for
``stream=False``) or a :class:`ResponseStream` synchronously (for
``stream=True``). The host's ``_invoke_stream`` relies on the sync
return so it can wrap the stream in ``_BoundResponseStream`` and
hand it to channels for later iteration.
"""
def __init__(self, providers: Sequence[Any], *, reply: str = "ok") -> None:
self.context_providers = list(providers)
self._reply = reply
self.calls: list[dict[str, Any]] = []
def create_session(self, *, session_id: str | None = None) -> _FakeAgentSession:
return _FakeAgentSession(session_id=session_id)
def run(
self,
messages: Any = None,
*,
stream: bool = False,
session: Any = None,
**kwargs: Any,
) -> Any:
self.calls.append({"messages": messages, "stream": stream, "session": session, "kwargs": kwargs})
if stream:
providers = self.context_providers
updates = [
AgentResponseUpdate(contents=[Content.from_text("chunk-1")], role="assistant"),
AgentResponseUpdate(contents=[Content.from_text("chunk-2")], role="assistant"),
]
async def _gen() -> AsyncIterator[AgentResponseUpdate]:
# ``agent_start`` is only recorded once iteration begins;
# if the channel abandons the stream without iterating
# we expect to see neither ``agent_start`` nor any
# ``stream_yield`` events.
for prov in providers:
if isinstance(prov, _RecordingContextProvider):
prov.events.append(("agent_start", None))
for u in updates:
for prov in providers:
if isinstance(prov, _RecordingContextProvider):
prov.events.append(("stream_yield", u.text))
yield u
async def _finalize(items: Sequence[AgentResponseUpdate]) -> AgentResponse: # noqa: RUF029
for prov in providers:
if isinstance(prov, _RecordingContextProvider):
prov.events.append(("stream_done", len(items)))
return AgentResponse.from_updates(items)
return ResponseStream[AgentResponseUpdate, AgentResponse](_gen(), finalizer=_finalize)
async def _coro() -> _FakeAgentResponse:
for prov in self.context_providers:
if isinstance(prov, _RecordingContextProvider):
prov.events.append(("agent_start", None))
prov.events.append(("agent_end", None))
return _FakeAgentResponse(text=self._reply)
return _coro()
class _ProviderWrapper:
"""Wrap children in a ``providers`` attribute (mirrors the
``ContextProviderBase`` aggregation shape)."""
def __init__(self, providers: Sequence[Any]) -> None:
self.providers = list(providers)
class TestBindRequestContext:
"""The host walks ``target.context_providers``, descends one level
when a provider exposes a ``providers`` attribute, and calls
``bind_request_context(response_id=..., previous_response_id=...)``
on every provider that supports it. Foundry response-id chaining
plugs into this exact seam — a regression that mistypes the kwarg
name, drops the descent, or fails to keep the binding open across
the agent run silently breaks chained writes."""
async def test_bind_called_with_request_attributes(self) -> None:
prov = _RecordingContextProvider()
agent = _ProvidersAgent([prov])
ch = _RecordingChannel(name="responses")
host = AgentFrameworkHost(target=agent, channels=[ch])
_ = host.app
assert ch.context is not None
req = ChannelRequest(
channel="responses",
operation="op",
input="hi",
session=ChannelSession(isolation_key="alice"),
attributes={"response_id": "resp_abc", "previous_response_id": "resp_prev"},
)
result = await ch.context.run(req)
assert result.result.text == "ok"
# Bind ↔ unbind brackets the agent run.
events = [name for name, _ in prov.events]
assert events == ["enter", "agent_start", "agent_end", "exit"]
# Both response_id and previous_response_id forwarded by name.
_, enter_payload = prov.events[0]
assert enter_payload["response_id"] == "resp_abc"
assert enter_payload["previous_response_id"] == "resp_prev"
async def test_bind_skipped_when_no_response_id_attribute(self) -> None:
"""Without a ``response_id`` attribute on the request, the host
skips the binding entirely — the contract requires one to anchor
the chain."""
prov = _RecordingContextProvider()
agent = _ProvidersAgent([prov])
ch = _RecordingChannel(name="responses")
host = AgentFrameworkHost(target=agent, channels=[ch])
_ = host.app
assert ch.context is not None
req = ChannelRequest(channel="responses", operation="op", input="hi")
await ch.context.run(req)
assert prov.events == [("agent_start", None), ("agent_end", None)]
async def test_bind_does_not_descend_into_providers_attribute(self) -> None:
"""The host does not introspect ``ContextProviderBase`` aggregator
wrappers. Aggregator providers are responsible for forwarding the
bind to their children themselves (``AggregateContextProvider``
already does this). The host treats whatever ``agent.context_providers``
exposes as the final, flat list."""
prov = _RecordingContextProvider(name="inner")
wrapper = _ProviderWrapper([prov])
agent = _ProvidersAgent([wrapper])
ch = _RecordingChannel(name="responses")
host = AgentFrameworkHost(target=agent, channels=[ch])
_ = host.app
assert ch.context is not None
req = ChannelRequest(
channel="responses",
operation="op",
input="hi",
attributes={"response_id": "resp_xyz"},
)
await ch.context.run(req)
# The wrapper does not implement ``response_context``, so the
# inner provider must NOT have been entered by the host.
assert ("enter", {"response_id": "resp_xyz", "previous_response_id": None}) not in prov.events
async def test_bind_held_open_until_stream_exhaustion(self) -> None:
"""Streaming runs return a ``ResponseStream`` synchronously but
consumption happens later. The binding must survive that gap and
only release after the iterator drains so the provider sees
every yielded chunk under the bound context."""
prov = _RecordingContextProvider()
agent = _ProvidersAgent([prov])
ch = _RecordingChannel(name="responses")
host = AgentFrameworkHost(target=agent, channels=[ch])
_ = host.app
assert ch.context is not None
req = ChannelRequest(
channel="responses",
operation="op",
input="hi",
stream=True,
attributes={"response_id": "resp_stream"},
)
stream = await ch.context.run_stream(req)
# As soon as run_stream returns, the binding must already be open
# so any provider work that happens during iteration sees it.
names_after_create = [name for name, _ in prov.events]
assert names_after_create.count("enter") == 1
assert "exit" not in names_after_create
chunks: list[str] = []
async for u in stream:
chunks.append(u.text)
assert chunks == ["chunk-1", "chunk-2"]
# After exhaustion the binding must be released — exactly once.
names_after_drain = [name for name, _ in prov.events]
assert names_after_drain.count("enter") == 1
assert names_after_drain.count("exit") == 1
# Brackets surround every stream_yield.
enter_idx = names_after_drain.index("enter")
exit_idx = names_after_drain.index("exit")
yield_idxs = [i for i, name in enumerate(names_after_drain) if name == "stream_yield"]
assert all(enter_idx < i < exit_idx for i in yield_idxs)
# --------------------------------------------------------------------------- #
# Agent-target streaming — `_BoundResponseStream` adapter behaviour #
# --------------------------------------------------------------------------- #
class TestBoundResponseStream:
"""The ``_BoundResponseStream`` adapter holds the bind-context
``ExitStack`` open across iteration. Cover the iterator-finally
close, ``get_final_response`` close, double-close idempotence,
``aclose()``, ``__getattr__`` forwarding, and the awaitable path
(which now routes through ``get_final_response`` so it doesn't
leak the binding)."""
async def test_get_final_response_closes_binding(self) -> None:
prov = _RecordingContextProvider()
agent = _ProvidersAgent([prov])
ch = _RecordingChannel(name="responses")
host = AgentFrameworkHost(target=agent, channels=[ch])
_ = host.app
assert ch.context is not None
req = ChannelRequest(
channel="responses",
operation="op",
input="hi",
stream=True,
attributes={"response_id": "resp_get_final"},
)
stream = await ch.context.run_stream(req)
# Skip iteration and go straight to ``get_final_response``;
# the adapter must drain the inner stream itself and close
# the binding in ``finally``.
final = await stream.get_final_response()
assert final.text == "chunk-1chunk-2"
names = [n for n, _ in prov.events]
assert names.count("enter") == 1
assert names.count("exit") == 1
async def test_double_close_is_idempotent(self) -> None:
prov = _RecordingContextProvider()
agent = _ProvidersAgent([prov])
ch = _RecordingChannel(name="responses")
host = AgentFrameworkHost(target=agent, channels=[ch])
_ = host.app
assert ch.context is not None
req = ChannelRequest(
channel="responses",
operation="op",
input="hi",
stream=True,
attributes={"response_id": "resp_idem"},
)
stream = await ch.context.run_stream(req)
async for _u in stream:
pass
# Iteration's finally already closed; an explicit ``aclose``
# afterwards must be a no-op (no second exit event).
await stream.aclose() # type: ignore[attr-defined]
await stream.aclose() # type: ignore[attr-defined]
names = [n for n, _ in prov.events]
assert names.count("exit") == 1
async def test_aclose_releases_binding_when_stream_abandoned(self) -> None:
"""A channel that abandons the stream without iterating must
be able to call ``aclose()`` so the host-bound contextvars
don't leak for the host's lifetime."""
prov = _RecordingContextProvider()
agent = _ProvidersAgent([prov])
ch = _RecordingChannel(name="responses")
host = AgentFrameworkHost(target=agent, channels=[ch])
_ = host.app
assert ch.context is not None
req = ChannelRequest(
channel="responses",
operation="op",
input="hi",
stream=True,
attributes={"response_id": "resp_abandon"},
)
stream = await ch.context.run_stream(req)
await stream.aclose() # type: ignore[attr-defined]
# Binding released without iterating.
names = [n for n, _ in prov.events]
assert names.count("enter") == 1
assert names.count("exit") == 1
# Agent never ran — we abandoned before iteration.
assert "agent_start" not in names
async def test_getattr_forwards_to_inner_stream(self) -> None:
"""``_BoundResponseStream.__getattr__`` forwards unknown
attributes to the inner ``ResponseStream``; channels that
check, e.g., ``stream.add_result_hook(...)`` must keep working."""
prov = _RecordingContextProvider()
agent = _ProvidersAgent([prov])
ch = _RecordingChannel(name="responses")
host = AgentFrameworkHost(target=agent, channels=[ch])
_ = host.app
assert ch.context is not None
req = ChannelRequest(
channel="responses",
operation="op",
input="hi",
stream=True,
attributes={"response_id": "resp_getattr"},
)
stream = await ch.context.run_stream(req)
# ``with_result_hook`` is a real method on ``ResponseStream``;
# if forwarding broke this would AttributeError.
try:
assert callable(stream.with_result_hook) # type: ignore[attr-defined]
finally:
await stream.aclose() # type: ignore[attr-defined]
async def test_await_path_routes_through_get_final_response(self) -> None:
"""``await stream`` is a convenience for ``await
get_final_response()``. The previous direct delegation leaked
the binding for the host's lifetime; the new routing closes the
stack in the same ``finally`` as ``get_final_response``."""
prov = _RecordingContextProvider()
agent = _ProvidersAgent([prov])
ch = _RecordingChannel(name="responses")
host = AgentFrameworkHost(target=agent, channels=[ch])
_ = host.app
assert ch.context is not None
req = ChannelRequest(
channel="responses",
operation="op",
input="hi",
stream=True,
attributes={"response_id": "resp_await"},
)
stream = await ch.context.run_stream(req)
final = await stream # exercises __await__
assert final.text == "chunk-1chunk-2"
names = [n for n, _ in prov.events]
assert names.count("enter") == 1
assert names.count("exit") == 1
# --------------------------------------------------------------------------- #
# `_wrap_input` — list[Message] LAST-message metadata stamping #
# --------------------------------------------------------------------------- #
class TestWrapInputListMessages:
"""The ``hosting`` block lands on the LAST message of a list — the
contract is load-bearing: the user turn (typically last) must
carry the channel provenance + identity for history correlation;
a regression stamping ``messages[0]`` instead silently breaks
every multi-message payload."""
async def test_metadata_lands_on_last_message_only(self) -> None:
agent = _FakeAgent()
ch = _RecordingChannel(name="responses")
host = AgentFrameworkHost(target=agent, channels=[ch])
_ = host.app
assert ch.context is not None
# Responses-API style: a system instruction followed by a user
# turn. Only the user turn (LAST) gets stamped.
system = Message(role="system", contents=[Content.from_text("be concise")])
user = Message(role="user", contents=[Content.from_text("hi")])
req = ChannelRequest(
channel="responses",
operation="op",
input=[system, user],
identity=ChannelIdentity(channel="responses", native_id="user:1"),
)
await ch.context.run(req)
forwarded = agent.calls[0]["messages"]
assert isinstance(forwarded, list)
assert len(forwarded) == 2
# System stays clean.
assert (system.additional_properties or {}).get("hosting") is None
# User turn carries the metadata.
hosting = forwarded[-1].additional_properties["hosting"]
assert hosting["channel"] == "responses"
assert hosting["identity"]["native_id"] == "user:1"
async def test_single_message_payload_still_works(self) -> None:
"""Regression guard: the single-``Message`` branch must be
unchanged by the LAST-of-list logic above."""
agent = _FakeAgent()
ch = _RecordingChannel(name="responses")
host = AgentFrameworkHost(target=agent, channels=[ch])
_ = host.app
assert ch.context is not None
only = Message(role="user", contents=[Content.from_text("hi")])
req = ChannelRequest(channel="responses", operation="op", input=only)
await ch.context.run(req)
forwarded = agent.calls[0]["messages"]
assert isinstance(forwarded, Message)
assert forwarded.additional_properties["hosting"]["channel"] == "responses"
# --------------------------------------------------------------------------- #
# Lifespan callback aggregation #
# --------------------------------------------------------------------------- #
class _RaisingLifecycleChannel:
"""Channel whose startup OR shutdown callback raises a controlled error."""
def __init__(self, name: str, *, fail_on: str) -> None:
self.name = name
self.path = ""
self._fail_on = fail_on # "startup" | "shutdown"
self.start_calls: list[str] = []
self.stop_calls: list[str] = []
def contribute(self, _context: ChannelContext) -> ChannelContribution:
async def _start() -> None:
self.start_calls.append("up")
if self._fail_on == "startup":
raise RuntimeError(f"startup-boom-{self.name}")
async def _stop() -> None:
self.stop_calls.append("down")
if self._fail_on == "shutdown":
raise RuntimeError(f"shutdown-boom-{self.name}")
return ChannelContribution(on_startup=[_start], on_shutdown=[_stop])
class _OkLifecycleChannel:
def __init__(self, name: str) -> None:
self.name = name
self.path = ""
self.start_calls: list[str] = []
self.stop_calls: list[str] = []
def contribute(self, _context: ChannelContext) -> ChannelContribution:
async def _start() -> None:
self.start_calls.append("up")
async def _stop() -> None:
self.stop_calls.append("down")
return ChannelContribution(on_startup=[_start], on_shutdown=[_stop])
class TestLifespanAggregation:
"""One bad startup / shutdown callback must NOT abort the rest —
every channel gets a chance to wire / unwire so half-initialised
state doesn't leak. The first error is still raised so the
process exits with a failure; remaining errors are logged so
operators see them all in one log scrape."""
def test_shutdown_failure_does_not_skip_peer_shutdowns(self, caplog: Any) -> None:
import logging as _logging
agent = _FakeAgent()
bad = _RaisingLifecycleChannel("bad", fail_on="shutdown")
ok1 = _OkLifecycleChannel("ok1")
ok2 = _OkLifecycleChannel("ok2")
# Order: bad first so that without aggregation, ok1+ok2 would
# never get to run their shutdown callbacks.
host = AgentFrameworkHost(target=agent, channels=[bad, ok1, ok2])
with caplog.at_level(_logging.ERROR, logger="agent_framework.hosting"): # noqa: SIM117
with pytest.raises(RuntimeError, match="shutdown-boom-bad"), TestClient(host.app):
pass
# Every channel had its shutdown attempted, even though `bad` raised.
assert bad.stop_calls == ["down"]
assert ok1.stop_calls == ["down"]
assert ok2.stop_calls == ["down"]
def test_startup_failure_aggregates_logs_and_raises_first(self, caplog: Any) -> None:
import logging as _logging
agent = _FakeAgent()
ok1 = _OkLifecycleChannel("ok1")
bad = _RaisingLifecycleChannel("bad", fail_on="startup")
ok2 = _OkLifecycleChannel("ok2")
another_bad = _RaisingLifecycleChannel("bad2", fail_on="startup")
host = AgentFrameworkHost(
target=agent,
channels=[ok1, bad, ok2, another_bad],
)
with caplog.at_level(_logging.ERROR, logger="agent_framework.hosting"): # noqa: SIM117
# The first failing callback's error is the one that
# propagates; remaining failures are logged.
with pytest.raises(RuntimeError, match="startup-boom-bad"), TestClient(host.app):
pass
# Every startup callback ran (even ok2 / another_bad after the
# first failure) so we get a complete picture in the logs.
assert ok1.start_calls == ["up"]
assert bad.start_calls == ["up"]
assert ok2.start_calls == ["up"]
assert another_bad.start_calls == ["up"]
# Both failures show up in operator logs. ``logger.exception`` puts
# the exception payload in ``record.exc_text``; the formatted summary
# of the second failure goes into ``record.message`` via the
# aggregate "N callback(s) failed" line.
log_messages = [rec.getMessage() for rec in caplog.records]
log_exc_texts = [rec.exc_text or "" for rec in caplog.records]
log_text = "\n".join(log_messages + log_exc_texts)
assert "startup-boom-bad" in log_text
assert "startup-boom-bad2" in log_text or "callback(s) failed" in log_text