mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
0cd40f8354
* [BREAKING] Refactor middleware layering and raw clients Reorder chat client layers so function invocation wraps chat middleware, and chat middleware stays outside telemetry while still running for each inner model call. Add middleware pipeline caching, refresh docs and samples, and split Anthropic into raw and public clients to match the standard layering model. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Tighten typing ignores in ancillary modules Add targeted typing ignores in workflow visualization and lab modules so pyright stays clean alongside the middleware refactor work. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix categorize_middleware to unpack tuple/Sequence and use relative MRO assertions - Broaden isinstance check in categorize_middleware from list to Sequence so tuples and other Sequence types are properly unpacked instead of being appended as a single item. - Replace fragile hardcoded MRO index assertions in anthropic test with relative ordering via mro.index(). - Add regression tests for categorize_middleware with tuple, list, and None inputs. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix middleware string decomposition, add middleware param to FunctionInvocationLayer, and add tests (#4710) - Guard categorize_middleware Sequence check against str/bytes to prevent character-by-character decomposition of accidentally passed strings - Add explicit middleware parameter to FunctionInvocationLayer.get_response and merge it into client_kwargs before categorization, fixing the inconsistency where only OpenAIChatClient supported this parameter - Add assertions that RawAnthropicClient does not inherit convenience layers - Add chat middleware cache test with non-empty base middleware - Add tests for single unwrapped middleware item and string input Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Apply pre-commit auto-fixes * Apply pre-commit auto-fixes * Address review feedback for #4710: review comment fixes --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Copilot <copilot@github.com>
336 lines
11 KiB
Python
336 lines
11 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
"""Shared test fixtures and stubs for AG-UI tests."""
|
|
|
|
import sys
|
|
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Mapping, MutableSequence, Sequence
|
|
from pathlib import Path
|
|
from types import SimpleNamespace
|
|
from typing import Any, Generic, Literal, cast, overload
|
|
|
|
import pytest
|
|
from agent_framework import (
|
|
AgentResponse,
|
|
AgentResponseUpdate,
|
|
AgentSession,
|
|
BaseChatClient,
|
|
ChatOptions,
|
|
ChatResponse,
|
|
ChatResponseUpdate,
|
|
Content,
|
|
Message,
|
|
SupportsAgentRun,
|
|
SupportsChatGetResponse,
|
|
)
|
|
from agent_framework._clients import OptionsCoT
|
|
from agent_framework._middleware import ChatMiddlewareLayer
|
|
from agent_framework._tools import FunctionInvocationLayer
|
|
from agent_framework._types import ResponseStream
|
|
from agent_framework.observability import ChatTelemetryLayer
|
|
|
|
if sys.version_info >= (3, 12):
|
|
from typing import override # type: ignore # pragma: no cover
|
|
else:
|
|
from typing_extensions import override # type: ignore[import] # pragma: no cover
|
|
|
|
StreamFn = Callable[..., AsyncIterable[ChatResponseUpdate]]
|
|
ResponseFn = Callable[..., Awaitable[ChatResponse]]
|
|
|
|
|
|
def pytest_configure() -> None:
|
|
"""Ensure this test directory is on sys.path so helper modules can be imported by name."""
|
|
test_dir = str(Path(__file__).resolve().parent)
|
|
if test_dir not in sys.path:
|
|
sys.path.insert(0, test_dir)
|
|
|
|
|
|
class StreamingChatClientStub(
|
|
FunctionInvocationLayer[OptionsCoT],
|
|
ChatMiddlewareLayer[OptionsCoT],
|
|
ChatTelemetryLayer[OptionsCoT],
|
|
BaseChatClient[OptionsCoT],
|
|
Generic[OptionsCoT],
|
|
):
|
|
"""Typed streaming stub that satisfies SupportsChatGetResponse."""
|
|
|
|
def __init__(self, stream_fn: StreamFn, response_fn: ResponseFn | None = None) -> None:
|
|
super().__init__(middleware=[])
|
|
self._stream_fn = stream_fn
|
|
self._response_fn = response_fn
|
|
self.last_session: AgentSession | None = None
|
|
self.last_service_session_id: str | None = None
|
|
|
|
@overload
|
|
def get_response(
|
|
self,
|
|
messages: Sequence[Message],
|
|
*,
|
|
stream: Literal[False] = ...,
|
|
options: ChatOptions[Any],
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse[Any]]: ...
|
|
|
|
@overload
|
|
def get_response(
|
|
self,
|
|
messages: Sequence[Message],
|
|
*,
|
|
stream: Literal[False] = ...,
|
|
options: OptionsCoT | ChatOptions[None] | None = ...,
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse[Any]]: ...
|
|
|
|
@overload
|
|
def get_response(
|
|
self,
|
|
messages: Sequence[Message],
|
|
*,
|
|
stream: Literal[True],
|
|
options: OptionsCoT | ChatOptions[Any] | None = ...,
|
|
**kwargs: Any,
|
|
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
|
|
|
def get_response(
|
|
self,
|
|
messages: Sequence[Message],
|
|
*,
|
|
stream: bool = False,
|
|
options: OptionsCoT | ChatOptions[Any] | None = None,
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
|
client_kwargs = kwargs.get("client_kwargs")
|
|
if isinstance(client_kwargs, Mapping):
|
|
self.last_session = cast(AgentSession | None, client_kwargs.get("session"))
|
|
else:
|
|
self.last_session = None
|
|
self.last_service_session_id = self.last_session.service_session_id if self.last_session else None
|
|
return cast(
|
|
Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]],
|
|
super().get_response(
|
|
messages=messages,
|
|
stream=cast(Literal[True, False], stream),
|
|
options=options,
|
|
**kwargs,
|
|
),
|
|
)
|
|
|
|
@override
|
|
def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: Sequence[Message],
|
|
stream: bool = False,
|
|
options: Mapping[str, Any],
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
if stream:
|
|
|
|
def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse:
|
|
return ChatResponse.from_updates(updates)
|
|
|
|
return ResponseStream(self._stream_fn(messages, options, **kwargs), finalizer=_finalize)
|
|
|
|
return self._get_response_impl(messages, options, **kwargs)
|
|
|
|
async def _get_response_impl(
|
|
self, messages: Sequence[Message], options: Mapping[str, Any], **kwargs: Any
|
|
) -> ChatResponse:
|
|
"""Non-streaming implementation."""
|
|
if self._response_fn is not None:
|
|
return await self._response_fn(messages, options, **kwargs)
|
|
|
|
contents: list[Any] = []
|
|
async for update in self._stream_fn(list(messages), dict(options), **kwargs):
|
|
contents.extend(update.contents)
|
|
|
|
return ChatResponse(
|
|
messages=[Message(role="assistant", contents=contents)],
|
|
response_id="stub-response",
|
|
)
|
|
|
|
|
|
def stream_from_updates(updates: list[ChatResponseUpdate]) -> StreamFn:
|
|
"""Create a stream function that yields from a static list of updates."""
|
|
|
|
async def _stream(
|
|
messages: MutableSequence[Message], options: dict[str, Any], **kwargs: Any
|
|
) -> AsyncIterator[ChatResponseUpdate]:
|
|
for update in updates:
|
|
yield update
|
|
|
|
return _stream
|
|
|
|
|
|
class StubAgent(SupportsAgentRun):
|
|
"""Minimal SupportsAgentRun stub for orchestrator tests."""
|
|
|
|
def __init__(
|
|
self,
|
|
updates: list[AgentResponseUpdate] | None = None,
|
|
*,
|
|
agent_id: str = "stub-agent",
|
|
agent_name: str | None = "stub-agent",
|
|
default_options: Any | None = None,
|
|
client: Any | None = None,
|
|
) -> None:
|
|
self.id = agent_id
|
|
self.name = agent_name
|
|
self.description = "stub agent"
|
|
self.updates = updates or [AgentResponseUpdate(contents=[Content.from_text(text="response")], role="assistant")]
|
|
self.default_options: dict[str, Any] = (
|
|
default_options if isinstance(default_options, dict) else {"tools": None, "response_format": None}
|
|
)
|
|
self.client = client or SimpleNamespace(function_invocation_configuration=None)
|
|
self.messages_received: list[Any] = []
|
|
self.tools_received: list[Any] | None = None
|
|
|
|
@overload
|
|
def run(
|
|
self,
|
|
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
|
*,
|
|
stream: Literal[False] = ...,
|
|
session: AgentSession | None = None,
|
|
**kwargs: Any,
|
|
) -> Awaitable[AgentResponse[Any]]: ...
|
|
|
|
@overload
|
|
def run(
|
|
self,
|
|
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
|
*,
|
|
stream: Literal[True],
|
|
session: AgentSession | None = None,
|
|
**kwargs: Any,
|
|
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
|
|
|
def run(
|
|
self,
|
|
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
|
*,
|
|
stream: bool = False,
|
|
session: AgentSession | None = None,
|
|
**kwargs: Any,
|
|
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
|
if stream:
|
|
|
|
async def _stream() -> AsyncIterator[AgentResponseUpdate]:
|
|
self.messages_received = [] if messages is None else list(messages) # type: ignore[arg-type]
|
|
self.tools_received = kwargs.get("tools")
|
|
for update in self.updates:
|
|
yield update
|
|
|
|
def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse:
|
|
return AgentResponse.from_updates(updates)
|
|
|
|
return ResponseStream(_stream(), finalizer=_finalize)
|
|
|
|
async def _get_response() -> AgentResponse[Any]:
|
|
return AgentResponse(messages=[], response_id="stub-response")
|
|
|
|
return _get_response()
|
|
|
|
def create_session(self, **kwargs: Any) -> AgentSession:
|
|
return AgentSession()
|
|
|
|
|
|
# Fixtures
|
|
|
|
|
|
@pytest.fixture
|
|
def streaming_chat_client_stub() -> type[SupportsChatGetResponse]:
|
|
"""Return the StreamingChatClientStub class for creating test instances."""
|
|
return StreamingChatClientStub # type: ignore[return-value]
|
|
|
|
|
|
@pytest.fixture
|
|
def stream_from_updates_fixture() -> Callable[[list[ChatResponseUpdate]], StreamFn]:
|
|
"""Return the stream_from_updates helper function."""
|
|
return stream_from_updates
|
|
|
|
|
|
@pytest.fixture
|
|
def stub_agent() -> type[SupportsAgentRun]:
|
|
"""Return the StubAgent class for creating test instances."""
|
|
return StubAgent # type: ignore[return-value]
|
|
|
|
|
|
# ── Fixtures for golden / integration tests ──
|
|
|
|
|
|
@pytest.fixture
|
|
def collect_events() -> Callable[..., Any]:
|
|
"""Return an async helper that collects all events from an async generator."""
|
|
|
|
async def _collect(async_gen: AsyncIterable[Any]) -> list[Any]:
|
|
return [event async for event in async_gen]
|
|
|
|
return _collect
|
|
|
|
|
|
@pytest.fixture
|
|
def make_agent_wrapper() -> Callable[..., Any]:
|
|
"""Factory that builds an AgentFrameworkAgent from a stream function.
|
|
|
|
Usage::
|
|
|
|
agent = make_agent_wrapper(
|
|
stream_fn=stream_from_updates(updates),
|
|
state_schema=...,
|
|
)
|
|
events = [e async for e in agent.run(payload)]
|
|
"""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
def _factory(
|
|
stream_fn: StreamFn,
|
|
*,
|
|
state_schema: Any | None = None,
|
|
predict_state_config: dict[str, dict[str, str]] | None = None,
|
|
require_confirmation: bool = True,
|
|
) -> Any:
|
|
client = StreamingChatClientStub(stream_fn)
|
|
stub = StubAgent(client=client)
|
|
return AgentFrameworkAgent(
|
|
agent=stub,
|
|
state_schema=state_schema,
|
|
predict_state_config=predict_state_config,
|
|
require_confirmation=require_confirmation,
|
|
)
|
|
|
|
return _factory
|
|
|
|
|
|
@pytest.fixture
|
|
def make_app() -> Callable[..., Any]:
|
|
"""Factory that builds a FastAPI app with an AG-UI endpoint.
|
|
|
|
Usage::
|
|
|
|
app = make_app(agent_or_wrapper, path="/test")
|
|
"""
|
|
from fastapi import FastAPI
|
|
|
|
from agent_framework_ag_ui import add_agent_framework_fastapi_endpoint
|
|
|
|
def _factory(
|
|
agent: Any,
|
|
*,
|
|
path: str = "/",
|
|
state_schema: Any | None = None,
|
|
predict_state_config: dict[str, dict[str, str]] | None = None,
|
|
default_state: dict[str, Any] | None = None,
|
|
) -> FastAPI:
|
|
app = FastAPI()
|
|
add_agent_framework_fastapi_endpoint(
|
|
app,
|
|
agent,
|
|
path=path,
|
|
state_schema=state_schema,
|
|
predict_state_config=predict_state_config,
|
|
default_state=default_state,
|
|
)
|
|
return app
|
|
|
|
return _factory
|