Files
Evan Mattson 0b50455e75 Python: Pass client thread_id as session_id when constructing AgentSession in AG-UI (#5384)
* Pass thread_id as session_id when constructing AgentSession in AG-UI

run_agent_stream() was constructing AgentSession without passing the
client's thread_id as session_id, causing every request to receive a
random UUID. This broke session continuity for HistoryProvider
implementations that rely on session_id matching the client's thread_id.

Pass session_id=thread_id in both the service-session and non-service
code paths so the session identity is consistent with the AG-UI client.

Fixes #5357

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Add test for service_session with no thread_id edge case (#5357)

When use_service_session=True but no thread_id/threadId is in the payload,
verify session_id is a generated UUID and service_session_id is None.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-04-22 17:45:25 +00:00

338 lines
11 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
"""Shared test fixtures and stubs for AG-UI tests."""
import sys
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Mapping, MutableSequence, Sequence
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Generic, Literal, cast, overload
import pytest
from agent_framework import (
AgentResponse,
AgentResponseUpdate,
AgentSession,
BaseChatClient,
ChatOptions,
ChatResponse,
ChatResponseUpdate,
Content,
Message,
SupportsAgentRun,
SupportsChatGetResponse,
)
from agent_framework._clients import OptionsCoT
from agent_framework._middleware import ChatMiddlewareLayer
from agent_framework._tools import FunctionInvocationLayer
from agent_framework._types import ResponseStream
from agent_framework.observability import ChatTelemetryLayer
if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
StreamFn = Callable[..., AsyncIterable[ChatResponseUpdate]]
ResponseFn = Callable[..., Awaitable[ChatResponse]]
def pytest_configure() -> None:
"""Ensure this test directory is on sys.path so helper modules can be imported by name."""
test_dir = str(Path(__file__).resolve().parent)
if test_dir not in sys.path:
sys.path.insert(0, test_dir)
class StreamingChatClientStub(
FunctionInvocationLayer[OptionsCoT],
ChatMiddlewareLayer[OptionsCoT],
ChatTelemetryLayer[OptionsCoT],
BaseChatClient[OptionsCoT],
Generic[OptionsCoT],
):
"""Typed streaming stub that satisfies SupportsChatGetResponse."""
def __init__(self, stream_fn: StreamFn, response_fn: ResponseFn | None = None) -> None:
super().__init__(middleware=[])
self._stream_fn = stream_fn
self._response_fn = response_fn
self.last_session: AgentSession | None = None
self.last_service_session_id: str | None = None
@overload
def get_response(
self,
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: ChatOptions[Any],
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]]: ...
@overload
def get_response(
self,
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: OptionsCoT | ChatOptions[None] | None = ...,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]]: ...
@overload
def get_response(
self,
messages: Sequence[Message],
*,
stream: Literal[True],
options: OptionsCoT | ChatOptions[Any] | None = ...,
**kwargs: Any,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
def get_response(
self,
messages: Sequence[Message],
*,
stream: bool = False,
options: OptionsCoT | ChatOptions[Any] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
client_kwargs = kwargs.get("client_kwargs")
if isinstance(client_kwargs, Mapping):
self.last_session = cast(AgentSession | None, client_kwargs.get("session"))
else:
self.last_session = None
self.last_service_session_id = self.last_session.service_session_id if self.last_session else None
return cast(
Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]],
super().get_response(
messages=messages,
stream=cast(Literal[True, False], stream),
options=options,
**kwargs,
),
)
@override
def _inner_get_response(
self,
*,
messages: Sequence[Message],
stream: bool = False,
options: Mapping[str, Any],
**kwargs: Any,
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
if stream:
def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse:
return ChatResponse.from_updates(updates)
return ResponseStream(self._stream_fn(messages, options, **kwargs), finalizer=_finalize)
return self._get_response_impl(messages, options, **kwargs)
async def _get_response_impl(
self, messages: Sequence[Message], options: Mapping[str, Any], **kwargs: Any
) -> ChatResponse:
"""Non-streaming implementation."""
if self._response_fn is not None:
return await self._response_fn(messages, options, **kwargs)
contents: list[Any] = []
async for update in self._stream_fn(list(messages), dict(options), **kwargs):
contents.extend(update.contents)
return ChatResponse(
messages=[Message(role="assistant", contents=contents)],
response_id="stub-response",
)
def stream_from_updates(updates: list[ChatResponseUpdate]) -> StreamFn:
"""Create a stream function that yields from a static list of updates."""
async def _stream(
messages: MutableSequence[Message], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
for update in updates:
yield update
return _stream
class StubAgent(SupportsAgentRun):
"""Minimal SupportsAgentRun stub for orchestrator tests."""
def __init__(
self,
updates: list[AgentResponseUpdate] | None = None,
*,
agent_id: str = "stub-agent",
agent_name: str | None = "stub-agent",
default_options: Any | None = None,
client: Any | None = None,
) -> None:
self.id = agent_id
self.name = agent_name
self.description = "stub agent"
self.updates = updates or [AgentResponseUpdate(contents=[Content.from_text(text="response")], role="assistant")]
self.default_options: dict[str, Any] = (
default_options if isinstance(default_options, dict) else {"tools": None, "response_format": None}
)
self.client = client or SimpleNamespace(function_invocation_configuration=None)
self.messages_received: list[Any] = []
self.tools_received: list[Any] | None = None
self.last_session: AgentSession | None = None
@overload
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: Literal[True],
session: AgentSession | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
if stream:
async def _stream() -> AsyncIterator[AgentResponseUpdate]:
self.messages_received = [] if messages is None else list(messages) # type: ignore[arg-type]
self.last_session = session
self.tools_received = kwargs.get("tools")
for update in self.updates:
yield update
def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse:
return AgentResponse.from_updates(updates)
return ResponseStream(_stream(), finalizer=_finalize)
async def _get_response() -> AgentResponse[Any]:
return AgentResponse(messages=[], response_id="stub-response")
return _get_response()
def create_session(self, **kwargs: Any) -> AgentSession:
return AgentSession()
# Fixtures
@pytest.fixture
def streaming_chat_client_stub() -> type[SupportsChatGetResponse]:
"""Return the StreamingChatClientStub class for creating test instances."""
return StreamingChatClientStub # type: ignore[return-value]
@pytest.fixture
def stream_from_updates_fixture() -> Callable[[list[ChatResponseUpdate]], StreamFn]:
"""Return the stream_from_updates helper function."""
return stream_from_updates
@pytest.fixture
def stub_agent() -> type[SupportsAgentRun]:
"""Return the StubAgent class for creating test instances."""
return StubAgent # type: ignore[return-value]
# ── Fixtures for golden / integration tests ──
@pytest.fixture
def collect_events() -> Callable[..., Any]:
"""Return an async helper that collects all events from an async generator."""
async def _collect(async_gen: AsyncIterable[Any]) -> list[Any]:
return [event async for event in async_gen]
return _collect
@pytest.fixture
def make_agent_wrapper() -> Callable[..., Any]:
"""Factory that builds an AgentFrameworkAgent from a stream function.
Usage::
agent = make_agent_wrapper(
stream_fn=stream_from_updates(updates),
state_schema=...,
)
events = [e async for e in agent.run(payload)]
"""
from agent_framework_ag_ui import AgentFrameworkAgent
def _factory(
stream_fn: StreamFn,
*,
state_schema: Any | None = None,
predict_state_config: dict[str, dict[str, str]] | None = None,
require_confirmation: bool = True,
) -> Any:
client = StreamingChatClientStub(stream_fn)
stub = StubAgent(client=client)
return AgentFrameworkAgent(
agent=stub,
state_schema=state_schema,
predict_state_config=predict_state_config,
require_confirmation=require_confirmation,
)
return _factory
@pytest.fixture
def make_app() -> Callable[..., Any]:
"""Factory that builds a FastAPI app with an AG-UI endpoint.
Usage::
app = make_app(agent_or_wrapper, path="/test")
"""
from fastapi import FastAPI
from agent_framework_ag_ui import add_agent_framework_fastapi_endpoint
def _factory(
agent: Any,
*,
path: str = "/",
state_schema: Any | None = None,
predict_state_config: dict[str, dict[str, str]] | None = None,
default_state: dict[str, Any] | None = None,
) -> FastAPI:
app = FastAPI()
add_agent_framework_fastapi_endpoint(
app,
agent,
path=path,
state_schema=state_schema,
predict_state_config=predict_state_config,
default_state=default_state,
)
return app
return _factory