Python: [BREAKING] update context provider APIs, middleware, and per-service-call history persistence (#4992)

* Rename provider base APIs

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Allow provider-added chat and function middleware

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Simulate service-stored history per model call

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix typing regressions in CI

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix response ID suppression review feedback

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Rename per-service-call history persistence APIs

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Address context persistence review feedback

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Stabilize markdown sample docs

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Persist service continuation state per call

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Eduard van Valkenburg
2026-04-01 18:13:11 +02:00
committed by GitHub
Unverified
parent 38de991481
commit b065a4ce51
37 changed files with 1836 additions and 396 deletions
@@ -35,9 +35,9 @@ from agent_framework import (
AgentResponseUpdate,
AgentSession,
BaseAgent,
BaseHistoryProvider,
Content,
ContinuationToken,
HistoryProvider,
Message,
ResponseStream,
SessionContext,
@@ -353,7 +353,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
# Run before_run providers (forward order)
for provider in self.context_providers:
if isinstance(provider, BaseHistoryProvider) and not provider.load_messages:
if isinstance(provider, HistoryProvider) and not provider.load_messages:
continue
if session is None:
raise RuntimeError("Provider session must be available when context providers are configured.")
+2 -2
View File
@@ -24,8 +24,8 @@ from agent_framework import (
AgentResponse,
AgentResponseUpdate,
AgentSession,
BaseContextProvider,
Content,
ContextProvider,
Message,
SessionContext,
)
@@ -869,7 +869,7 @@ async def test_poll_task_completed(a2a_agent: A2AAgent, mock_a2a_client: MockA2A
# region Context Provider Tests
class TrackingContextProvider(BaseContextProvider):
class TrackingContextProvider(ContextProvider):
"""A context provider that records when before_run and after_run are called."""
def __init__(self) -> None:
@@ -1,9 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.
"""New-pattern Azure AI Search context provider using BaseContextProvider.
"""New-pattern Azure AI Search context provider using ContextProvider.
This module provides ``AzureAISearchContextProvider``, built on the new
:class:`BaseContextProvider` hooks pattern.
:class:`ContextProvider` hooks pattern.
"""
from __future__ import annotations
@@ -17,8 +17,8 @@ from agent_framework import (
AGENT_FRAMEWORK_USER_AGENT,
AgentSession,
Annotation,
BaseContextProvider,
Content,
ContextProvider,
Message,
SecretString,
SessionContext,
@@ -154,8 +154,8 @@ class AzureAISearchSettings(TypedDict, total=False):
api_key: SecretString | None
class AzureAISearchContextProvider(BaseContextProvider):
"""Azure AI Search context provider using the new BaseContextProvider hooks pattern.
class AzureAISearchContextProvider(ContextProvider):
"""Azure AI Search context provider using the new ContextProvider hooks pattern.
Retrieves relevant context from Azure AI Search using semantic or agentic search
modes.
@@ -11,7 +11,7 @@ from collections.abc import Sequence
from typing import Any, ClassVar, TypedDict
from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message
from agent_framework._sessions import BaseHistoryProvider
from agent_framework._sessions import HistoryProvider
from agent_framework._settings import SecretString, load_settings
from azure.core.credentials import TokenCredential
from azure.core.credentials_async import AsyncTokenCredential
@@ -32,8 +32,8 @@ class AzureCosmosHistorySettings(TypedDict, total=False):
key: SecretString | None
class CosmosHistoryProvider(BaseHistoryProvider):
"""Azure Cosmos DB-backed history provider using BaseHistoryProvider hooks."""
class CosmosHistoryProvider(HistoryProvider):
"""Azure Cosmos DB-backed history provider using HistoryProvider hooks."""
DEFAULT_SOURCE_ID: ClassVar[str] = "azure_cosmos_history"
_BATCH_OPERATION_LIMIT: ClassVar[int] = 100
@@ -16,8 +16,8 @@ from agent_framework import (
AgentRunInputs,
AgentSession,
BaseAgent,
BaseContextProvider,
Content,
ContextProvider,
FunctionTool,
Message,
ResponseStream,
@@ -223,7 +223,7 @@ class RawClaudeAgent(BaseAgent, Generic[OptionsT]):
id: str | None = None,
name: str | None = None,
description: str | None = None,
context_providers: Sequence[BaseContextProvider] | None = None,
context_providers: Sequence[ContextProvider] | None = None,
middleware: Sequence[AgentMiddlewareTypes] | None = None,
tools: ToolTypes | Callable[..., Any] | str | Sequence[ToolTypes | Callable[..., Any] | str] | None = None,
default_options: OptionsT | MutableMapping[str, Any] | None = None,
@@ -11,8 +11,8 @@ from agent_framework import (
AgentResponseUpdate,
AgentSession,
BaseAgent,
BaseContextProvider,
Content,
ContextProvider,
Message,
ResponseStream,
normalize_messages,
@@ -60,7 +60,7 @@ class CopilotStudioAgent(BaseAgent):
id: str | None = None,
name: str | None = None,
description: str | None = None,
context_providers: Sequence[BaseContextProvider] | None = None,
context_providers: Sequence[ContextProvider] | None = None,
middleware: list[AgentMiddlewareTypes] | None = None,
environment_id: str | None = None,
agent_identifier: str | None = None,
+3 -3
View File
@@ -61,8 +61,8 @@ agent_framework/
- **`AgentSession`** - Manages conversation state and session metadata
- **`SessionContext`** - Context object for session-scoped data during agent runs
- **`BaseContextProvider`** - Base class for context providers (RAG, memory systems)
- **`BaseHistoryProvider`** - Base class for conversation history storage
- **`ContextProvider`** - Base class for context providers (RAG, memory systems)
- **`HistoryProvider`** - Base class for conversation history storage
### Skills (`_skills.py`)
@@ -70,7 +70,7 @@ agent_framework/
- **`SkillResource`** - Named supplementary content attached to a skill; holds either static `content` or a dynamic `function` (sync or async). Exactly one must be provided.
- **`SkillScript`** - An executable script attached to a skill; holds either an inline `function` (code-defined, runs in-process) or a `path` to a file on disk (file-based, delegated to a runner). Exactly one must be provided.
- **`SkillScriptRunner`** - Protocol for file-based script execution. Any callable matching `(skill, script, args) -> Any` satisfies it. Code-defined scripts do not use a runner.
- **`SkillsProvider`** - Context provider (extends `BaseContextProvider`) that discovers file-based skills from `SKILL.md` files and/or accepts code-defined `Skill` instances. Follows progressive disclosure: advertise → load → read resources / run scripts.
- **`SkillsProvider`** - Context provider (extends `ContextProvider`) that discovers file-based skills from `SKILL.md` files and/or accepts code-defined `Skill` instances. Follows progressive disclosure: advertise → load → read resources / run scripts.
### Workflows (`_workflows/`)
@@ -102,8 +102,10 @@ from ._middleware import (
)
from ._sessions import (
AgentSession,
BaseContextProvider,
BaseHistoryProvider,
BaseContextProvider, # type: ignore[reportDeprecated]
BaseHistoryProvider, # type: ignore[reportDeprecated]
ContextProvider,
HistoryProvider,
InMemoryHistoryProvider,
SessionContext,
register_state_type,
@@ -296,6 +298,7 @@ __all__ = [
"CompactionProvider",
"CompactionStrategy",
"Content",
"ContextProvider",
"ContinuationToken",
"ConversationSplit",
"ConversationSplitter",
@@ -331,6 +334,7 @@ __all__ = [
"FunctionTool",
"GeneratedEmbeddings",
"GraphConnectivityError",
"HistoryProvider",
"InMemoryCheckpointStorage",
"InMemoryHistoryProvider",
"InProcRunnerContext",
+275 -137
View File
@@ -29,14 +29,16 @@ from . import _tools as _tool_utils # pyright: ignore[reportPrivateUsage]
from ._clients import BaseChatClient, SupportsChatGetResponse
from ._docstrings import apply_layered_docstring
from ._mcp import LOG_LEVEL_MAPPING, MCPTool
from ._middleware import AgentMiddlewareLayer, FunctionInvocationContext, MiddlewareTypes
from ._middleware import AgentMiddlewareLayer, FunctionInvocationContext, MiddlewareTypes, categorize_middleware
from ._serialization import SerializationMixin
from ._sessions import (
AgentSession,
BaseContextProvider,
BaseHistoryProvider,
ContextProvider,
HistoryProvider,
InMemoryHistoryProvider,
PerServiceCallHistoryPersistingMiddleware,
SessionContext,
is_local_history_conversation_id,
)
from ._tools import FunctionInvocationLayer, FunctionTool, ToolTypes, normalize_tools
from ._types import (
@@ -50,7 +52,7 @@ from ._types import (
map_chat_to_agent_update,
normalize_messages,
)
from .exceptions import AgentInvalidResponseException, UserInputRequiredException
from .exceptions import AgentInvalidRequestException, AgentInvalidResponseException, UserInputRequiredException
from .observability import AgentTelemetryLayer
if sys.version_info >= (3, 13):
@@ -166,6 +168,7 @@ class _RunContext(TypedDict):
input_messages: Sequence[Message]
session_messages: Sequence[Message]
agent_name: str
suppress_response_id: bool
chat_options: MutableMapping[str, Any]
compaction_strategy: CompactionStrategy | None
tokenizer: TokenizerProtocol | None
@@ -366,6 +369,7 @@ class BaseAgent(SerializationMixin):
"""
DEFAULT_EXCLUDE: ClassVar[set[str]] = {"additional_properties"}
require_per_service_call_history_persistence: bool = False
def __init__(
self,
@@ -373,7 +377,7 @@ class BaseAgent(SerializationMixin):
id: str | None = None,
name: str | None = None,
description: str | None = None,
context_providers: Sequence[BaseContextProvider] | None = None,
context_providers: Sequence[ContextProvider] | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
additional_properties: MutableMapping[str, Any] | None = None,
) -> None:
@@ -393,7 +397,7 @@ class BaseAgent(SerializationMixin):
self.id = id
self.name = name
self.description = description
self.context_providers: list[BaseContextProvider] = list(context_providers or [])
self.context_providers: list[ContextProvider] = list(context_providers or [])
self.middleware: list[MiddlewareTypes] | None = (
cast(list[MiddlewareTypes], middleware) if middleware is not None else None
)
@@ -455,7 +459,12 @@ class BaseAgent(SerializationMixin):
if provider_session is None and self.context_providers:
provider_session = AgentSession()
per_service_call_history_required = self.require_per_service_call_history_persistence and any(
isinstance(provider, HistoryProvider) for provider in self.context_providers
)
for provider in reversed(self.context_providers):
if per_service_call_history_required and isinstance(provider, HistoryProvider):
continue
if provider_session is None:
raise RuntimeError("Provider session must be available when context providers are configured.")
await provider.after_run(
@@ -656,8 +665,9 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
description: str | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
default_options: OptionsCoT | None = None,
context_providers: Sequence[BaseContextProvider] | None = None,
context_providers: Sequence[ContextProvider] | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
require_per_service_call_history_persistence: bool = False,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
additional_properties: MutableMapping[str, Any] | None = None,
@@ -675,6 +685,11 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
description: A brief description of the agent's purpose.
context_providers: Context providers to include during agent invocation.
middleware: List of middleware to intercept agent and function invocations.
require_per_service_call_history_persistence: When True, history providers are invoked
around each model call instead of once per ``run()`` when the service
is not already storing history. If service-side storage is active for
the run, the agent skips local history providers and relies on the
service-managed conversation instead.
default_options: A TypedDict containing chat options. When using a typed agent like
``Agent[OpenAIChatOptions]``, this enables IDE autocomplete for
provider-specific options including temperature, max_tokens, model_id,
@@ -706,6 +721,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
)
self.client = client
self.compaction_strategy = compaction_strategy
self.require_per_service_call_history_persistence = require_per_service_call_history_persistence
self.tokenizer = tokenizer
# Get tools from options or named parameter (named param takes precedence)
@@ -764,6 +780,35 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
await self._async_exit_stack.enter_async_context(context_manager)
return self
def _get_history_providers(self) -> list[HistoryProvider]:
return [provider for provider in self.context_providers if isinstance(provider, HistoryProvider)]
def _resolve_per_service_call_history_providers(
self,
*,
session: AgentSession | None,
options: Mapping[str, Any] | None,
service_stores_history: bool,
) -> list[HistoryProvider]:
history_providers = self._get_history_providers()
if not self.require_per_service_call_history_persistence or not history_providers:
return []
conversation_id = (
session.service_session_id
if session and session.service_session_id
else cast(str | None, (options or {}).get("conversation_id") or self.default_options.get("conversation_id"))
)
if service_stores_history:
return []
if conversation_id is not None:
raise AgentInvalidRequestException(
"require_per_service_call_history_persistence cannot be used "
"with an existing service-managed conversation."
)
return history_providers
async def __aexit__(
self,
exc_type: type[BaseException] | None,
@@ -885,97 +930,9 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
When stream=True: A ResponseStream of AgentResponseUpdate items with
``get_final_response()`` for the final AgentResponse.
"""
if not stream:
async def _run_non_streaming() -> AgentResponse[Any]:
ctx = await self._prepare_run_context(
messages=messages,
session=session,
tools=tools,
options=options,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
)
response = cast(
ChatResponse[Any],
await self.client.get_response( # type: ignore
messages=ctx["session_messages"],
stream=False,
options=ctx["chat_options"], # type: ignore[reportArgumentType]
compaction_strategy=ctx["compaction_strategy"],
tokenizer=ctx["tokenizer"],
function_invocation_kwargs=ctx["function_invocation_kwargs"],
client_kwargs=ctx["client_kwargs"],
),
)
if not response:
raise AgentInvalidResponseException("Chat client did not return a response.")
await self._finalize_response(
response=response,
agent_name=ctx["agent_name"],
session=ctx["session"],
session_context=ctx["session_context"],
)
response_format = ctx["chat_options"].get("response_format")
if not (
response_format is not None
and isinstance(response_format, type)
and issubclass(response_format, BaseModel)
):
response_format = None
return AgentResponse(
messages=response.messages,
response_id=response.response_id,
created_at=response.created_at,
usage_details=response.usage_details,
value=response.value,
response_format=response_format,
continuation_token=response.continuation_token,
raw_representation=response,
additional_properties=response.additional_properties,
)
return _run_non_streaming()
# Use a holder to capture the context created during stream initialization
ctx_holder: dict[str, _RunContext | None] = {"ctx": None}
async def _post_hook(response: AgentResponse) -> None:
ctx = ctx_holder["ctx"]
if ctx is None:
return # No context available (shouldn't happen in normal flow)
# Update thread with conversation_id derived from streaming raw updates.
# Using response_id here can break function-call continuation for APIs
# where response IDs are not valid conversation handles.
conversation_id = self._extract_conversation_id_from_streaming_response(response)
# Ensure author names are set for all messages
for message in response.messages:
if message.author_name is None:
message.author_name = ctx["agent_name"]
# Propagate conversation_id back to session from streaming updates.
# For Responses-style APIs this can rotate every turn (response_id-based continuation),
# so refresh when a newer value is returned.
sess = ctx["session"]
if sess and conversation_id and sess.service_session_id != conversation_id:
sess.service_session_id = conversation_id
# Run after_run providers (reverse order)
session_context = ctx["session_context"]
session_context._response = AgentResponse( # type: ignore[assignment]
messages=response.messages,
response_id=response.response_id,
)
await self._run_after_providers(session=ctx["session"], context=session_context)
async def _get_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
ctx_holder["ctx"] = await self._prepare_run_context(
async def _prepare_run_context() -> _RunContext:
return await self._prepare_run_context(
messages=messages,
session=session,
tools=tools,
@@ -985,55 +942,177 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
)
ctx: _RunContext = ctx_holder["ctx"] # type: ignore[assignment] # Safe: we just assigned it
if not stream:
async def _run_non_streaming() -> AgentResponse[Any]:
ctx = await _prepare_run_context()
response = await self._call_chat_client(ctx, stream=False)
return await self._parse_non_streaming_response(ctx, response)
return _run_non_streaming()
async def _run_streaming() -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
ctx = await _prepare_run_context()
stream_response = self._call_chat_client(ctx, stream=True)
return self._parse_streaming_response(ctx, stream_response)
return cast(
ResponseStream[AgentResponseUpdate, AgentResponse[Any]],
cast(Any, ResponseStream).from_awaitable(_run_streaming()),
)
@overload
def _call_chat_client(
self,
context: _RunContext,
*,
stream: Literal[False],
) -> Awaitable[ChatResponse[Any]]: ...
@overload
def _call_chat_client(
self,
context: _RunContext,
*,
stream: Literal[True],
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
def _call_chat_client(
self,
context: _RunContext,
*,
stream: bool,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
"""Invoke the downstream chat client for a prepared run context."""
if stream:
return self.client.get_response( # type: ignore[call-overload, no-any-return]
messages=ctx["session_messages"],
messages=context["session_messages"],
stream=True,
options=ctx["chat_options"], # type: ignore[reportArgumentType]
compaction_strategy=ctx["compaction_strategy"],
tokenizer=ctx["tokenizer"],
function_invocation_kwargs=ctx["function_invocation_kwargs"],
client_kwargs=ctx["client_kwargs"],
options=context["chat_options"], # type: ignore[reportArgumentType]
compaction_strategy=context["compaction_strategy"],
tokenizer=context["tokenizer"],
function_invocation_kwargs=context["function_invocation_kwargs"],
client_kwargs=context["client_kwargs"],
)
def _propagate_conversation_id(
update: AgentResponseUpdate,
) -> AgentResponseUpdate:
"""Eagerly propagate conversation_id to session as updates arrive.
return self.client.get_response( # type: ignore[call-overload, no-any-return]
messages=context["session_messages"],
stream=False,
options=context["chat_options"], # type: ignore[reportArgumentType]
compaction_strategy=context["compaction_strategy"],
tokenizer=context["tokenizer"],
function_invocation_kwargs=context["function_invocation_kwargs"],
client_kwargs=context["client_kwargs"],
)
This ensures session.service_session_id is set even when the user
only iterates the stream without calling get_final_response().
"""
async def _parse_non_streaming_response(
self,
context: _RunContext,
response: ChatResponse[Any],
) -> AgentResponse[Any]:
"""Finalize a non-streaming chat response into an AgentResponse."""
if not response:
raise AgentInvalidResponseException("Chat client did not return a response.")
await self._finalize_response(
response=response,
agent_name=context["agent_name"],
session=context["session"],
session_context=context["session_context"],
suppress_response_id=context["suppress_response_id"],
)
response_format = context["chat_options"].get("response_format")
if not (
response_format is not None and isinstance(response_format, type) and issubclass(response_format, BaseModel)
):
response_format = None
return AgentResponse(
messages=response.messages,
response_id=None if context["suppress_response_id"] else response.response_id,
created_at=response.created_at,
usage_details=response.usage_details,
value=response.value,
response_format=response_format,
continuation_token=response.continuation_token,
raw_representation=response,
additional_properties=response.additional_properties,
)
def _parse_streaming_response(
self,
context: _RunContext,
stream_response: ResponseStream[ChatResponseUpdate, ChatResponse[Any]],
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Finalize a streaming chat response into an agent response stream."""
async def _post_hook(response: AgentResponse) -> None:
# Update thread with conversation_id derived from streaming raw updates.
# Using response_id here can break function-call continuation for APIs
# where response IDs are not valid conversation handles.
conversation_id = self._extract_conversation_id_from_streaming_response(response)
for message in response.messages:
if message.author_name is None:
message.author_name = context["agent_name"]
session = context["session"]
if (
session
and conversation_id
and not is_local_history_conversation_id(conversation_id)
and session.service_session_id != conversation_id
):
session.service_session_id = conversation_id
suppress_response_id = context["suppress_response_id"]
session_context = context["session_context"]
session_context._response = AgentResponse( # type: ignore[assignment]
messages=response.messages,
response_id=None if suppress_response_id else response.response_id,
)
await self._run_after_providers(session=session, context=session_context)
def _propagate_conversation_id(update: AgentResponseUpdate) -> AgentResponseUpdate:
"""Eagerly propagate conversation_id to session as updates arrive."""
session = context["session"]
if session is None:
return update
raw = update.raw_representation
conv_id = getattr(raw, "conversation_id", None) if raw else None
if isinstance(conv_id, str) and conv_id and session.service_session_id != conv_id:
session.service_session_id = conv_id
conversation_id = getattr(raw, "conversation_id", None) if raw else None
if (
isinstance(conversation_id, str)
and conversation_id
and not is_local_history_conversation_id(conversation_id)
and session.service_session_id != conversation_id
):
session.service_session_id = conversation_id
return update
def _suppress_response_id(update: AgentResponseUpdate) -> AgentResponseUpdate:
"""Hide raw service response ids when local per-service-call persistence owns continuation."""
update.response_id = None
return update
def _finalizer(updates: Sequence[AgentResponseUpdate]) -> AgentResponse[Any]:
ctx = ctx_holder["ctx"]
rf = (
ctx.get("chat_options", {}).get("response_format")
if ctx
else (options.get("response_format") if options else None) # type: ignore[union-attr]
return self._finalize_response_updates(
updates,
response_format=context["chat_options"].get("response_format"),
)
return self._finalize_response_updates(updates, response_format=rf)
return (
ResponseStream
.from_awaitable(_get_stream()) # type: ignore[reportUnknownMemberType]
.map(
transform=partial(
map_chat_to_agent_update,
agent_name=self.name,
),
finalizer=_finalizer,
)
.with_transform_hook(_propagate_conversation_id)
.with_result_hook(_post_hook)
stream = stream_response.map(
transform=partial(
map_chat_to_agent_update,
agent_name=self.name,
),
finalizer=_finalizer,
)
if context["suppress_response_id"]:
stream = stream.with_transform_hook(_suppress_response_id)
return stream.with_transform_hook(_propagate_conversation_id).with_result_hook(_post_hook)
def _finalize_response_updates(
self,
@@ -1111,6 +1190,12 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
if active_session is None and self.context_providers:
active_session = AgentSession()
per_service_call_history_providers = self._resolve_per_service_call_history_providers(
session=active_session,
options=opts,
service_stores_history=bool(store_),
)
session_context, chat_options = await self._prepare_session_and_messages(
session=active_session,
input_messages=input_messages,
@@ -1191,6 +1276,43 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
if active_session is not None:
effective_client_kwargs["session"] = active_session
if per_service_call_history_providers and active_session is not None:
per_service_call_history_middleware = PerServiceCallHistoryPersistingMiddleware(
agent=self,
session=active_session,
providers=per_service_call_history_providers,
)
existing_middleware = effective_client_kwargs.get("middleware")
if isinstance(existing_middleware, Sequence) and not isinstance(existing_middleware, (str, bytes)):
effective_client_kwargs["middleware"] = [per_service_call_history_middleware, *existing_middleware]
elif existing_middleware is not None:
effective_client_kwargs["middleware"] = [
per_service_call_history_middleware,
cast(MiddlewareTypes, existing_middleware),
]
else:
effective_client_kwargs["middleware"] = [per_service_call_history_middleware]
provider_middleware = session_context.get_middleware()
if provider_middleware:
middleware_list = categorize_middleware(provider_middleware)
provider_function_chat_middleware = [
*middleware_list["function"],
*middleware_list["chat"],
]
if provider_function_chat_middleware:
existing_middleware = effective_client_kwargs.get("middleware")
if isinstance(existing_middleware, Sequence) and not isinstance(existing_middleware, (str, bytes)):
effective_client_kwargs["middleware"] = [
*existing_middleware,
*provider_function_chat_middleware,
]
elif existing_middleware is not None:
effective_client_kwargs["middleware"] = [
cast(MiddlewareTypes, existing_middleware),
*provider_function_chat_middleware,
]
else:
effective_client_kwargs["middleware"] = provider_function_chat_middleware
return {
"session": active_session,
@@ -1198,6 +1320,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
"input_messages": input_messages,
"session_messages": session_messages,
"agent_name": agent_name,
"suppress_response_id": bool(per_service_call_history_providers),
"chat_options": co,
"compaction_strategy": compaction_strategy or self.compaction_strategy,
"tokenizer": tokenizer or self.tokenizer,
@@ -1211,6 +1334,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
agent_name: str,
session: AgentSession | None,
session_context: SessionContext,
suppress_response_id: bool = False,
) -> None:
"""Finalize response by setting author names and running after_run providers.
@@ -1219,6 +1343,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
agent_name: The name of the agent to set as author.
session: The conversation session.
session_context: The invocation context.
suppress_response_id: When True, omit the raw service response ID from the public response.
"""
# Ensure that the author name is set for each message in the response.
for message in response.messages:
@@ -1228,13 +1353,18 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
# Propagate conversation_id back to session (e.g. thread ID from Assistants API).
# For Responses-style APIs this can rotate every turn (response_id-based continuation),
# so refresh when a newer value is returned.
if session and response.conversation_id and session.service_session_id != response.conversation_id:
if (
session
and response.conversation_id
and not is_local_history_conversation_id(response.conversation_id)
and session.service_session_id != response.conversation_id
):
session.service_session_id = response.conversation_id
# Set the response on the context for after_run providers
session_context._response = AgentResponse( # type: ignore[assignment]
messages=response.messages,
response_id=response.response_id,
response_id=None if suppress_response_id else response.response_id,
)
# Run after_run providers (reverse order)
@@ -1284,9 +1414,15 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
options=options or {},
)
# Run before_run providers (forward order, skip BaseHistoryProvider with load_messages=False)
per_service_call_history_required = self.require_per_service_call_history_persistence and bool(
self._get_history_providers()
)
# Run before_run providers (forward order, skip HistoryProvider when per-service-call persistence owns history)
for provider in self.context_providers:
if isinstance(provider, BaseHistoryProvider) and not provider.load_messages:
if per_service_call_history_required and isinstance(provider, HistoryProvider):
continue
if isinstance(provider, HistoryProvider) and not provider.load_messages:
continue
if provider_session is None:
raise RuntimeError("Provider session must be available when context providers are configured.")
@@ -1551,8 +1687,9 @@ class Agent(
description: str | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
default_options: OptionsCoT | None = None,
context_providers: Sequence[BaseContextProvider] | None = None,
context_providers: Sequence[ContextProvider] | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
require_per_service_call_history_persistence: bool = False,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
additional_properties: MutableMapping[str, Any] | None = None,
@@ -1568,6 +1705,7 @@ class Agent(
default_options=default_options,
context_providers=context_providers,
middleware=middleware,
require_per_service_call_history_persistence=require_per_service_call_history_persistence,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
additional_properties=additional_properties,
@@ -572,6 +572,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
default_options: OptionsCoT | Mapping[str, Any] | None = None,
context_providers: Sequence[Any] | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
require_per_service_call_history_persistence: bool = False,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
@@ -596,6 +597,10 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
and dict literals are accepted without specialized option typing.
context_providers: Context providers to include during agent invocation.
middleware: List of middleware to intercept agent and function invocations.
require_per_service_call_history_persistence: Whether to require per-service-call
chat history persistence. When enabled, history providers are invoked around
each model call instead of once per ``run()`` when the service is not already
storing history.
function_invocation_configuration: Optional function invocation configuration override.
compaction_strategy: Optional agent-level compaction override. When omitted,
client-level compaction defaults remain in effect for each call.
@@ -636,6 +641,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
"default_options": cast(Any, default_options),
"context_providers": context_providers,
"middleware": middleware,
"require_per_service_call_history_persistence": require_per_service_call_history_persistence,
"compaction_strategy": compaction_strategy,
"tokenizer": tokenizer,
"additional_properties": dict(additional_properties) if additional_properties is not None else None,
@@ -15,7 +15,7 @@ from typing import (
runtime_checkable,
)
from ._sessions import BaseContextProvider
from ._sessions import ContextProvider
from ._types import ChatResponse, Content, Message
if TYPE_CHECKING:
@@ -1152,7 +1152,7 @@ async def apply_compaction(
COMPACTION_STATE_KEY: Final[str] = "_compaction_messages"
class CompactionProvider(BaseContextProvider):
class CompactionProvider(ContextProvider):
"""Context provider that compacts messages before and after agent runs.
This provider accepts two separate strategies:
+274 -19
View File
@@ -4,8 +4,8 @@
This module provides the core types for the context provider pipeline:
- SessionContext: Per-invocation state passed through providers
- BaseContextProvider: Base class for context providers (renamed to ContextProvider in PR2)
- BaseHistoryProvider: Base class for history storage providers (renamed to HistoryProvider in PR2)
- ContextProvider: Base class for context providers
- HistoryProvider: Base class for history storage providers
- AgentSession: Lightweight session state container
- InMemoryHistoryProvider: Built-in in-memory history provider
"""
@@ -13,21 +13,42 @@ This module provides the core types for the context provider pipeline:
from __future__ import annotations
import copy
import sys
import uuid
from abc import abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, ClassVar, cast
from collections.abc import Awaitable, Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, TypeGuard, cast
from ._types import AgentResponse, Message
if sys.version_info >= (3, 13):
from warnings import deprecated # type: ignore # pragma: no cover
else:
from typing_extensions import deprecated # type: ignore # pragma: no cover
from ._middleware import ChatContext, ChatMiddleware
from ._types import AgentResponse, ChatResponse, Message, ResponseStream
from .exceptions import ChatClientInvalidResponseException
if TYPE_CHECKING:
from ._agents import SupportsAgentRun
from ._middleware import MiddlewareTypes
# Registry of known types for state deserialization
_STATE_TYPE_REGISTRY: dict[str, type] = {}
def _is_middleware_sequence(
middleware: MiddlewareTypes | Sequence[MiddlewareTypes],
) -> TypeGuard[Sequence[MiddlewareTypes]]:
return isinstance(middleware, Sequence) and not isinstance(middleware, (str, bytes))
def _is_single_middleware(
middleware: MiddlewareTypes | Sequence[MiddlewareTypes],
) -> TypeGuard[MiddlewareTypes]:
return not _is_middleware_sequence(middleware)
def register_state_type(cls: type) -> None:
"""Register a type for automatic deserialization in session state.
@@ -131,6 +152,8 @@ class SessionContext:
Maintains insertion order (provider execution order).
instructions: Additional instructions added by providers.
tools: Additional tools added by providers.
middleware: Dict mapping source_id -> chat/function middleware added by that provider.
Maintains insertion order (provider execution order).
response: After invocation, contains the full AgentResponse, should not be changed.
options: Options passed to agent.run() - read-only, for reflection only.
metadata: Shared metadata dictionary for cross-provider communication.
@@ -145,6 +168,7 @@ class SessionContext:
context_messages: dict[str, list[Message]] | None = None,
instructions: list[str] | None = None,
tools: list[Any] | None = None,
middleware: dict[str, list[MiddlewareTypes]] | None = None,
options: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
):
@@ -157,6 +181,7 @@ class SessionContext:
context_messages: Pre-populated context messages by source.
instructions: Pre-populated instructions.
tools: Pre-populated tools.
middleware: Pre-populated chat/function middleware by source.
options: Options from agent.run() - read-only for providers.
metadata: Shared metadata for cross-provider communication.
"""
@@ -166,6 +191,10 @@ class SessionContext:
self.context_messages: dict[str, list[Message]] = context_messages or {}
self.instructions: list[str] = instructions or []
self.tools: list[Any] = tools or []
self.middleware: dict[str, list[MiddlewareTypes]] = {}
if middleware:
for source_id, provider_middleware in middleware.items():
self.extend_middleware(source_id, provider_middleware)
self._response: AgentResponse | None = None
self.options: dict[str, Any] = options or {}
self.metadata: dict[str, Any] = metadata or {}
@@ -236,6 +265,40 @@ class SessionContext:
additional_properties["context_source"] = source_id
self.tools.extend(tools)
def extend_middleware(
self,
source_id: str,
middleware: MiddlewareTypes | Sequence[MiddlewareTypes],
) -> None:
"""Add middleware to be applied for this invocation.
Args:
source_id: The provider source_id adding this middleware.
middleware: A single chat/function middleware object/callable or sequence of middleware.
"""
from ._middleware import categorize_middleware
from .exceptions import MiddlewareException
if _is_middleware_sequence(middleware):
middleware_items = list(middleware)
elif _is_single_middleware(middleware):
middleware_items = [middleware]
else:
raise TypeError("middleware must be a middleware object or a sequence of middleware objects.")
middleware_list = categorize_middleware(middleware_items)
if middleware_list["agent"]:
raise MiddlewareException("Context providers may only add chat or function middleware.")
if source_id not in self.middleware:
self.middleware[source_id] = []
self.middleware[source_id].extend(middleware_items)
def get_middleware(self) -> list[MiddlewareTypes]:
"""Get provider-added chat/function middleware in provider execution order."""
result: list[MiddlewareTypes] = []
for middleware_items in self.middleware.values():
result.extend(middleware_items)
return result
def get_messages(
self,
*,
@@ -272,17 +335,12 @@ class SessionContext:
return result
class BaseContextProvider:
"""Base class for context providers (hooks pattern).
class ContextProvider:
"""Base class for context providers.
Context providers participate in the context engineering pipeline,
adding context before model invocation and processing responses after.
Note:
This class uses a temporary name prefixed with ``_`` to avoid collision
with the existing ``ContextProvider`` in ``_memory.py``. It will be
renamed to ``ContextProvider`` in PR2 when the old class is removed.
Attributes:
source_id: Unique identifier for this provider instance (required).
Used for message/tool attribution so other providers can filter.
@@ -312,7 +370,7 @@ class BaseContextProvider:
Args:
agent: The agent running this invocation.
session: The current session.
context: The invocation context - add messages/instructions/tools here.
context: The invocation context - add messages/instructions/tools/chat/function middleware here.
state: The provider-scoped mutable state dict for this provider.
Full cross-provider state remains available at ``session.state``.
"""
@@ -339,7 +397,7 @@ class BaseContextProvider:
"""
class BaseHistoryProvider(BaseContextProvider):
class HistoryProvider(ContextProvider):
"""Base class for conversation history storage providers.
A single class configurable for different use cases:
@@ -347,10 +405,6 @@ class BaseHistoryProvider(BaseContextProvider):
- Audit/logging storage (stores only, doesn't load)
- Evaluation storage (stores only for later analysis)
Note:
This class uses a temporary name prefixed with ``_`` to avoid collision
with existing types. It will be renamed to ``HistoryProvider`` in PR2.
Subclasses only need to implement ``get_messages()`` and ``save_messages()``.
The default ``before_run``/``after_run`` handle loading and storing based on
configuration flags. Override them for custom behavior.
@@ -467,6 +521,207 @@ class BaseHistoryProvider(BaseContextProvider):
await self.save_messages(context.session_id, messages_to_store, state=state)
LOCAL_HISTORY_CONVERSATION_ID = "agent_framework_local_history_persistence"
def is_local_history_conversation_id(conversation_id: str | None) -> bool:
"""Return whether a conversation id is the local history-persistence sentinel."""
return conversation_id == LOCAL_HISTORY_CONVERSATION_ID
def _response_contains_follow_up_request(response: ChatResponse) -> bool:
"""Return whether a response requires another model call in the current run."""
return any(
item.type in {"function_call", "function_approval_request"}
for message in response.messages
for item in message.contents
)
def _split_service_call_messages(messages: Sequence[Message]) -> tuple[list[Message], dict[str, list[Message]]]:
"""Split service-call messages into input messages and attributed context messages."""
input_messages: list[Message] = []
context_messages: dict[str, list[Message]] = {}
for message in messages:
attribution = message.additional_properties.get("_attribution")
if isinstance(attribution, Mapping):
attribution_mapping = cast(Mapping[str, Any], attribution)
source_id = attribution_mapping.get("source_id")
if isinstance(source_id, str):
context_messages.setdefault(source_id, []).append(message)
continue
input_messages.append(message)
return input_messages, context_messages
class PerServiceCallHistoryPersistingMiddleware(ChatMiddleware):
"""Persist local chat history after each service call when history is framework-managed.
This middleware runs around each model call when
``require_per_service_call_history_persistence`` is enabled. It loads history providers
before the model call, persists them after the model call, and uses a local
sentinel conversation id so the function loop follows the existing
service-managed branch without forwarding that sentinel to the leaf client.
"""
def __init__(
self,
*,
agent: SupportsAgentRun,
session: AgentSession,
providers: Sequence[HistoryProvider],
) -> None:
"""Initialize the middleware.
Args:
agent: The agent that owns the history providers.
session: The active session for the current run.
providers: The history providers participating in per-service-call persistence.
"""
self._agent = agent
self._session = session
self._providers = list(providers)
async def _prepare_service_call_context(self, messages: Sequence[Message]) -> SessionContext:
"""Create a per-call SessionContext and load history providers into it."""
input_messages, context_messages = _split_service_call_messages(messages)
service_call_context = SessionContext(
session_id=self._session.session_id,
service_session_id=None,
input_messages=list(input_messages),
)
for source_id, source_messages in context_messages.items():
service_call_context.extend_messages(source_id, source_messages)
for provider in self._providers:
if not provider.load_messages:
continue
await provider.before_run(
agent=self._agent,
session=self._session,
context=service_call_context,
state=self._session.state.setdefault(provider.source_id, {}),
)
return service_call_context
async def _persist_service_call_response(
self,
*,
service_call_context: SessionContext,
response: ChatResponse,
) -> None:
"""Persist a single model-call response through the configured history providers."""
service_call_context._response = AgentResponse( # type: ignore[assignment]
messages=response.messages,
response_id=None,
)
for provider in reversed(self._providers):
await provider.after_run(
agent=self._agent,
session=self._session,
context=service_call_context,
state=self._session.state.setdefault(provider.source_id, {}),
)
def _strip_local_conversation_id(self, context: ChatContext) -> None:
"""Remove the local sentinel before the leaf chat client is invoked."""
if is_local_history_conversation_id(cast(str | None, context.kwargs.get("conversation_id"))):
context.kwargs.pop("conversation_id", None)
if context.options is None:
return
mutable_options = dict(context.options)
if is_local_history_conversation_id(cast(str | None, mutable_options.get("conversation_id"))):
mutable_options.pop("conversation_id", None)
context.options = mutable_options
async def _finalize_response(
self,
*,
service_call_context: SessionContext,
response: ChatResponse,
) -> ChatResponse:
"""Persist a model response and apply the local follow-up sentinel when needed."""
if response.conversation_id is not None and not is_local_history_conversation_id(response.conversation_id):
raise ChatClientInvalidResponseException(
"require_per_service_call_history_persistence cannot be used "
"when the chat client returns a real conversation_id."
)
await self._persist_service_call_response(
service_call_context=service_call_context,
response=response,
)
if _response_contains_follow_up_request(response):
response.mark_internal_conversation_id()
response.conversation_id = LOCAL_HISTORY_CONVERSATION_ID
return response
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
"""Load and persist history providers around a single model call.
Args:
context: The chat invocation context for the current model call.
call_next: The next middleware or the leaf chat client.
Raises:
ChatClientInvalidResponseException: If the leaf client returns a real
service-managed conversation id while local per-service-call persistence is enabled.
ValueError: If the downstream middleware contract returns the wrong
result type for streaming or non-streaming execution.
"""
service_call_context = await self._prepare_service_call_context(context.messages)
context.messages = service_call_context.get_messages(include_input=True)
self._strip_local_conversation_id(context)
await call_next()
if context.result is None:
return
if context.stream:
if not isinstance(context.result, ResponseStream):
raise ValueError("Streaming chat middleware requires a ResponseStream result.")
context.result = context.result.with_result_hook(
lambda response: self._finalize_response(
service_call_context=service_call_context,
response=response,
)
)
return
if isinstance(context.result, ResponseStream):
raise ValueError("Non-streaming chat middleware requires a ChatResponse result.")
context.result = await self._finalize_response(
service_call_context=service_call_context,
response=context.result,
)
@deprecated(
"BaseContextProvider is deprecated. Use ContextProvider instead.",
category=DeprecationWarning,
)
class BaseContextProvider(ContextProvider):
"""Deprecated alias for :class:`ContextProvider`.
.. deprecated::
BaseContextProvider is deprecated. Use :class:`ContextProvider` instead.
"""
@deprecated(
"BaseHistoryProvider is deprecated. Use HistoryProvider instead.",
category=DeprecationWarning,
)
class BaseHistoryProvider(HistoryProvider):
"""Deprecated alias for :class:`HistoryProvider`.
.. deprecated::
BaseHistoryProvider is deprecated. Use :class:`HistoryProvider` instead.
"""
class AgentSession:
"""A conversation session with an agent.
@@ -535,7 +790,7 @@ class AgentSession:
return session
class InMemoryHistoryProvider(BaseHistoryProvider):
class InMemoryHistoryProvider(HistoryProvider):
"""Built-in history provider that stores messages in session.state.
Messages are stored in ``state["messages"]`` as a list of
@@ -36,7 +36,7 @@ from pathlib import Path, PurePosixPath
from typing import TYPE_CHECKING, Any, ClassVar, Final, Protocol, runtime_checkable
from ._feature_stage import ExperimentalFeature, experimental
from ._sessions import BaseContextProvider
from ._sessions import ContextProvider
from ._tools import FunctionTool
if TYPE_CHECKING:
@@ -519,7 +519,7 @@ SCRIPT_RUNNER_INSTRUCTIONS: Final[str] = (
@experimental(feature_id=ExperimentalFeature.SKILLS)
class SkillsProvider(BaseContextProvider):
class SkillsProvider(ContextProvider):
"""Context provider that advertises skills and exposes skill tools.
Supports both **file-based** skills (discovered from ``SKILL.md`` files)
+55 -5
View File
@@ -1688,6 +1688,34 @@ def _update_conversation_id(
options["conversation_id"] = conversation_id
def _update_continuation_state(
kwargs: dict[str, Any],
response: ChatResponse[Any],
*,
session: AgentSession | None,
options: dict[str, Any] | None = None,
) -> None:
"""Update in-flight and persisted continuation state from a response."""
conversation_id = response.conversation_id
if conversation_id is None:
return
_update_conversation_id(kwargs, conversation_id, options)
if (
session is not None
and not response.has_internal_conversation_id()
and session.service_session_id != conversation_id
):
session.service_session_id = conversation_id
def _clear_internal_conversation_id(response: ChatResponse[Any]) -> ChatResponse[Any]:
if response.has_internal_conversation_id():
response.conversation_id = None
response.clear_internal_conversation_id()
return response
def _extract_tools(
options: dict[str, Any] | None,
) -> ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None:
@@ -2206,9 +2234,14 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
),
)
aggregated_usage = add_usage_details(aggregated_usage, response.usage_details)
_update_continuation_state(
filtered_kwargs,
response,
session=invocation_session,
options=mutable_options,
)
if response.conversation_id is not None:
_update_conversation_id(filtered_kwargs, response.conversation_id, mutable_options)
prepped_messages = []
result = await _process_function_requests(
@@ -2223,7 +2256,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
)
if result.get("action") == "return":
response.usage_details = aggregated_usage
return response
return _clear_internal_conversation_id(response)
total_function_calls += result.get("function_call_count", 0)
if result.get("action") == "stop":
# Error threshold reached: force a final non-tool turn so
@@ -2279,11 +2312,17 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
),
)
aggregated_usage = add_usage_details(aggregated_usage, response.usage_details)
_update_continuation_state(
filtered_kwargs,
response,
session=invocation_session,
options=mutable_options,
)
response.usage_details = aggregated_usage
if fcc_messages:
for msg in reversed(fcc_messages):
response.messages.insert(0, msg)
return response
return _clear_internal_conversation_id(response)
return _get_response()
@@ -2343,6 +2382,12 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
# Get the finalized response from the inner stream
# This triggers the inner stream's finalizer and result hooks
response = await inner_stream.get_final_response()
_update_continuation_state(
filtered_kwargs,
response,
session=invocation_session,
options=mutable_options,
)
if not any(
item.type in ("function_call", "function_approval_request")
@@ -2352,7 +2397,6 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
return
if response.conversation_id is not None:
_update_conversation_id(filtered_kwargs, response.conversation_id, mutable_options)
prepped_messages = []
result = await _process_function_requests(
@@ -2430,7 +2474,13 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
async for update in final_inner_stream:
yield update
# Finalize the inner stream to trigger its hooks
await final_inner_stream.get_final_response()
final_response = await final_inner_stream.get_final_response()
_update_continuation_state(
filtered_kwargs,
final_response,
session=invocation_session,
options=mutable_options,
)
def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse[Any]:
# Note: stream_result_hooks are already run via inner stream's get_final_response()
@@ -2001,6 +2001,7 @@ class ChatResponse(SerializationMixin, Generic[ResponseModelT]):
"""
DEFAULT_EXCLUDE: ClassVar[set[str]] = {"raw_representation", "additional_properties"}
_INTERNAL_CONVERSATION_ID_KEY: ClassVar[str] = "_agent_framework_internal_conversation_id"
def __init__(
self,
@@ -2069,6 +2070,18 @@ class ChatResponse(SerializationMixin, Generic[ResponseModelT]):
self.continuation_token = continuation_token
self.raw_representation: Any | list[Any] | None = raw_representation
def mark_internal_conversation_id(self) -> None:
"""Mark the current conversation_id as internal control-flow state."""
self.additional_properties[self._INTERNAL_CONVERSATION_ID_KEY] = True
def clear_internal_conversation_id(self) -> None:
"""Remove the internal conversation-id marker."""
self.additional_properties.pop(self._INTERNAL_CONVERSATION_ID_KEY, None)
def has_internal_conversation_id(self) -> bool:
"""Return whether conversation_id is internal control-flow state."""
return bool(self.additional_properties.get(self._INTERNAL_CONVERSATION_ID_KEY, False))
@property
def model_id(self) -> str | None:
"""Deprecated alias for :attr:`model`."""
@@ -14,8 +14,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload
from .._agents import BaseAgent
from .._sessions import (
AgentSession,
BaseContextProvider,
BaseHistoryProvider,
ContextProvider,
HistoryProvider,
InMemoryHistoryProvider,
SessionContext,
)
@@ -86,7 +86,7 @@ class WorkflowAgent(BaseAgent):
id: str | None = None,
name: str | None = None,
description: str | None = None,
context_providers: Sequence[BaseContextProvider] | None = None,
context_providers: Sequence[ContextProvider] | None = None,
**kwargs: Any,
) -> None:
"""Initialize the WorkflowAgent.
@@ -249,7 +249,7 @@ class WorkflowAgent(BaseAgent):
options={},
)
for provider in self.context_providers:
if isinstance(provider, BaseHistoryProvider) and not provider.load_messages:
if isinstance(provider, HistoryProvider) and not provider.load_messages:
continue
if provider_session is None:
raise RuntimeError("Provider session must be available when context providers are configured.")
@@ -314,7 +314,7 @@ class WorkflowAgent(BaseAgent):
options={},
)
for provider in self.context_providers:
if isinstance(provider, BaseHistoryProvider) and not provider.load_messages:
if isinstance(provider, HistoryProvider) and not provider.load_messages:
continue
if provider_session is None:
raise RuntimeError("Provider session must be available when context providers are configured.")
@@ -1502,6 +1502,161 @@ class AgentTelemetryLayer:
self.token_usage_histogram = _get_token_usage_histogram()
self.duration_histogram = _get_duration_histogram()
def _trace_agent_invocation(
self,
*,
messages: AgentRunInputs | None,
session: AgentSession | None,
merged_options: Mapping[str, Any],
client_kwargs: Mapping[str, Any] | None,
stream: bool,
execute: Callable[[], Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]],
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Trace an agent invocation while delegating execution to ``execute``."""
global OBSERVABILITY_SETTINGS
from ._types import ResponseStream
if not OBSERVABILITY_SETTINGS.ENABLED:
return execute()
provider_name = str(self.otel_provider_name)
merged_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
attributes = _get_span_attributes(
operation_name=OtelAttr.AGENT_INVOKE_OPERATION,
provider_name=provider_name,
agent_id=getattr(self, "id", "unknown"),
agent_name=getattr(self, "name", None) or getattr(self, "id", "unknown"),
agent_description=getattr(self, "description", None),
thread_id=session.service_session_id if session else None,
all_options=dict(merged_options),
**merged_client_kwargs,
)
inner_response_telemetry_captured_fields: set[str] = set()
inner_response_telemetry_captured_fields_token = INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.set(
inner_response_telemetry_captured_fields
)
inner_accumulated_usage_token = INNER_ACCUMULATED_USAGE.set({})
if stream:
try:
run_result: object = execute()
if isinstance(run_result, ResponseStream):
result_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = run_result # pyright: ignore[reportUnknownVariableType]
elif isinstance(run_result, Awaitable):
result_stream = ResponseStream.from_awaitable(run_result) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
else:
raise RuntimeError("Streaming telemetry requires a ResponseStream result.")
except Exception:
INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(inner_response_telemetry_captured_fields_token)
INNER_ACCUMULATED_USAGE.reset(inner_accumulated_usage_token)
raise
operation = attributes.get(OtelAttr.OPERATION, "operation")
span_name = attributes.get(OtelAttr.AGENT_NAME, "unknown")
span = get_tracer().start_span(f"{operation} {span_name}")
span.set_attributes(attributes)
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages:
_capture_messages(
span=span,
provider_name=provider_name,
messages=messages,
system_instructions=_get_instructions_from_options(dict(merged_options)),
)
span_state = {"closed": False}
duration_state: dict[str, float] = {}
start_time = perf_counter()
def _close_span() -> None:
if span_state["closed"]:
return
span_state["closed"] = True
span.end()
def _record_duration() -> None:
duration_state["duration"] = perf_counter() - start_time
async def _finalize_stream() -> None:
from ._types import AgentResponse
try:
response: AgentResponse[Any] = await result_stream.get_final_response()
duration = duration_state.get("duration")
response_attributes = _get_response_attributes(
attributes,
response,
capture_response_id=INNER_RESPONSE_ID_CAPTURED_FIELD
not in inner_response_telemetry_captured_fields,
capture_usage=INNER_USAGE_CAPTURED_FIELD not in inner_response_telemetry_captured_fields,
)
_apply_accumulated_usage(response_attributes, inner_response_telemetry_captured_fields)
_capture_response(span=span, attributes=response_attributes, duration=duration)
if (
OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED
and isinstance(response, AgentResponse)
and response.messages
):
_capture_messages(
span=span,
provider_name=provider_name,
messages=response.messages,
output=True,
)
except Exception as exception:
capture_exception(span=span, exception=exception, timestamp=time_ns())
finally:
INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(inner_response_telemetry_captured_fields_token)
INNER_ACCUMULATED_USAGE.reset(inner_accumulated_usage_token)
_close_span()
wrapped_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = result_stream.with_cleanup_hook(
_record_duration
).with_cleanup_hook(_finalize_stream)
weakref.finalize(wrapped_stream, _close_span)
return wrapped_stream
async def _run() -> AgentResponse[Any]:
try:
with _get_span(attributes=attributes, span_name_attribute=OtelAttr.AGENT_NAME) as span:
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages:
_capture_messages(
span=span,
provider_name=provider_name,
messages=messages,
system_instructions=_get_instructions_from_options(dict(merged_options)),
)
start_time_stamp = perf_counter()
try:
response: AgentResponse[Any] = await execute()
except Exception as exception:
capture_exception(span=span, exception=exception, timestamp=time_ns())
raise
duration = perf_counter() - start_time_stamp
if response:
response_attributes = _get_response_attributes(
attributes,
response,
capture_response_id=INNER_RESPONSE_ID_CAPTURED_FIELD
not in inner_response_telemetry_captured_fields,
capture_usage=INNER_USAGE_CAPTURED_FIELD not in inner_response_telemetry_captured_fields,
)
_apply_accumulated_usage(response_attributes, inner_response_telemetry_captured_fields)
_capture_response(span=span, attributes=response_attributes, duration=duration)
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and response.messages:
_capture_messages(
span=span,
provider_name=provider_name,
messages=response.messages,
output=True,
)
return response # type: ignore[return-value,no-any-return]
finally:
INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(inner_response_telemetry_captured_fields_token)
INNER_ACCUMULATED_USAGE.reset(inner_accumulated_usage_token)
return _run()
@overload
def run(
self,
@@ -1565,14 +1720,12 @@ class AgentTelemetryLayer:
client_kwargs: Mapping[str, Any] | None = None,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Trace agent runs with OpenTelemetry spans and metrics."""
global OBSERVABILITY_SETTINGS
from ._types import ResponseStream, merge_chat_options
from ._types import merge_chat_options
super_run = cast(
"Callable[..., Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]]",
super().run, # type: ignore[misc]
)
provider_name = str(self.otel_provider_name)
super_run_kwargs: dict[str, Any] = {
"messages": messages,
"stream": stream,
@@ -1586,156 +1739,21 @@ class AgentTelemetryLayer:
}
if middleware is not None:
super_run_kwargs["middleware"] = middleware
if not OBSERVABILITY_SETTINGS.ENABLED:
return super_run(**super_run_kwargs) # type: ignore[no-any-return]
default_options = dict(getattr(self, "default_options", {}))
merged_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
merged_options: dict[str, Any] = merge_chat_options(
default_options, dict(options) if options is not None else {}
)
attributes = _get_span_attributes(
operation_name=OtelAttr.AGENT_INVOKE_OPERATION,
provider_name=provider_name,
agent_id=getattr(self, "id", "unknown"),
agent_name=getattr(self, "name", None) or getattr(self, "id", "unknown"),
agent_description=getattr(self, "description", None),
thread_id=session.service_session_id if session else None,
all_options=merged_options,
**merged_client_kwargs,
return self._trace_agent_invocation(
messages=messages,
session=session,
merged_options=merged_options,
client_kwargs=merged_client_kwargs,
stream=stream,
execute=lambda: super_run(**super_run_kwargs),
)
inner_response_telemetry_captured_fields: set[str] = set()
inner_response_telemetry_captured_fields_token = INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.set(
inner_response_telemetry_captured_fields
)
inner_accumulated_usage_token = INNER_ACCUMULATED_USAGE.set({})
if stream:
try:
run_result: object = super_run(**super_run_kwargs)
if isinstance(run_result, ResponseStream):
result_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = run_result # pyright: ignore[reportUnknownVariableType]
elif isinstance(run_result, Awaitable):
result_stream = ResponseStream.from_awaitable(run_result) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
else:
raise RuntimeError("Streaming telemetry requires a ResponseStream result.")
except Exception:
INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(inner_response_telemetry_captured_fields_token)
INNER_ACCUMULATED_USAGE.reset(inner_accumulated_usage_token)
raise
# Create span directly without trace.use_span() context attachment.
# Streaming spans are closed asynchronously in cleanup hooks, which run
# in a different async context than creation — using use_span() would
# cause "Failed to detach context" errors from OpenTelemetry.
operation = attributes.get(OtelAttr.OPERATION, "operation")
span_name = attributes.get(OtelAttr.AGENT_NAME, "unknown")
span = get_tracer().start_span(f"{operation} {span_name}")
span.set_attributes(attributes)
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages:
_capture_messages(
span=span,
provider_name=provider_name,
messages=messages,
system_instructions=_get_instructions_from_options(merged_options),
)
span_state = {"closed": False}
duration_state: dict[str, float] = {}
start_time = perf_counter()
def _close_span() -> None:
if span_state["closed"]:
return
span_state["closed"] = True
span.end()
def _record_duration() -> None:
duration_state["duration"] = perf_counter() - start_time
async def _finalize_stream() -> None:
from ._types import AgentResponse
try:
response: AgentResponse[Any] = await result_stream.get_final_response()
duration = duration_state.get("duration")
response_attributes = _get_response_attributes(
attributes,
response,
capture_response_id=INNER_RESPONSE_ID_CAPTURED_FIELD
not in inner_response_telemetry_captured_fields,
capture_usage=INNER_USAGE_CAPTURED_FIELD not in inner_response_telemetry_captured_fields,
)
_apply_accumulated_usage(response_attributes, inner_response_telemetry_captured_fields)
_capture_response(span=span, attributes=response_attributes, duration=duration)
if (
OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED
and isinstance(response, AgentResponse)
and response.messages
):
_capture_messages(
span=span,
provider_name=provider_name,
messages=response.messages,
output=True,
)
except Exception as exception:
capture_exception(span=span, exception=exception, timestamp=time_ns())
finally:
INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(inner_response_telemetry_captured_fields_token)
INNER_ACCUMULATED_USAGE.reset(inner_accumulated_usage_token)
_close_span()
# Register a weak reference callback to close the span if stream is garbage collected
# without being consumed. This ensures spans don't leak if users don't consume streams.
wrapped_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = result_stream.with_cleanup_hook(
_record_duration
).with_cleanup_hook(_finalize_stream)
weakref.finalize(wrapped_stream, _close_span)
return wrapped_stream
async def _run() -> AgentResponse:
try:
with _get_span(attributes=attributes, span_name_attribute=OtelAttr.AGENT_NAME) as span:
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages:
_capture_messages(
span=span,
provider_name=provider_name,
messages=messages,
system_instructions=_get_instructions_from_options(merged_options),
)
start_time_stamp = perf_counter()
try:
response: AgentResponse[Any] = await super_run(**super_run_kwargs)
except Exception as exception:
capture_exception(span=span, exception=exception, timestamp=time_ns())
raise
duration = perf_counter() - start_time_stamp
if response:
response_attributes = _get_response_attributes(
attributes,
response,
capture_response_id=INNER_RESPONSE_ID_CAPTURED_FIELD
not in inner_response_telemetry_captured_fields,
capture_usage=INNER_USAGE_CAPTURED_FIELD not in inner_response_telemetry_captured_fields,
)
_apply_accumulated_usage(response_attributes, inner_response_telemetry_captured_fields)
_capture_response(span=span, attributes=response_attributes, duration=duration)
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and response.messages:
_capture_messages(
span=span,
provider_name=provider_name,
messages=response.messages,
output=True,
)
return response # type: ignore[return-value,no-any-return]
finally:
INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(inner_response_telemetry_captured_fields_token)
INNER_ACCUMULATED_USAGE.reset(inner_accumulated_usage_token)
return _run()
# region Otel Helpers
+491 -7
View File
@@ -3,8 +3,8 @@
import contextlib
import inspect
import json
from collections.abc import AsyncIterable, MutableSequence
from typing import Any
from collections.abc import AsyncIterable, Awaitable, Callable, MutableSequence, Sequence
from typing import Any, cast
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
@@ -18,22 +18,29 @@ from agent_framework import (
AgentResponse,
AgentResponseUpdate,
AgentSession,
BaseContextProvider,
ChatContext,
ChatOptions,
ChatResponse,
ChatResponseUpdate,
Content,
ContextProvider,
FunctionTool,
HistoryProvider,
InMemoryHistoryProvider,
Message,
ResponseStream,
SessionContext,
SlidingWindowStrategy,
SupportsAgentRun,
SupportsChatGetResponse,
TruncationStrategy,
chat_middleware,
tool,
)
from agent_framework._agents import _get_tool_name, _merge_options, _sanitize_agent_name
from agent_framework._mcp import MCPTool, _build_prefixed_mcp_name, _normalize_mcp_name
from agent_framework._middleware import FunctionInvocationContext
from agent_framework.exceptions import AgentInvalidRequestException, ChatClientInvalidResponseException
class _FixedTokenizer:
@@ -68,6 +75,49 @@ class _ConnectedMCPTool(MCPTool):
raise NotImplementedError
class _RecordingHistoryProvider(HistoryProvider):
def __init__(self, source_id: str = "recording_history") -> None:
super().__init__(source_id=source_id)
async def get_messages(
self,
session_id: str | None,
*,
state: dict[str, Any] | None = None,
**kwargs: Any,
) -> list[Message]:
if state is None:
return []
state["get_call_count"] = state.get("get_call_count", 0) + 1
return list(cast(list[Message], state.get("messages", [])))
async def save_messages(
self,
session_id: str | None,
messages: Sequence[Message],
*,
state: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
if state is None:
return
state["save_call_count"] = state.get("save_call_count", 0) + 1
state.setdefault("messages", []).extend(messages)
class _ResponseIdRecordingHistoryProvider(_RecordingHistoryProvider):
async def after_run(
self,
*,
agent: SupportsAgentRun,
session: AgentSession,
context: SessionContext,
state: dict[str, Any],
) -> None:
state.setdefault("response_ids", []).append(context.response.response_id if context.response else None)
await super().after_run(agent=agent, session=session, context=context, state=state)
def test_agent_session_type(agent_session: AgentSession) -> None:
assert isinstance(agent_session, AgentSession)
@@ -314,6 +364,413 @@ async def test_prepare_run_context_handles_function_kwargs(
assert ctx["client_kwargs"]["session"] is session
async def test_chat_agent_persists_history_per_service_call(
chat_client_base: SupportsChatGetResponse,
) -> None:
provider = _RecordingHistoryProvider()
@tool(name="lookup_weather", approval_mode="never_require")
def lookup_weather(location: str) -> str:
return f"Weather in {location}: sunny"
session = AgentSession()
session.state[provider.source_id] = {
"messages": [
Message(role="user", text="Earlier question"),
Message(role="assistant", text="Earlier answer"),
]
}
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1",
name="lookup_weather",
arguments='{"location": "Seattle"}',
)
],
),
response_id="resp_call_1",
),
ChatResponse(messages=Message(role="assistant", text="It is sunny in Seattle."), response_id="resp_call_2"),
]
agent = Agent(
client=chat_client_base,
tools=[lookup_weather],
context_providers=[provider],
require_per_service_call_history_persistence=True,
)
result = await agent.run("What's the weather in Seattle?", session=session)
provider_state = session.state[provider.source_id]
stored_messages = cast(list[Message], provider_state["messages"])
assert result.text == "It is sunny in Seattle."
assert result.response_id is None
assert chat_client_base.call_count == 2
assert provider_state["get_call_count"] == 2
assert provider_state["save_call_count"] == 2
assert stored_messages[-1].text == "It is sunny in Seattle."
assert session.service_session_id is None
async def test_chat_agent_persists_history_per_service_call_streaming(
chat_client_base: SupportsChatGetResponse,
) -> None:
provider = _RecordingHistoryProvider()
@tool(name="lookup_weather", approval_mode="never_require")
def lookup_weather(location: str) -> str:
return f"Weather in {location}: sunny"
session = AgentSession()
session.state[provider.source_id] = {
"messages": [
Message(role="user", text="Earlier question"),
Message(role="assistant", text="Earlier answer"),
]
}
chat_client_base.streaming_responses = [
[
ChatResponseUpdate(
contents=[
Content.from_function_call(
call_id="call_1",
name="lookup_weather",
arguments='{"location": "Seattle"}',
)
],
role="assistant",
finish_reason="stop",
response_id="resp_call_1",
)
],
[
ChatResponseUpdate(
contents=[Content.from_text("It is sunny in Seattle.")],
role="assistant",
finish_reason="stop",
response_id="resp_call_2",
)
],
]
agent = Agent(
client=chat_client_base,
tools=[lookup_weather],
context_providers=[provider],
require_per_service_call_history_persistence=True,
)
stream = agent.run("What's the weather in Seattle?", session=session, stream=True)
async for _ in stream:
pass
result = await stream.get_final_response()
provider_state = session.state[provider.source_id]
stored_messages = cast(list[Message], provider_state["messages"])
assert result.text == "It is sunny in Seattle."
assert result.response_id is None
assert chat_client_base.call_count == 2
assert provider_state["get_call_count"] == 2
assert provider_state["save_call_count"] == 2
assert stored_messages[-1].text == "It is sunny in Seattle."
assert session.service_session_id is None
async def test_streaming_per_service_call_persistence_hides_response_id_from_after_run(
chat_client_base: SupportsChatGetResponse,
) -> None:
provider = _ResponseIdRecordingHistoryProvider()
@tool(name="lookup_weather", approval_mode="never_require")
def lookup_weather(location: str) -> str:
return f"Weather in {location}: sunny"
session = AgentSession()
session.state[provider.source_id] = {"messages": []}
chat_client_base.streaming_responses = [
[
ChatResponseUpdate(
contents=[
Content.from_function_call(
call_id="call_1",
name="lookup_weather",
arguments='{"location": "Seattle"}',
)
],
role="assistant",
finish_reason="stop",
response_id="resp_call_1",
)
],
[
ChatResponseUpdate(
contents=[Content.from_text("It is sunny in Seattle.")],
role="assistant",
finish_reason="stop",
response_id="resp_call_2",
)
],
]
agent = Agent(
client=chat_client_base,
tools=[lookup_weather],
context_providers=[provider],
require_per_service_call_history_persistence=True,
)
stream = agent.run("What's the weather in Seattle?", session=session, stream=True)
async for _ in stream:
pass
result = await stream.get_final_response()
provider_state = session.state[provider.source_id]
assert result.response_id is None
assert provider_state["response_ids"] == [None, None]
async def test_per_service_call_persistence_uses_real_service_storage_when_client_stores_by_default(
chat_client_base: SupportsChatGetResponse,
) -> None:
provider = _RecordingHistoryProvider()
@tool(name="lookup_weather", approval_mode="never_require")
def lookup_weather(location: str) -> str:
return f"Weather in {location}: sunny"
chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined]
session = AgentSession()
session.state[provider.source_id] = {"messages": []}
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1",
name="lookup_weather",
arguments='{"location": "Seattle"}',
)
],
),
conversation_id="resp_service_managed",
response_id="resp_call_1",
),
ChatResponse(
messages=Message(role="assistant", text="It is sunny in Seattle."),
conversation_id="resp_service_managed",
response_id="resp_call_2",
),
]
agent = Agent(
client=chat_client_base,
tools=[lookup_weather],
context_providers=[provider],
require_per_service_call_history_persistence=True,
)
result = await agent.run("What's the weather in Seattle?", session=session)
provider_state = session.state[provider.source_id]
assert result.text == "It is sunny in Seattle."
assert result.response_id == "resp_call_2"
assert chat_client_base.call_count == 2
assert "get_call_count" not in provider_state
assert "save_call_count" not in provider_state
assert session.service_session_id == "resp_service_managed"
async def test_service_storage_updates_session_handle_per_service_call_before_non_streaming_failure(
chat_client_base: SupportsChatGetResponse,
) -> None:
provider = _RecordingHistoryProvider()
@tool(name="lookup_weather", approval_mode="never_require")
def lookup_weather(location: str) -> str:
return f"Weather in {location}: sunny"
chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined]
session = AgentSession()
session.state[provider.source_id] = {"messages": []}
first_response = ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1",
name="lookup_weather",
arguments='{"location": "Seattle"}',
)
],
),
conversation_id="resp_call_1",
response_id="resp_call_1",
)
mock_get_non_streaming_response = AsyncMock(
side_effect=[first_response, RuntimeError("service down")],
)
agent = Agent(
client=chat_client_base,
tools=[lookup_weather],
context_providers=[provider],
require_per_service_call_history_persistence=True,
)
with (
patch.object(chat_client_base, "_get_non_streaming_response", new=mock_get_non_streaming_response),
pytest.raises(RuntimeError, match="service down"),
):
await agent.run("What's the weather in Seattle?", session=session)
assert mock_get_non_streaming_response.await_count == 2
assert session.service_session_id == "resp_call_1"
async def test_service_storage_updates_session_handle_per_service_call_before_streaming_failure(
chat_client_base: SupportsChatGetResponse,
) -> None:
provider = _RecordingHistoryProvider()
@tool(name="lookup_weather", approval_mode="never_require")
def lookup_weather(location: str) -> str:
return f"Weather in {location}: sunny"
chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined]
session = AgentSession()
session.state[provider.source_id] = {"messages": []}
async def _first_stream_updates() -> AsyncIterable[ChatResponseUpdate]:
yield ChatResponseUpdate(
contents=[
Content.from_function_call(
call_id="call_1",
name="lookup_weather",
arguments='{"location": "Seattle"}',
)
],
role="assistant",
finish_reason="stop",
)
def _finalize_first_stream(_updates: Sequence[ChatResponseUpdate]) -> ChatResponse[Any]:
return ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1",
name="lookup_weather",
arguments='{"location": "Seattle"}',
)
],
),
conversation_id="resp_call_1",
response_id="resp_call_1",
)
first_stream = ResponseStream(_first_stream_updates(), finalizer=_finalize_first_stream)
mock_get_streaming_response = MagicMock(side_effect=[first_stream, RuntimeError("service down")])
agent = Agent(
client=chat_client_base,
tools=[lookup_weather],
context_providers=[provider],
require_per_service_call_history_persistence=True,
)
with (
patch.object(chat_client_base, "_get_streaming_response", new=mock_get_streaming_response),
pytest.raises(RuntimeError, match="service down"),
):
stream = agent.run("What's the weather in Seattle?", session=session, stream=True)
async for _ in stream:
pass
assert mock_get_streaming_response.call_count == 2
assert session.service_session_id == "resp_call_1"
async def test_chat_agent_without_per_service_call_persistence_preserves_response_id(
chat_client_base: SupportsChatGetResponse,
) -> None:
chat_client_base.run_responses = [
ChatResponse(
messages=Message(role="assistant", text="Hello"),
response_id="resp_call_1",
)
]
agent = Agent(
client=chat_client_base,
context_providers=[InMemoryHistoryProvider()],
)
result = await agent.run("Hello", session=AgentSession(), options={"store": False})
assert result.response_id == "resp_call_1"
async def test_per_service_call_persistence_rejects_real_service_conversation_id(
chat_client_base: SupportsChatGetResponse,
) -> None:
provider = _RecordingHistoryProvider()
chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined]
session = AgentSession()
session.state[provider.source_id] = {"messages": []}
chat_client_base.run_responses = [
ChatResponse(
messages=Message(role="assistant", text="Hello"),
conversation_id="resp_service_managed",
)
]
agent = Agent(
client=chat_client_base,
context_providers=[provider],
require_per_service_call_history_persistence=True,
)
with pytest.raises(
ChatClientInvalidResponseException,
match="require_per_service_call_history_persistence cannot be used",
):
await agent.run("Hello", session=session, options={"store": False})
async def test_per_service_call_persistence_rejects_existing_conversation_id_when_service_not_storing_history(
chat_client_base: SupportsChatGetResponse,
) -> None:
provider = _RecordingHistoryProvider()
session = AgentSession()
session.state[provider.source_id] = {"messages": []}
agent = Agent(
client=chat_client_base,
context_providers=[provider],
require_per_service_call_history_persistence=True,
)
with pytest.raises(
AgentInvalidRequestException,
match="require_per_service_call_history_persistence cannot be used",
):
await agent.run("Hello", session=session, options={"store": False, "conversation_id": "existing_conversation"})
async def test_chat_client_agent_run_with_session(chat_client_base: SupportsChatGetResponse) -> None:
mock_response = ChatResponse(
messages=[Message(role="assistant", contents=[Content.from_text("test response")])],
@@ -586,7 +1043,7 @@ async def test_chat_client_agent_author_name_is_used_from_response(
# Mock context provider for testing
class MockContextProvider(BaseContextProvider):
class MockContextProvider(ContextProvider):
def __init__(self, messages: list[Message] | None = None) -> None:
super().__init__(source_id="mock")
self.context_messages = messages
@@ -1723,7 +2180,7 @@ async def test_agent_create_session_with_context_providers(
):
"""Test that create_session works when context_providers are set on the agent."""
class TestContextProvider(BaseContextProvider):
class TestContextProvider(ContextProvider):
def __init__(self):
super().__init__(source_id="test")
@@ -1798,7 +2255,7 @@ async def test_chat_agent_context_provider_adds_tools_when_agent_has_none(
"""A tool provided by context."""
return text
class ToolContextProvider(BaseContextProvider):
class ToolContextProvider(ContextProvider):
def __init__(self):
super().__init__(source_id="tool-context")
@@ -1827,7 +2284,7 @@ async def test_chat_agent_context_provider_adds_instructions_when_agent_has_none
):
"""Test that context provider instructions are used when agent has no default instructions."""
class InstructionContextProvider(BaseContextProvider):
class InstructionContextProvider(ContextProvider):
def __init__(self):
super().__init__(source_id="instruction-context")
@@ -1849,6 +2306,33 @@ async def test_chat_agent_context_provider_adds_instructions_when_agent_has_none
assert options.get("instructions") == "Context-provided instructions"
async def test_chat_agent_context_provider_adds_middleware_when_agent_has_none(
chat_client_base: SupportsChatGetResponse,
) -> None:
"""Test that context provider middleware is collected during preparation."""
@chat_middleware
async def context_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
class MiddlewareContextProvider(ContextProvider):
def __init__(self) -> None:
super().__init__(source_id="middleware-context")
async def before_run(self, *, agent, session, context, state) -> None:
context.extend_middleware("middleware-context", context_chat_middleware)
agent = Agent(client=chat_client_base, context_providers=[MiddlewareContextProvider()])
session_context, _ = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage]
session=None,
input_messages=[Message(role="user", text="Hello")],
)
assert session_context.middleware["middleware-context"] == [context_chat_middleware]
assert session_context.get_middleware() == [context_chat_middleware]
# region STORES_BY_DEFAULT tests
@@ -15,6 +15,7 @@ from agent_framework import (
ChatResponse,
ChatResponseUpdate,
Content,
ContextProvider,
FunctionInvocationContext,
FunctionMiddleware,
FunctionTool,
@@ -464,6 +465,31 @@ class TestChatAgentMultipleMiddlewareOrdering:
expected_order = ["class_agent_before", "function_agent_before", "function_agent_after", "class_agent_after"]
assert execution_order == expected_order
async def test_provider_added_agent_middleware_is_rejected(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test provider-added agent middleware is rejected explicitly."""
@agent_middleware
async def provider_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
class ProviderMiddlewareContextProvider(ContextProvider):
def __init__(self) -> None:
super().__init__(source_id="provider-middleware")
async def before_run(self, *, agent, session, context, state) -> None:
context.extend_middleware(self.source_id, provider_middleware)
agent = Agent(
client=chat_client_base,
context_providers=[ProviderMiddlewareContextProvider()],
)
with pytest.raises(
MiddlewareException,
match="Context providers may only add chat or function middleware",
):
await agent.run([Message(role="user", text="test message")])
# region Tool Functions for Testing
@@ -2066,6 +2092,121 @@ class TestChatAgentChatMiddleware:
"agent_middleware_after",
]
async def test_provider_added_chat_and_function_middleware_are_forwarded(
self, chat_client_base: "MockBaseChatClient"
) -> None:
"""Test provider-added chat and function middleware forwarding and ordering."""
execution_order: list[str] = []
@chat_middleware
async def constructor_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("constructor_chat_before")
await call_next()
execution_order.append("constructor_chat_after")
@chat_middleware
async def provider_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("provider_chat_before")
await call_next()
execution_order.append("provider_chat_after")
@chat_middleware
async def run_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("run_chat_before")
await call_next()
execution_order.append("run_chat_after")
@function_middleware
async def constructor_function_middleware(
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
execution_order.append("constructor_function_before")
await call_next()
execution_order.append("constructor_function_after")
@function_middleware
async def provider_function_middleware(
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
execution_order.append("provider_function_before")
await call_next()
execution_order.append("provider_function_after")
@function_middleware
async def run_function_middleware(
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
execution_order.append("run_function_before")
await call_next()
execution_order.append("run_function_after")
class ProviderMiddlewareContextProvider(ContextProvider):
def __init__(self) -> None:
super().__init__(source_id="provider-middleware")
async def before_run(self, *, agent, session, context, state) -> None:
context.extend_middleware(
self.source_id,
[
provider_chat_middleware,
provider_function_middleware,
],
)
chat_client_base.run_responses = [
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_provider",
name="sample_tool_function",
arguments='{"location": "Seattle"}',
)
],
)
]
),
ChatResponse(messages=[Message(role="assistant", text="Final response")]),
]
agent = Agent(
client=chat_client_base,
middleware=[constructor_chat_middleware, constructor_function_middleware],
context_providers=[ProviderMiddlewareContextProvider()],
tools=[sample_tool_function],
)
response = await agent.run(
[Message(role="user", text="Get weather for Seattle")],
middleware=[run_chat_middleware, run_function_middleware],
)
assert response is not None
assert chat_client_base.call_count == 2
assert response.messages[-1].text == "Final response"
assert execution_order == [
"constructor_chat_before",
"run_chat_before",
"provider_chat_before",
"provider_chat_after",
"run_chat_after",
"constructor_chat_after",
"constructor_function_before",
"run_function_before",
"provider_function_before",
"provider_function_after",
"run_function_after",
"constructor_function_after",
"constructor_chat_before",
"run_chat_before",
"provider_chat_before",
"provider_chat_after",
"run_chat_after",
"constructor_chat_after",
]
async def test_agent_middleware_can_access_and_override_options(self) -> None:
"""Test that agent middleware can access and override runtime options."""
captured_options: dict[str, Any] = {}
@@ -1,16 +1,26 @@
# Copyright (c) Microsoft. All rights reserved.
import json
from collections.abc import Sequence
from collections.abc import Awaitable, Callable, Sequence
from agent_framework import Message
from agent_framework._sessions import (
import pytest
from agent_framework import (
AgentContext,
AgentSession,
BaseContextProvider,
BaseHistoryProvider,
ChatContext,
ContextProvider,
HistoryProvider,
InMemoryHistoryProvider,
Message,
SessionContext,
agent_middleware,
chat_middleware,
)
from agent_framework._sessions import LOCAL_HISTORY_CONVERSATION_ID, is_local_history_conversation_id
from agent_framework.exceptions import MiddlewareException
# ---------------------------------------------------------------------------
# SessionContext tests
@@ -102,6 +112,50 @@ class TestSessionContext:
ctx.extend_instructions("sys", ["Be helpful", "Be concise"])
assert ctx.instructions == ["Be helpful", "Be concise"]
def test_extend_middleware_creates_key_and_appends(self) -> None:
ctx = SessionContext(input_messages=[])
@chat_middleware
async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
@chat_middleware
async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
ctx.extend_middleware("rag", first_middleware)
ctx.extend_middleware("rag", [second_middleware])
assert ctx.middleware["rag"] == [first_middleware, second_middleware]
assert ctx.get_middleware() == [first_middleware, second_middleware]
def test_extend_middleware_preserves_source_order(self) -> None:
ctx = SessionContext(input_messages=[])
@chat_middleware
async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
@chat_middleware
async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
ctx.extend_middleware("a", first_middleware)
ctx.extend_middleware("b", second_middleware)
assert list(ctx.middleware.keys()) == ["a", "b"]
assert ctx.get_middleware() == [first_middleware, second_middleware]
def test_extend_middleware_rejects_agent_middleware(self) -> None:
ctx = SessionContext(input_messages=[])
@agent_middleware
async def provider_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
with pytest.raises(MiddlewareException, match="Context providers may only add chat or function middleware"):
ctx.extend_middleware("rag", provider_agent_middleware)
def test_get_messages_all(self) -> None:
ctx = SessionContext(input_messages=[])
ctx.extend_messages("a", [Message(role="user", contents=["a"])])
@@ -154,37 +208,58 @@ class TestSessionContext:
ctx._response = resp
assert ctx.response is resp
def test_local_history_conversation_id_sentinel(self) -> None:
assert is_local_history_conversation_id(LOCAL_HISTORY_CONVERSATION_ID) is True
assert is_local_history_conversation_id("some_other_id") is False
# ---------------------------------------------------------------------------
# BaseContextProvider tests
# ContextProvider tests
# ---------------------------------------------------------------------------
class TestContextProviderBase:
class TestContextProvider:
def test_source_id_required(self) -> None:
provider = BaseContextProvider(source_id="test")
provider = ContextProvider(source_id="test")
assert provider.source_id == "test"
async def test_before_run_is_noop(self) -> None:
provider = BaseContextProvider(source_id="test")
provider = ContextProvider(source_id="test")
session = AgentSession()
ctx = SessionContext(input_messages=[])
# Should not raise
await provider.before_run(agent=None, session=session, context=ctx, state={}) # type: ignore[arg-type]
async def test_after_run_is_noop(self) -> None:
provider = BaseContextProvider(source_id="test")
provider = ContextProvider(source_id="test")
session = AgentSession()
ctx = SessionContext(input_messages=[])
await provider.after_run(agent=None, session=session, context=ctx, state={}) # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# BaseHistoryProvider tests
# Deprecated provider alias tests
# ---------------------------------------------------------------------------
class ConcreteHistoryProvider(BaseHistoryProvider):
class TestDeprecatedProviderAliases:
def test_base_context_provider_warns_and_is_compatible(self) -> None:
with pytest.warns(DeprecationWarning, match="BaseContextProvider is deprecated. Use ContextProvider instead."):
provider = BaseContextProvider(source_id="test")
assert isinstance(provider, ContextProvider)
def test_base_provider_aliases_preserve_subtyping(self) -> None:
assert issubclass(BaseContextProvider, ContextProvider)
assert issubclass(BaseHistoryProvider, HistoryProvider)
# ---------------------------------------------------------------------------
# HistoryProvider tests
# ---------------------------------------------------------------------------
class ConcreteHistoryProvider(HistoryProvider):
"""Concrete test implementation."""
def __init__(self, source_id: str, stored_messages: list[Message] | None = None, **kwargs) -> None:
@@ -17,9 +17,9 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic, cast
from agent_framework import (
AGENT_FRAMEWORK_USER_AGENT,
AgentMiddlewareLayer,
BaseContextProvider,
ChatAndFunctionMiddlewareTypes,
ChatMiddlewareLayer,
ContextProvider,
FunctionInvocationConfiguration,
FunctionInvocationLayer,
FunctionTool,
@@ -50,8 +50,8 @@ else:
if TYPE_CHECKING:
from agent_framework import (
Agent,
BaseContextProvider,
ChatAndFunctionMiddlewareTypes,
ContextProvider,
MiddlewareTypes,
ToolTypes,
)
@@ -224,8 +224,9 @@ class RawFoundryAgentChatClient( # type: ignore[misc]
instructions: str | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
default_options: FoundryAgentOptionsT | Mapping[str, Any] | None = None,
context_providers: Sequence[BaseContextProvider] | None = None,
context_providers: Sequence[ContextProvider] | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
require_per_service_call_history_persistence: bool = False,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
@@ -246,6 +247,7 @@ class RawFoundryAgentChatClient( # type: ignore[misc]
tools=function_tools,
context_providers=context_providers,
middleware=middleware,
require_per_service_call_history_persistence=require_per_service_call_history_persistence,
client_type=cast(type[RawFoundryAgentChatClient], self.__class__),
id=id,
name=self.agent_name if name is None else name,
@@ -468,7 +470,7 @@ class RawFoundryAgent( # type: ignore[misc]
project_client: AIProjectClient | None = None,
allow_preview: bool | None = None,
tools: FunctionTool | Callable[..., Any] | Sequence[FunctionTool | Callable[..., Any]] | None = None,
context_providers: Sequence[BaseContextProvider] | None = None,
context_providers: Sequence[ContextProvider] | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
client_type: type[RawFoundryAgentChatClient] | None = None,
env_file_path: str | None = None,
@@ -478,6 +480,7 @@ class RawFoundryAgent( # type: ignore[misc]
description: str | None = None,
instructions: str | None = None,
default_options: FoundryAgentOptionsT | Mapping[str, Any] | None = None,
require_per_service_call_history_persistence: bool = False,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
@@ -507,6 +510,8 @@ class RawFoundryAgent( # type: ignore[misc]
description: Optional local description for the local agent wrapper.
instructions: Optional instructions for the local agent wrapper.
default_options: Default chat options for the local agent wrapper.
require_per_service_call_history_persistence: Whether to require per-service-call
chat history persistence when using local history providers.
function_invocation_configuration: Optional function invocation configuration override.
compaction_strategy: Optional agent-level in-run compaction override.
tokenizer: Optional agent-level tokenizer override.
@@ -548,6 +553,7 @@ class RawFoundryAgent( # type: ignore[misc]
default_options=cast(FoundryAgentOptionsT | None, default_options),
context_providers=context_providers,
middleware=middleware,
require_per_service_call_history_persistence=require_per_service_call_history_persistence,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
additional_properties=dict(additional_properties) if additional_properties is not None else None,
@@ -661,7 +667,7 @@ class FoundryAgent( # type: ignore[misc]
project_client: AIProjectClient | None = None,
allow_preview: bool | None = None,
tools: FunctionTool | Callable[..., Any] | Sequence[FunctionTool | Callable[..., Any]] | None = None,
context_providers: Sequence[BaseContextProvider] | None = None,
context_providers: Sequence[ContextProvider] | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
client_type: type[RawFoundryAgentChatClient] | None = None,
env_file_path: str | None = None,
@@ -671,6 +677,7 @@ class FoundryAgent( # type: ignore[misc]
description: str | None = None,
instructions: str | None = None,
default_options: FoundryAgentOptionsT | Mapping[str, Any] | None = None,
require_per_service_call_history_persistence: bool = False,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
@@ -696,6 +703,8 @@ class FoundryAgent( # type: ignore[misc]
description: Optional local description for the local agent wrapper.
instructions: Optional instructions for the local agent wrapper.
default_options: Default chat options for the local agent wrapper.
require_per_service_call_history_persistence: Whether to require per-service-call
chat history persistence when using local history providers.
function_invocation_configuration: Optional function invocation configuration override.
compaction_strategy: Optional agent-level in-run compaction override.
tokenizer: Optional agent-level tokenizer override.
@@ -719,6 +728,7 @@ class FoundryAgent( # type: ignore[misc]
description=description,
instructions=instructions,
default_options=default_options,
require_per_service_call_history_persistence=require_per_service_call_history_persistence,
function_invocation_configuration=function_invocation_configuration,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
@@ -1,9 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.
"""Foundry Memory Context Provider using BaseContextProvider.
"""Foundry Memory Context Provider using ContextProvider.
This module provides ``FoundryMemoryProvider``, built on
:class:`BaseContextProvider`.
:class:`ContextProvider`.
"""
from __future__ import annotations
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, ClassVar
from agent_framework import (
AGENT_FRAMEWORK_USER_AGENT,
AgentSession,
BaseContextProvider,
ContextProvider,
Message,
SessionContext,
load_settings,
@@ -46,8 +46,8 @@ class FoundryProjectSettings(TypedDict, total=False):
project_endpoint: str | None
class FoundryMemoryProvider(BaseContextProvider):
"""Foundry Memory context provider using the new BaseContextProvider hooks pattern.
class FoundryMemoryProvider(ContextProvider):
"""Foundry Memory context provider using the new ContextProvider hooks pattern.
Integrates Azure AI Foundry Memory Store for persistent semantic memory,
searching and storing memories via the Azure AI Projects SDK.
@@ -15,8 +15,8 @@ from agent_framework import (
AgentResponseUpdate,
AgentSession,
BaseAgent,
BaseContextProvider,
Content,
ContextProvider,
Message,
ResponseStream,
normalize_messages,
@@ -178,7 +178,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
id: str | None = None,
name: str | None = None,
description: str | None = None,
context_providers: Sequence[BaseContextProvider] | None = None,
context_providers: Sequence[ContextProvider] | None = None,
middleware: Sequence[AgentMiddlewareTypes] | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
default_options: OptionsT | None = None,
@@ -1,9 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.
"""New-pattern Mem0 context provider using BaseContextProvider.
"""New-pattern Mem0 context provider using ContextProvider.
This module provides ``Mem0ContextProvider``, built on the new
:class:`BaseContextProvider` hooks pattern.
:class:`ContextProvider` hooks pattern.
"""
from __future__ import annotations
@@ -13,7 +13,7 @@ from contextlib import AbstractAsyncContextManager
from typing import TYPE_CHECKING, Any, ClassVar
from agent_framework import Message
from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext
from agent_framework._sessions import AgentSession, ContextProvider, SessionContext
from mem0 import AsyncMemory, AsyncMemoryClient
if sys.version_info >= (3, 11):
@@ -33,8 +33,8 @@ class _MemorySearchResponse_v1_1(TypedDict):
_MemorySearchResponse_v2 = list[dict[str, Any]]
class Mem0ContextProvider(BaseContextProvider):
"""Mem0 context provider using the new BaseContextProvider hooks pattern.
class Mem0ContextProvider(ContextProvider):
"""Mem0 context provider using the new ContextProvider hooks pattern.
Integrates Mem0 for persistent semantic memory, searching and storing
memories via the Mem0 API.
@@ -39,7 +39,7 @@ from dataclasses import dataclass
from typing import Any
from agent_framework import Agent, SupportsAgentRun
from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware
from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware, MiddlewareTermination
from agent_framework._sessions import AgentSession
from agent_framework._tools import FunctionTool, tool
from agent_framework._types import AgentResponse, Content, Message
@@ -138,8 +138,6 @@ class _AutoHandoffMiddleware(FunctionMiddleware):
await call_next()
return
from agent_framework._middleware import MiddlewareTermination
# Short-circuit execution and provide deterministic response payload for the tool call.
# Parse the result using the default parser to ensure in a form that can be passed directly to LLM APIs.
context.result = FunctionTool.parse_result({
@@ -375,6 +373,7 @@ class HandoffAgentExecutor(AgentExecutor):
description=agent.description,
context_providers=agent.context_providers,
middleware=agent.agent_middleware,
require_per_service_call_history_persistence=agent.require_per_service_call_history_persistence,
default_options=cloned_options, # type: ignore[assignment]
)
@@ -8,10 +8,11 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from agent_framework import (
Agent,
BaseContextProvider,
ChatResponse,
ChatResponseUpdate,
Content,
ContextProvider,
InMemoryHistoryProvider,
Message,
ResponseStream,
WorkflowEvent,
@@ -695,6 +696,48 @@ def test_handoff_clone_disables_provider_side_storage() -> None:
assert executor._agent.default_options.get("store") is False
async def test_handoff_clone_preserves_per_service_call_history_persistence() -> None:
"""Handoff clones should keep per-service-call history persistence active for auto-handoff termination."""
triage_history = InMemoryHistoryProvider()
triage = Agent(
id="triage",
name="triage",
client=MockChatClient(name="triage", handoff_to="specialist"),
context_providers=[triage_history],
require_per_service_call_history_persistence=True,
)
specialist = Agent(
id="specialist",
name="specialist",
client=MockChatClient(name="specialist"),
default_options={"tool_choice": "none"},
)
workflow = (
HandoffBuilder(participants=[triage, specialist], termination_condition=lambda _: False)
.with_start_agent(triage)
.add_handoff(triage, [specialist])
.add_handoff(specialist, [triage])
.build()
)
await _drain(workflow.run("start", stream=True))
executor = workflow.executors[resolve_agent_id(triage)]
assert isinstance(executor, HandoffAgentExecutor)
assert executor._agent.require_per_service_call_history_persistence is True
provider_state = executor._session.state[triage_history.source_id]
stored_messages = await triage_history.get_messages(
executor._session.session_id,
state=provider_state,
)
assert [message.role for message in stored_messages] == ["user", "assistant"]
assert any(content.type == "function_call" for content in stored_messages[-1].contents)
assert all(message.role != "tool" for message in stored_messages)
async def test_handoff_clears_stale_service_session_id_before_run() -> None:
"""Stale service session IDs must be dropped before each handoff agent turn."""
triage = MockHandoffAgent(name="triage", handoff_to="specialist")
@@ -997,7 +1040,7 @@ async def test_context_provider_preserved_during_handoff():
# Track whether context provider methods were called
provider_calls: list[str] = []
class TestContextProvider(BaseContextProvider):
class TestContextProvider(ContextProvider):
"""A test context provider that tracks its invocations."""
def __init__(self) -> None:
@@ -1,9 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.
"""New-pattern Redis context provider using BaseContextProvider.
"""New-pattern Redis context provider using ContextProvider.
This module provides ``RedisContextProvider``, built on the new
:class:`BaseContextProvider` hooks pattern.
:class:`ContextProvider` hooks pattern.
"""
from __future__ import annotations
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal
import numpy as np
from agent_framework import Message
from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext
from agent_framework._sessions import AgentSession, ContextProvider, SessionContext
from agent_framework.exceptions import (
AgentException,
IntegrationInvalidRequestException,
@@ -41,8 +41,8 @@ if TYPE_CHECKING:
from agent_framework._agents import SupportsAgentRun
class RedisContextProvider(BaseContextProvider):
"""Redis context provider using the new BaseContextProvider hooks pattern.
class RedisContextProvider(ContextProvider):
"""Redis context provider using the new ContextProvider hooks pattern.
Stores context in Redis and retrieves scoped context via full-text or
optional hybrid vector search.
@@ -1,9 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.
"""New-pattern Redis history provider using BaseHistoryProvider.
"""New-pattern Redis history provider using HistoryProvider.
This module provides ``RedisHistoryProvider``, built on the new
:class:`BaseHistoryProvider` hooks pattern.
:class:`HistoryProvider` hooks pattern.
"""
from __future__ import annotations
@@ -13,12 +13,12 @@ from typing import Any, ClassVar
import redis.asyncio as redis
from agent_framework import Message
from agent_framework._sessions import BaseHistoryProvider
from agent_framework._sessions import HistoryProvider
from redis.credentials import CredentialProvider
class RedisHistoryProvider(BaseHistoryProvider):
"""Redis-backed history provider using the new BaseHistoryProvider hooks pattern.
class RedisHistoryProvider(HistoryProvider):
"""Redis-backed history provider using the new HistoryProvider hooks pattern.
Stores conversation history in Redis Lists, with each session isolated by a
unique Redis key.
@@ -475,7 +475,7 @@ class TestRedisHistoryProviderClear:
class TestRedisHistoryProviderBeforeAfterRun:
"""Test before_run/after_run integration via BaseHistoryProvider defaults."""
"""Test before_run/after_run integration via HistoryProvider defaults."""
async def test_before_run_loads_history(self, mock_redis_client: MagicMock):
msg = Message(role="user", contents=["old msg"])
+2 -2
View File
@@ -3,7 +3,7 @@
import asyncio
from typing import Any
from agent_framework import Agent, AgentSession, BaseContextProvider, SessionContext
from agent_framework import Agent, AgentSession, ContextProvider, SessionContext
from agent_framework.foundry import FoundryChatClient
from azure.identity import AzureCliCredential
@@ -17,7 +17,7 @@ responses — the name persists across turns via the session.
# <context_provider>
class UserMemoryProvider(BaseContextProvider):
class UserMemoryProvider(ContextProvider):
"""A context provider that remembers user info in session state."""
DEFAULT_SOURCE_ID = "user_memory"
@@ -9,6 +9,7 @@ This folder contains examples for direct chat client usage patterns.
| [`built_in_chat_clients.py`](built_in_chat_clients.py) | Consolidated sample for built-in chat clients. Uses `get_client()` to create the selected client and pass it to `main()`. |
| [`chat_response_cancellation.py`](chat_response_cancellation.py) | Demonstrates how to cancel chat responses during streaming, showing proper cancellation handling and cleanup. |
| [`custom_chat_client.py`](custom_chat_client.py) | Demonstrates how to create custom chat clients by extending the `BaseChatClient` class. Shows a `EchoingChatClient` implementation and how to integrate it with `Agent` using the `as_agent()` method. |
| [`require_per_service_call_history_persistence.py`](require_per_service_call_history_persistence.py) | Compares two otherwise identical `FoundryChatClient` agents with `store=False`; the only difference is whether `require_per_service_call_history_persistence` is enabled, and only the run without it stores the synthesized tool result when middleware terminates the loop early. |
## Selecting a built-in client
@@ -35,6 +36,15 @@ Example:
uv run samples/02-agents/chat_client/built_in_chat_clients.py
```
The `require_per_service_call_history_persistence.py` sample uses `FoundryChatClient`, so set the usual Foundry settings first and sign in with the Azure CLI:
```bash
export FOUNDRY_PROJECT_ENDPOINT="https://<your-project>.services.ai.azure.com/api/projects/<project-name>"
export FOUNDRY_MODEL="<your-model-deployment-name>"
az login
uv run samples/02-agents/chat_client/require_per_service_call_history_persistence.py
```
## Environment Variables
Depending on the selected client, set the appropriate environment variables:
@@ -0,0 +1,194 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from typing import Annotated
from agent_framework import (
Agent,
FunctionInvocationContext,
FunctionMiddleware,
InMemoryHistoryProvider,
Message,
MiddlewareTermination,
)
from agent_framework.foundry import FoundryChatClient
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
from pydantic import Field
"""
Compare Foundry agents with and without per-service-call chat history persistence.
This sample runs two otherwise identical Foundry agents with ``store=False`` so
history stays local for both runs.
The sample adds a function middleware that raises ``MiddlewareTermination``
immediately after the tool runs, so the request stops before a second model
call.
That early termination is the important difference:
- Without per-service-call chat history persistence, the synthesized tool result is
still written to local history.
- With ``require_per_service_call_history_persistence=True``, that synthesized tool result is
not written to local history.
The per-service-call persistence case matches service-side storage behavior. When a terminated
request never sends the tool result back to the service, that result also never
becomes part of the service-managed history.
"""
# Load environment variables from .env file
load_dotenv()
def lookup_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Return a deterministic weather result for the requested location."""
return f"The weather in {location} is sunny."
class TerminateAfterToolMiddleware(FunctionMiddleware):
"""Stop the tool loop after the first tool finishes."""
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Run the tool, then terminate the loop with that tool result."""
await call_next()
raise MiddlewareTermination(result=context.result)
def _describe_message(message: Message) -> str:
"""Render one stored message in a compact, readable format."""
parts: list[str] = []
for content in message.contents:
if content.type == "text" and content.text:
parts.append(content.text)
elif content.type == "function_call":
parts.append(f"function_call -> {content.name}({content.arguments})")
elif content.type == "function_result":
parts.append(f"function_result -> {content.result}")
else:
parts.append(content.type)
return f"{message.role}: {' | '.join(parts)}"
def _includes_tool_result(messages: list[Message]) -> bool:
"""Return whether any stored message contains a tool result."""
return any(content.type == "function_result" for message in messages for content in message.contents)
async def main() -> None:
"""Run both comparison scenarios."""
print("=== require_per_service_call_history_persistence when middleware terminates the tool loop ===\n")
# 1. Create one Foundry chat client that both agents will share.
client = FoundryChatClient(credential=AzureCliCredential())
query = "What is the weather in Seattle, and should I bring sunglasses?"
# 2. Create and run the agent without per-service-call persistence.
agent_without_persistence = Agent(
client=client,
instructions=(
"You are a weather assistant. Call lookup_weather exactly once before answering "
"any weather question, then summarize the tool result in one short paragraph."
),
tools=[lookup_weather],
context_providers=[InMemoryHistoryProvider()],
middleware=[TerminateAfterToolMiddleware()],
default_options={"tool_choice": "required", "store": False},
)
session_without_persistence = agent_without_persistence.create_session()
await agent_without_persistence.run(
query,
session=session_without_persistence,
)
stored_messages_without_persistence = session_without_persistence.state[InMemoryHistoryProvider.DEFAULT_SOURCE_ID][
"messages"
]
print("=== Without per-service-call persistence ===")
print("Loop terminated immediately after the tool finished.")
print(f"Stored synthesized tool result: {_includes_tool_result(stored_messages_without_persistence)}")
print("Stored history:")
for index, message in enumerate(stored_messages_without_persistence, start=1):
print(f" {index}. {_describe_message(message)}")
print()
# 3. Create and run the agent with per-service-call persistence enabled.
agent_with_persistence = Agent(
client=client,
instructions=(
"You are a weather assistant. Call lookup_weather exactly once before answering "
"any weather question, then summarize the tool result in one short paragraph."
),
tools=[lookup_weather],
context_providers=[InMemoryHistoryProvider()],
middleware=[TerminateAfterToolMiddleware()],
require_per_service_call_history_persistence=True,
default_options={"tool_choice": "required", "store": False},
)
session_with_persistence = agent_with_persistence.create_session()
await agent_with_persistence.run(
query,
session=session_with_persistence,
)
stored_messages_with_persistence = session_with_persistence.state[InMemoryHistoryProvider.DEFAULT_SOURCE_ID][
"messages"
]
print("=== With per-service-call persistence ===")
print("Loop terminated immediately after the tool finished.")
print(f"Stored synthesized tool result: {_includes_tool_result(stored_messages_with_persistence)}")
print("Stored history:")
for index, message in enumerate(stored_messages_with_persistence, start=1):
print(f" {index}. {_describe_message(message)}")
print()
# 4. Summarize the effect of the flag.
print(
"Both runs used FoundryChatClient with store=False and terminated right after the tool. "
"Without per-service-call persistence, local history still stored the synthesized tool result. "
"With per-service-call persistence, local history stopped at the assistant function-call message instead, "
"which matches service-side storage because the terminated tool result is never sent back to the service."
)
if __name__ == "__main__":
asyncio.run(main())
"""
Sample output:
=== require_per_service_call_history_persistence when middleware terminates the tool loop ===
=== Without per-service-call persistence ===
Loop terminated immediately after the tool finished.
Stored synthesized tool result: True
Stored history:
1. user: What is the weather in Seattle, and should I bring sunglasses?
2. assistant: function_call -> lookup_weather({"location":"Seattle"})
3. tool: function_result -> The weather in Seattle is sunny.
=== With per-service-call persistence ===
Loop terminated immediately after the tool finished.
Stored synthesized tool result: False
Stored history:
1. user: What is the weather in Seattle, and should I bring sunglasses?
2. assistant: function_call -> lookup_weather({"location":"Seattle"})
Both runs used FoundryChatClient with store=False and terminated right after
the tool. Without per-service-call persistence, local history still stored the
synthesized tool result. With per-service-call persistence, local history
stopped at the assistant function-call message instead, which matches
service-side storage because the terminated tool result is never sent back to
the service.
"""
@@ -5,7 +5,7 @@ import os
from contextlib import suppress
from typing import Any
from agent_framework import Agent, AgentSession, BaseContextProvider, SessionContext, SupportsChatGetResponse
from agent_framework import Agent, AgentSession, ContextProvider, SessionContext, SupportsChatGetResponse
from agent_framework.foundry import FoundryChatClient
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
@@ -20,7 +20,7 @@ class UserInfo(BaseModel):
age: int | None = None
class UserInfoMemory(BaseContextProvider):
class UserInfoMemory(ContextProvider):
DEFAULT_SOURCE_ID = "user_info_memory"
def __init__(self, source_id: str = DEFAULT_SOURCE_ID, *, client: SupportsChatGetResponse, **kwargs: Any):
@@ -4,7 +4,7 @@ import asyncio
from collections.abc import Sequence
from typing import Any
from agent_framework import Agent, AgentSession, BaseHistoryProvider, Message
from agent_framework import Agent, AgentSession, HistoryProvider, Message
from agent_framework.openai import OpenAIChatClient
from dotenv import load_dotenv
@@ -20,7 +20,7 @@ preferred storage solution (database, file system, etc.).
"""
class CustomHistoryProvider(BaseHistoryProvider):
class CustomHistoryProvider(HistoryProvider):
"""Implementation of custom history provider.
In real applications, this can be an implementation of relational database or vector store."""
@@ -7,7 +7,7 @@ These samples demonstrate how to build and host AI agents in Python using the [A
| Sample | Description |
| ----------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------- |
| [`agent_with_hosted_mcp`](./agent_with_hosted_mcp/) | Hosted MCP tool that connects to Microsoft Learn via `https://learn.microsoft.com/api/mcp` |
| [`agent_with_text_search_rag`](./agent_with_text_search_rag/) | Retrieval-augmented generation using a custom `BaseContextProvider` with Contoso Outdoors sample data |
| [`agent_with_text_search_rag`](./agent_with_text_search_rag/) | Retrieval-augmented generation using a custom `ContextProvider` with Contoso Outdoors sample data |
| [`agents_in_workflow`](./agents_in_workflow/) | Concurrent workflow that combines researcher, marketer, and legal specialist agents |
| [`agent_with_local_tools`](./agent_with_local_tools/) | Local Python tool execution for Seattle hotel search |
| [`writer_reviewer_agents_in_workflow`](./writer_reviewer_agents_in_workflow/) | Writer/Reviewer workflow using `FoundryChatClient` |
@@ -5,7 +5,7 @@ import sys
from dataclasses import dataclass
from typing import Any
from agent_framework import Agent, AgentSession, BaseContextProvider, Message, SessionContext
from agent_framework import Agent, AgentSession, ContextProvider, Message, SessionContext
from agent_framework.foundry import FoundryChatClient
from azure.ai.agentserver.agentframework import from_agent_framework # pyright: ignore[reportUnknownVariableType]
from azure.identity import DefaultAzureCredential
@@ -28,7 +28,7 @@ class TextSearchResult:
text: str
class TextSearchContextProvider(BaseContextProvider):
class TextSearchContextProvider(ContextProvider):
"""A simple context provider that simulates text search results based on keywords in the user's message."""
def __init__(self):
@@ -7,7 +7,7 @@ This sample demonstrates a simple Weather Forecast Agent built with the Python M
- Python 3.11+
- [uv](https://github.com/astral-sh/uv) for fast dependency management
- [devtunnel](https://learn.microsoft.com/azure/developer/dev-tunnels/get-started?tabs=windows)
- [Microsoft 365 Agents Toolkit](https://github.com/OfficeDev/microsoft-365-agents-toolkit) for playground/testing
- `agentsplayground` for playground/testing
- Access to OpenAI or Azure OpenAI with a model like `gpt-4o-mini`
## Configuration