# Copyright (c) Microsoft. All rights reserved. """Shared test fixtures and stubs for AG-UI tests.""" import sys from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Mapping, MutableSequence, Sequence from pathlib import Path from types import SimpleNamespace from typing import Any, Generic, Literal, cast, overload import pytest from agent_framework import ( AgentResponse, AgentResponseUpdate, AgentSession, BaseChatClient, ChatOptions, ChatResponse, ChatResponseUpdate, Content, Message, SupportsAgentRun, SupportsChatGetResponse, ) from agent_framework._clients import OptionsCoT from agent_framework._middleware import ChatMiddlewareLayer from agent_framework._tools import FunctionInvocationLayer from agent_framework._types import ResponseStream from agent_framework.observability import ChatTelemetryLayer if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: from typing_extensions import override # type: ignore[import] # pragma: no cover StreamFn = Callable[..., AsyncIterable[ChatResponseUpdate]] ResponseFn = Callable[..., Awaitable[ChatResponse]] def pytest_configure() -> None: """Ensure this test directory is on sys.path so helper modules can be imported by name.""" test_dir = str(Path(__file__).resolve().parent) if test_dir not in sys.path: sys.path.insert(0, test_dir) class StreamingChatClientStub( FunctionInvocationLayer[OptionsCoT], ChatMiddlewareLayer[OptionsCoT], ChatTelemetryLayer[OptionsCoT], BaseChatClient[OptionsCoT], Generic[OptionsCoT], ): """Typed streaming stub that satisfies SupportsChatGetResponse.""" def __init__(self, stream_fn: StreamFn, response_fn: ResponseFn | None = None) -> None: super().__init__(middleware=[]) self._stream_fn = stream_fn self._response_fn = response_fn self.last_session: AgentSession | None = None self.last_service_session_id: str | None = None @overload def get_response( self, messages: Sequence[Message], *, stream: Literal[False] = ..., options: ChatOptions[Any], **kwargs: Any, ) -> Awaitable[ChatResponse[Any]]: ... @overload def get_response( self, messages: Sequence[Message], *, stream: Literal[False] = ..., options: OptionsCoT | ChatOptions[None] | None = ..., **kwargs: Any, ) -> Awaitable[ChatResponse[Any]]: ... @overload def get_response( self, messages: Sequence[Message], *, stream: Literal[True], options: OptionsCoT | ChatOptions[Any] | None = ..., **kwargs: Any, ) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ... def get_response( self, messages: Sequence[Message], *, stream: bool = False, options: OptionsCoT | ChatOptions[Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: client_kwargs = kwargs.get("client_kwargs") if isinstance(client_kwargs, Mapping): self.last_session = cast(AgentSession | None, client_kwargs.get("session")) else: self.last_session = None self.last_service_session_id = self.last_session.service_session_id if self.last_session else None return cast( Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]], super().get_response( messages=messages, stream=cast(Literal[True, False], stream), options=options, **kwargs, ), ) @override def _inner_get_response( self, *, messages: Sequence[Message], stream: bool = False, options: Mapping[str, Any], **kwargs: Any, ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: if stream: def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse: return ChatResponse.from_updates(updates) return ResponseStream(self._stream_fn(messages, options, **kwargs), finalizer=_finalize) return self._get_response_impl(messages, options, **kwargs) async def _get_response_impl( self, messages: Sequence[Message], options: Mapping[str, Any], **kwargs: Any ) -> ChatResponse: """Non-streaming implementation.""" if self._response_fn is not None: return await self._response_fn(messages, options, **kwargs) contents: list[Any] = [] async for update in self._stream_fn(list(messages), dict(options), **kwargs): contents.extend(update.contents) return ChatResponse( messages=[Message(role="assistant", contents=contents)], response_id="stub-response", ) def stream_from_updates(updates: list[ChatResponseUpdate]) -> StreamFn: """Create a stream function that yields from a static list of updates.""" async def _stream( messages: MutableSequence[Message], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: for update in updates: yield update return _stream class StubAgent(SupportsAgentRun): """Minimal SupportsAgentRun stub for orchestrator tests.""" def __init__( self, updates: list[AgentResponseUpdate] | None = None, *, agent_id: str = "stub-agent", agent_name: str | None = "stub-agent", default_options: Any | None = None, client: Any | None = None, ) -> None: self.id = agent_id self.name = agent_name self.description = "stub agent" self.updates = updates or [AgentResponseUpdate(contents=[Content.from_text(text="response")], role="assistant")] self.default_options: dict[str, Any] = ( default_options if isinstance(default_options, dict) else {"tools": None, "response_format": None} ) self.client = client or SimpleNamespace(function_invocation_configuration=None) self.messages_received: list[Any] = [] self.tools_received: list[Any] | None = None self.last_session: AgentSession | None = None @overload def run( self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, *, stream: Literal[False] = ..., session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]]: ... @overload def run( self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, *, stream: Literal[True], session: AgentSession | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: if stream: async def _stream() -> AsyncIterator[AgentResponseUpdate]: self.messages_received = [] if messages is None else list(messages) # type: ignore[arg-type] self.last_session = session self.tools_received = kwargs.get("tools") for update in self.updates: yield update def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse: return AgentResponse.from_updates(updates) return ResponseStream(_stream(), finalizer=_finalize) async def _get_response() -> AgentResponse[Any]: return AgentResponse(messages=[], response_id="stub-response") return _get_response() def create_session(self, **kwargs: Any) -> AgentSession: return AgentSession() # Fixtures @pytest.fixture def streaming_chat_client_stub() -> type[SupportsChatGetResponse]: """Return the StreamingChatClientStub class for creating test instances.""" return StreamingChatClientStub # type: ignore[return-value] @pytest.fixture def stream_from_updates_fixture() -> Callable[[list[ChatResponseUpdate]], StreamFn]: """Return the stream_from_updates helper function.""" return stream_from_updates @pytest.fixture def stub_agent() -> type[SupportsAgentRun]: """Return the StubAgent class for creating test instances.""" return StubAgent # type: ignore[return-value] # ── Fixtures for golden / integration tests ── @pytest.fixture def collect_events() -> Callable[..., Any]: """Return an async helper that collects all events from an async generator.""" async def _collect(async_gen: AsyncIterable[Any]) -> list[Any]: return [event async for event in async_gen] return _collect @pytest.fixture def make_agent_wrapper() -> Callable[..., Any]: """Factory that builds an AgentFrameworkAgent from a stream function. Usage:: agent = make_agent_wrapper( stream_fn=stream_from_updates(updates), state_schema=..., ) events = [e async for e in agent.run(payload)] """ from agent_framework_ag_ui import AgentFrameworkAgent def _factory( stream_fn: StreamFn, *, state_schema: Any | None = None, predict_state_config: dict[str, dict[str, str]] | None = None, require_confirmation: bool = True, ) -> Any: client = StreamingChatClientStub(stream_fn) stub = StubAgent(client=client) return AgentFrameworkAgent( agent=stub, state_schema=state_schema, predict_state_config=predict_state_config, require_confirmation=require_confirmation, ) return _factory @pytest.fixture def make_app() -> Callable[..., Any]: """Factory that builds a FastAPI app with an AG-UI endpoint. Usage:: app = make_app(agent_or_wrapper, path="/test") """ from fastapi import FastAPI from agent_framework_ag_ui import add_agent_framework_fastapi_endpoint def _factory( agent: Any, *, path: str = "/", state_schema: Any | None = None, predict_state_config: dict[str, dict[str, str]] | None = None, default_state: dict[str, Any] | None = None, ) -> FastAPI: app = FastAPI() add_agent_framework_fastapi_endpoint( app, agent, path=path, state_schema=state_schema, predict_state_config=predict_state_config, default_state=default_state, ) return app return _factory