mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: [BREAKING] update context provider APIs, middleware, and per-service-call history persistence (#4992)
* Rename provider base APIs Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Allow provider-added chat and function middleware Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Simulate service-stored history per model call Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix typing regressions in CI Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix response ID suppression review feedback Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Rename per-service-call history persistence APIs Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address context persistence review feedback Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Stabilize markdown sample docs Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Persist service continuation state per call Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
38de991481
commit
b065a4ce51
@@ -35,9 +35,9 @@ from agent_framework import (
|
||||
AgentResponseUpdate,
|
||||
AgentSession,
|
||||
BaseAgent,
|
||||
BaseHistoryProvider,
|
||||
Content,
|
||||
ContinuationToken,
|
||||
HistoryProvider,
|
||||
Message,
|
||||
ResponseStream,
|
||||
SessionContext,
|
||||
@@ -353,7 +353,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
|
||||
# Run before_run providers (forward order)
|
||||
for provider in self.context_providers:
|
||||
if isinstance(provider, BaseHistoryProvider) and not provider.load_messages:
|
||||
if isinstance(provider, HistoryProvider) and not provider.load_messages:
|
||||
continue
|
||||
if session is None:
|
||||
raise RuntimeError("Provider session must be available when context providers are configured.")
|
||||
|
||||
@@ -24,8 +24,8 @@ from agent_framework import (
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentSession,
|
||||
BaseContextProvider,
|
||||
Content,
|
||||
ContextProvider,
|
||||
Message,
|
||||
SessionContext,
|
||||
)
|
||||
@@ -869,7 +869,7 @@ async def test_poll_task_completed(a2a_agent: A2AAgent, mock_a2a_client: MockA2A
|
||||
# region Context Provider Tests
|
||||
|
||||
|
||||
class TrackingContextProvider(BaseContextProvider):
|
||||
class TrackingContextProvider(ContextProvider):
|
||||
"""A context provider that records when before_run and after_run are called."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
||||
+5
-5
@@ -1,9 +1,9 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""New-pattern Azure AI Search context provider using BaseContextProvider.
|
||||
"""New-pattern Azure AI Search context provider using ContextProvider.
|
||||
|
||||
This module provides ``AzureAISearchContextProvider``, built on the new
|
||||
:class:`BaseContextProvider` hooks pattern.
|
||||
:class:`ContextProvider` hooks pattern.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -17,8 +17,8 @@ from agent_framework import (
|
||||
AGENT_FRAMEWORK_USER_AGENT,
|
||||
AgentSession,
|
||||
Annotation,
|
||||
BaseContextProvider,
|
||||
Content,
|
||||
ContextProvider,
|
||||
Message,
|
||||
SecretString,
|
||||
SessionContext,
|
||||
@@ -154,8 +154,8 @@ class AzureAISearchSettings(TypedDict, total=False):
|
||||
api_key: SecretString | None
|
||||
|
||||
|
||||
class AzureAISearchContextProvider(BaseContextProvider):
|
||||
"""Azure AI Search context provider using the new BaseContextProvider hooks pattern.
|
||||
class AzureAISearchContextProvider(ContextProvider):
|
||||
"""Azure AI Search context provider using the new ContextProvider hooks pattern.
|
||||
|
||||
Retrieves relevant context from Azure AI Search using semantic or agentic search
|
||||
modes.
|
||||
|
||||
@@ -11,7 +11,7 @@ from collections.abc import Sequence
|
||||
from typing import Any, ClassVar, TypedDict
|
||||
|
||||
from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message
|
||||
from agent_framework._sessions import BaseHistoryProvider
|
||||
from agent_framework._sessions import HistoryProvider
|
||||
from agent_framework._settings import SecretString, load_settings
|
||||
from azure.core.credentials import TokenCredential
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
@@ -32,8 +32,8 @@ class AzureCosmosHistorySettings(TypedDict, total=False):
|
||||
key: SecretString | None
|
||||
|
||||
|
||||
class CosmosHistoryProvider(BaseHistoryProvider):
|
||||
"""Azure Cosmos DB-backed history provider using BaseHistoryProvider hooks."""
|
||||
class CosmosHistoryProvider(HistoryProvider):
|
||||
"""Azure Cosmos DB-backed history provider using HistoryProvider hooks."""
|
||||
|
||||
DEFAULT_SOURCE_ID: ClassVar[str] = "azure_cosmos_history"
|
||||
_BATCH_OPERATION_LIMIT: ClassVar[int] = 100
|
||||
|
||||
@@ -16,8 +16,8 @@ from agent_framework import (
|
||||
AgentRunInputs,
|
||||
AgentSession,
|
||||
BaseAgent,
|
||||
BaseContextProvider,
|
||||
Content,
|
||||
ContextProvider,
|
||||
FunctionTool,
|
||||
Message,
|
||||
ResponseStream,
|
||||
@@ -223,7 +223,7 @@ class RawClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
id: str | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
context_providers: Sequence[BaseContextProvider] | None = None,
|
||||
context_providers: Sequence[ContextProvider] | None = None,
|
||||
middleware: Sequence[AgentMiddlewareTypes] | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | str | Sequence[ToolTypes | Callable[..., Any] | str] | None = None,
|
||||
default_options: OptionsT | MutableMapping[str, Any] | None = None,
|
||||
|
||||
@@ -11,8 +11,8 @@ from agent_framework import (
|
||||
AgentResponseUpdate,
|
||||
AgentSession,
|
||||
BaseAgent,
|
||||
BaseContextProvider,
|
||||
Content,
|
||||
ContextProvider,
|
||||
Message,
|
||||
ResponseStream,
|
||||
normalize_messages,
|
||||
@@ -60,7 +60,7 @@ class CopilotStudioAgent(BaseAgent):
|
||||
id: str | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
context_providers: Sequence[BaseContextProvider] | None = None,
|
||||
context_providers: Sequence[ContextProvider] | None = None,
|
||||
middleware: list[AgentMiddlewareTypes] | None = None,
|
||||
environment_id: str | None = None,
|
||||
agent_identifier: str | None = None,
|
||||
|
||||
@@ -61,8 +61,8 @@ agent_framework/
|
||||
|
||||
- **`AgentSession`** - Manages conversation state and session metadata
|
||||
- **`SessionContext`** - Context object for session-scoped data during agent runs
|
||||
- **`BaseContextProvider`** - Base class for context providers (RAG, memory systems)
|
||||
- **`BaseHistoryProvider`** - Base class for conversation history storage
|
||||
- **`ContextProvider`** - Base class for context providers (RAG, memory systems)
|
||||
- **`HistoryProvider`** - Base class for conversation history storage
|
||||
|
||||
### Skills (`_skills.py`)
|
||||
|
||||
@@ -70,7 +70,7 @@ agent_framework/
|
||||
- **`SkillResource`** - Named supplementary content attached to a skill; holds either static `content` or a dynamic `function` (sync or async). Exactly one must be provided.
|
||||
- **`SkillScript`** - An executable script attached to a skill; holds either an inline `function` (code-defined, runs in-process) or a `path` to a file on disk (file-based, delegated to a runner). Exactly one must be provided.
|
||||
- **`SkillScriptRunner`** - Protocol for file-based script execution. Any callable matching `(skill, script, args) -> Any` satisfies it. Code-defined scripts do not use a runner.
|
||||
- **`SkillsProvider`** - Context provider (extends `BaseContextProvider`) that discovers file-based skills from `SKILL.md` files and/or accepts code-defined `Skill` instances. Follows progressive disclosure: advertise → load → read resources / run scripts.
|
||||
- **`SkillsProvider`** - Context provider (extends `ContextProvider`) that discovers file-based skills from `SKILL.md` files and/or accepts code-defined `Skill` instances. Follows progressive disclosure: advertise → load → read resources / run scripts.
|
||||
|
||||
### Workflows (`_workflows/`)
|
||||
|
||||
|
||||
@@ -102,8 +102,10 @@ from ._middleware import (
|
||||
)
|
||||
from ._sessions import (
|
||||
AgentSession,
|
||||
BaseContextProvider,
|
||||
BaseHistoryProvider,
|
||||
BaseContextProvider, # type: ignore[reportDeprecated]
|
||||
BaseHistoryProvider, # type: ignore[reportDeprecated]
|
||||
ContextProvider,
|
||||
HistoryProvider,
|
||||
InMemoryHistoryProvider,
|
||||
SessionContext,
|
||||
register_state_type,
|
||||
@@ -296,6 +298,7 @@ __all__ = [
|
||||
"CompactionProvider",
|
||||
"CompactionStrategy",
|
||||
"Content",
|
||||
"ContextProvider",
|
||||
"ContinuationToken",
|
||||
"ConversationSplit",
|
||||
"ConversationSplitter",
|
||||
@@ -331,6 +334,7 @@ __all__ = [
|
||||
"FunctionTool",
|
||||
"GeneratedEmbeddings",
|
||||
"GraphConnectivityError",
|
||||
"HistoryProvider",
|
||||
"InMemoryCheckpointStorage",
|
||||
"InMemoryHistoryProvider",
|
||||
"InProcRunnerContext",
|
||||
|
||||
@@ -29,14 +29,16 @@ from . import _tools as _tool_utils # pyright: ignore[reportPrivateUsage]
|
||||
from ._clients import BaseChatClient, SupportsChatGetResponse
|
||||
from ._docstrings import apply_layered_docstring
|
||||
from ._mcp import LOG_LEVEL_MAPPING, MCPTool
|
||||
from ._middleware import AgentMiddlewareLayer, FunctionInvocationContext, MiddlewareTypes
|
||||
from ._middleware import AgentMiddlewareLayer, FunctionInvocationContext, MiddlewareTypes, categorize_middleware
|
||||
from ._serialization import SerializationMixin
|
||||
from ._sessions import (
|
||||
AgentSession,
|
||||
BaseContextProvider,
|
||||
BaseHistoryProvider,
|
||||
ContextProvider,
|
||||
HistoryProvider,
|
||||
InMemoryHistoryProvider,
|
||||
PerServiceCallHistoryPersistingMiddleware,
|
||||
SessionContext,
|
||||
is_local_history_conversation_id,
|
||||
)
|
||||
from ._tools import FunctionInvocationLayer, FunctionTool, ToolTypes, normalize_tools
|
||||
from ._types import (
|
||||
@@ -50,7 +52,7 @@ from ._types import (
|
||||
map_chat_to_agent_update,
|
||||
normalize_messages,
|
||||
)
|
||||
from .exceptions import AgentInvalidResponseException, UserInputRequiredException
|
||||
from .exceptions import AgentInvalidRequestException, AgentInvalidResponseException, UserInputRequiredException
|
||||
from .observability import AgentTelemetryLayer
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
@@ -166,6 +168,7 @@ class _RunContext(TypedDict):
|
||||
input_messages: Sequence[Message]
|
||||
session_messages: Sequence[Message]
|
||||
agent_name: str
|
||||
suppress_response_id: bool
|
||||
chat_options: MutableMapping[str, Any]
|
||||
compaction_strategy: CompactionStrategy | None
|
||||
tokenizer: TokenizerProtocol | None
|
||||
@@ -366,6 +369,7 @@ class BaseAgent(SerializationMixin):
|
||||
"""
|
||||
|
||||
DEFAULT_EXCLUDE: ClassVar[set[str]] = {"additional_properties"}
|
||||
require_per_service_call_history_persistence: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -373,7 +377,7 @@ class BaseAgent(SerializationMixin):
|
||||
id: str | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
context_providers: Sequence[BaseContextProvider] | None = None,
|
||||
context_providers: Sequence[ContextProvider] | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
additional_properties: MutableMapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
@@ -393,7 +397,7 @@ class BaseAgent(SerializationMixin):
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.context_providers: list[BaseContextProvider] = list(context_providers or [])
|
||||
self.context_providers: list[ContextProvider] = list(context_providers or [])
|
||||
self.middleware: list[MiddlewareTypes] | None = (
|
||||
cast(list[MiddlewareTypes], middleware) if middleware is not None else None
|
||||
)
|
||||
@@ -455,7 +459,12 @@ class BaseAgent(SerializationMixin):
|
||||
if provider_session is None and self.context_providers:
|
||||
provider_session = AgentSession()
|
||||
|
||||
per_service_call_history_required = self.require_per_service_call_history_persistence and any(
|
||||
isinstance(provider, HistoryProvider) for provider in self.context_providers
|
||||
)
|
||||
for provider in reversed(self.context_providers):
|
||||
if per_service_call_history_required and isinstance(provider, HistoryProvider):
|
||||
continue
|
||||
if provider_session is None:
|
||||
raise RuntimeError("Provider session must be available when context providers are configured.")
|
||||
await provider.after_run(
|
||||
@@ -656,8 +665,9 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
description: str | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
default_options: OptionsCoT | None = None,
|
||||
context_providers: Sequence[BaseContextProvider] | None = None,
|
||||
context_providers: Sequence[ContextProvider] | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
require_per_service_call_history_persistence: bool = False,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
additional_properties: MutableMapping[str, Any] | None = None,
|
||||
@@ -675,6 +685,11 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
description: A brief description of the agent's purpose.
|
||||
context_providers: Context providers to include during agent invocation.
|
||||
middleware: List of middleware to intercept agent and function invocations.
|
||||
require_per_service_call_history_persistence: When True, history providers are invoked
|
||||
around each model call instead of once per ``run()`` when the service
|
||||
is not already storing history. If service-side storage is active for
|
||||
the run, the agent skips local history providers and relies on the
|
||||
service-managed conversation instead.
|
||||
default_options: A TypedDict containing chat options. When using a typed agent like
|
||||
``Agent[OpenAIChatOptions]``, this enables IDE autocomplete for
|
||||
provider-specific options including temperature, max_tokens, model_id,
|
||||
@@ -706,6 +721,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
)
|
||||
self.client = client
|
||||
self.compaction_strategy = compaction_strategy
|
||||
self.require_per_service_call_history_persistence = require_per_service_call_history_persistence
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# Get tools from options or named parameter (named param takes precedence)
|
||||
@@ -764,6 +780,35 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
await self._async_exit_stack.enter_async_context(context_manager)
|
||||
return self
|
||||
|
||||
def _get_history_providers(self) -> list[HistoryProvider]:
|
||||
return [provider for provider in self.context_providers if isinstance(provider, HistoryProvider)]
|
||||
|
||||
def _resolve_per_service_call_history_providers(
|
||||
self,
|
||||
*,
|
||||
session: AgentSession | None,
|
||||
options: Mapping[str, Any] | None,
|
||||
service_stores_history: bool,
|
||||
) -> list[HistoryProvider]:
|
||||
history_providers = self._get_history_providers()
|
||||
if not self.require_per_service_call_history_persistence or not history_providers:
|
||||
return []
|
||||
|
||||
conversation_id = (
|
||||
session.service_session_id
|
||||
if session and session.service_session_id
|
||||
else cast(str | None, (options or {}).get("conversation_id") or self.default_options.get("conversation_id"))
|
||||
)
|
||||
if service_stores_history:
|
||||
return []
|
||||
|
||||
if conversation_id is not None:
|
||||
raise AgentInvalidRequestException(
|
||||
"require_per_service_call_history_persistence cannot be used "
|
||||
"with an existing service-managed conversation."
|
||||
)
|
||||
return history_providers
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
@@ -885,97 +930,9 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
When stream=True: A ResponseStream of AgentResponseUpdate items with
|
||||
``get_final_response()`` for the final AgentResponse.
|
||||
"""
|
||||
if not stream:
|
||||
|
||||
async def _run_non_streaming() -> AgentResponse[Any]:
|
||||
ctx = await self._prepare_run_context(
|
||||
messages=messages,
|
||||
session=session,
|
||||
tools=tools,
|
||||
options=options,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
)
|
||||
response = cast(
|
||||
ChatResponse[Any],
|
||||
await self.client.get_response( # type: ignore
|
||||
messages=ctx["session_messages"],
|
||||
stream=False,
|
||||
options=ctx["chat_options"], # type: ignore[reportArgumentType]
|
||||
compaction_strategy=ctx["compaction_strategy"],
|
||||
tokenizer=ctx["tokenizer"],
|
||||
function_invocation_kwargs=ctx["function_invocation_kwargs"],
|
||||
client_kwargs=ctx["client_kwargs"],
|
||||
),
|
||||
)
|
||||
|
||||
if not response:
|
||||
raise AgentInvalidResponseException("Chat client did not return a response.")
|
||||
|
||||
await self._finalize_response(
|
||||
response=response,
|
||||
agent_name=ctx["agent_name"],
|
||||
session=ctx["session"],
|
||||
session_context=ctx["session_context"],
|
||||
)
|
||||
response_format = ctx["chat_options"].get("response_format")
|
||||
if not (
|
||||
response_format is not None
|
||||
and isinstance(response_format, type)
|
||||
and issubclass(response_format, BaseModel)
|
||||
):
|
||||
response_format = None
|
||||
|
||||
return AgentResponse(
|
||||
messages=response.messages,
|
||||
response_id=response.response_id,
|
||||
created_at=response.created_at,
|
||||
usage_details=response.usage_details,
|
||||
value=response.value,
|
||||
response_format=response_format,
|
||||
continuation_token=response.continuation_token,
|
||||
raw_representation=response,
|
||||
additional_properties=response.additional_properties,
|
||||
)
|
||||
|
||||
return _run_non_streaming()
|
||||
|
||||
# Use a holder to capture the context created during stream initialization
|
||||
ctx_holder: dict[str, _RunContext | None] = {"ctx": None}
|
||||
|
||||
async def _post_hook(response: AgentResponse) -> None:
|
||||
ctx = ctx_holder["ctx"]
|
||||
if ctx is None:
|
||||
return # No context available (shouldn't happen in normal flow)
|
||||
|
||||
# Update thread with conversation_id derived from streaming raw updates.
|
||||
# Using response_id here can break function-call continuation for APIs
|
||||
# where response IDs are not valid conversation handles.
|
||||
conversation_id = self._extract_conversation_id_from_streaming_response(response)
|
||||
# Ensure author names are set for all messages
|
||||
for message in response.messages:
|
||||
if message.author_name is None:
|
||||
message.author_name = ctx["agent_name"]
|
||||
|
||||
# Propagate conversation_id back to session from streaming updates.
|
||||
# For Responses-style APIs this can rotate every turn (response_id-based continuation),
|
||||
# so refresh when a newer value is returned.
|
||||
sess = ctx["session"]
|
||||
if sess and conversation_id and sess.service_session_id != conversation_id:
|
||||
sess.service_session_id = conversation_id
|
||||
|
||||
# Run after_run providers (reverse order)
|
||||
session_context = ctx["session_context"]
|
||||
session_context._response = AgentResponse( # type: ignore[assignment]
|
||||
messages=response.messages,
|
||||
response_id=response.response_id,
|
||||
)
|
||||
await self._run_after_providers(session=ctx["session"], context=session_context)
|
||||
|
||||
async def _get_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
ctx_holder["ctx"] = await self._prepare_run_context(
|
||||
async def _prepare_run_context() -> _RunContext:
|
||||
return await self._prepare_run_context(
|
||||
messages=messages,
|
||||
session=session,
|
||||
tools=tools,
|
||||
@@ -985,55 +942,177 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
)
|
||||
ctx: _RunContext = ctx_holder["ctx"] # type: ignore[assignment] # Safe: we just assigned it
|
||||
|
||||
if not stream:
|
||||
|
||||
async def _run_non_streaming() -> AgentResponse[Any]:
|
||||
ctx = await _prepare_run_context()
|
||||
response = await self._call_chat_client(ctx, stream=False)
|
||||
return await self._parse_non_streaming_response(ctx, response)
|
||||
|
||||
return _run_non_streaming()
|
||||
|
||||
async def _run_streaming() -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
ctx = await _prepare_run_context()
|
||||
stream_response = self._call_chat_client(ctx, stream=True)
|
||||
return self._parse_streaming_response(ctx, stream_response)
|
||||
|
||||
return cast(
|
||||
ResponseStream[AgentResponseUpdate, AgentResponse[Any]],
|
||||
cast(Any, ResponseStream).from_awaitable(_run_streaming()),
|
||||
)
|
||||
|
||||
@overload
|
||||
def _call_chat_client(
|
||||
self,
|
||||
context: _RunContext,
|
||||
*,
|
||||
stream: Literal[False],
|
||||
) -> Awaitable[ChatResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
def _call_chat_client(
|
||||
self,
|
||||
context: _RunContext,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
||||
|
||||
def _call_chat_client(
|
||||
self,
|
||||
context: _RunContext,
|
||||
*,
|
||||
stream: bool,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
"""Invoke the downstream chat client for a prepared run context."""
|
||||
if stream:
|
||||
return self.client.get_response( # type: ignore[call-overload, no-any-return]
|
||||
messages=ctx["session_messages"],
|
||||
messages=context["session_messages"],
|
||||
stream=True,
|
||||
options=ctx["chat_options"], # type: ignore[reportArgumentType]
|
||||
compaction_strategy=ctx["compaction_strategy"],
|
||||
tokenizer=ctx["tokenizer"],
|
||||
function_invocation_kwargs=ctx["function_invocation_kwargs"],
|
||||
client_kwargs=ctx["client_kwargs"],
|
||||
options=context["chat_options"], # type: ignore[reportArgumentType]
|
||||
compaction_strategy=context["compaction_strategy"],
|
||||
tokenizer=context["tokenizer"],
|
||||
function_invocation_kwargs=context["function_invocation_kwargs"],
|
||||
client_kwargs=context["client_kwargs"],
|
||||
)
|
||||
|
||||
def _propagate_conversation_id(
|
||||
update: AgentResponseUpdate,
|
||||
) -> AgentResponseUpdate:
|
||||
"""Eagerly propagate conversation_id to session as updates arrive.
|
||||
return self.client.get_response( # type: ignore[call-overload, no-any-return]
|
||||
messages=context["session_messages"],
|
||||
stream=False,
|
||||
options=context["chat_options"], # type: ignore[reportArgumentType]
|
||||
compaction_strategy=context["compaction_strategy"],
|
||||
tokenizer=context["tokenizer"],
|
||||
function_invocation_kwargs=context["function_invocation_kwargs"],
|
||||
client_kwargs=context["client_kwargs"],
|
||||
)
|
||||
|
||||
This ensures session.service_session_id is set even when the user
|
||||
only iterates the stream without calling get_final_response().
|
||||
"""
|
||||
async def _parse_non_streaming_response(
|
||||
self,
|
||||
context: _RunContext,
|
||||
response: ChatResponse[Any],
|
||||
) -> AgentResponse[Any]:
|
||||
"""Finalize a non-streaming chat response into an AgentResponse."""
|
||||
if not response:
|
||||
raise AgentInvalidResponseException("Chat client did not return a response.")
|
||||
|
||||
await self._finalize_response(
|
||||
response=response,
|
||||
agent_name=context["agent_name"],
|
||||
session=context["session"],
|
||||
session_context=context["session_context"],
|
||||
suppress_response_id=context["suppress_response_id"],
|
||||
)
|
||||
|
||||
response_format = context["chat_options"].get("response_format")
|
||||
if not (
|
||||
response_format is not None and isinstance(response_format, type) and issubclass(response_format, BaseModel)
|
||||
):
|
||||
response_format = None
|
||||
|
||||
return AgentResponse(
|
||||
messages=response.messages,
|
||||
response_id=None if context["suppress_response_id"] else response.response_id,
|
||||
created_at=response.created_at,
|
||||
usage_details=response.usage_details,
|
||||
value=response.value,
|
||||
response_format=response_format,
|
||||
continuation_token=response.continuation_token,
|
||||
raw_representation=response,
|
||||
additional_properties=response.additional_properties,
|
||||
)
|
||||
|
||||
def _parse_streaming_response(
|
||||
self,
|
||||
context: _RunContext,
|
||||
stream_response: ResponseStream[ChatResponseUpdate, ChatResponse[Any]],
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""Finalize a streaming chat response into an agent response stream."""
|
||||
|
||||
async def _post_hook(response: AgentResponse) -> None:
|
||||
# Update thread with conversation_id derived from streaming raw updates.
|
||||
# Using response_id here can break function-call continuation for APIs
|
||||
# where response IDs are not valid conversation handles.
|
||||
conversation_id = self._extract_conversation_id_from_streaming_response(response)
|
||||
|
||||
for message in response.messages:
|
||||
if message.author_name is None:
|
||||
message.author_name = context["agent_name"]
|
||||
|
||||
session = context["session"]
|
||||
if (
|
||||
session
|
||||
and conversation_id
|
||||
and not is_local_history_conversation_id(conversation_id)
|
||||
and session.service_session_id != conversation_id
|
||||
):
|
||||
session.service_session_id = conversation_id
|
||||
|
||||
suppress_response_id = context["suppress_response_id"]
|
||||
session_context = context["session_context"]
|
||||
session_context._response = AgentResponse( # type: ignore[assignment]
|
||||
messages=response.messages,
|
||||
response_id=None if suppress_response_id else response.response_id,
|
||||
)
|
||||
await self._run_after_providers(session=session, context=session_context)
|
||||
|
||||
def _propagate_conversation_id(update: AgentResponseUpdate) -> AgentResponseUpdate:
|
||||
"""Eagerly propagate conversation_id to session as updates arrive."""
|
||||
session = context["session"]
|
||||
if session is None:
|
||||
return update
|
||||
raw = update.raw_representation
|
||||
conv_id = getattr(raw, "conversation_id", None) if raw else None
|
||||
if isinstance(conv_id, str) and conv_id and session.service_session_id != conv_id:
|
||||
session.service_session_id = conv_id
|
||||
conversation_id = getattr(raw, "conversation_id", None) if raw else None
|
||||
if (
|
||||
isinstance(conversation_id, str)
|
||||
and conversation_id
|
||||
and not is_local_history_conversation_id(conversation_id)
|
||||
and session.service_session_id != conversation_id
|
||||
):
|
||||
session.service_session_id = conversation_id
|
||||
return update
|
||||
|
||||
def _suppress_response_id(update: AgentResponseUpdate) -> AgentResponseUpdate:
|
||||
"""Hide raw service response ids when local per-service-call persistence owns continuation."""
|
||||
update.response_id = None
|
||||
return update
|
||||
|
||||
def _finalizer(updates: Sequence[AgentResponseUpdate]) -> AgentResponse[Any]:
|
||||
ctx = ctx_holder["ctx"]
|
||||
rf = (
|
||||
ctx.get("chat_options", {}).get("response_format")
|
||||
if ctx
|
||||
else (options.get("response_format") if options else None) # type: ignore[union-attr]
|
||||
return self._finalize_response_updates(
|
||||
updates,
|
||||
response_format=context["chat_options"].get("response_format"),
|
||||
)
|
||||
return self._finalize_response_updates(updates, response_format=rf)
|
||||
|
||||
return (
|
||||
ResponseStream
|
||||
.from_awaitable(_get_stream()) # type: ignore[reportUnknownMemberType]
|
||||
.map(
|
||||
transform=partial(
|
||||
map_chat_to_agent_update,
|
||||
agent_name=self.name,
|
||||
),
|
||||
finalizer=_finalizer,
|
||||
)
|
||||
.with_transform_hook(_propagate_conversation_id)
|
||||
.with_result_hook(_post_hook)
|
||||
stream = stream_response.map(
|
||||
transform=partial(
|
||||
map_chat_to_agent_update,
|
||||
agent_name=self.name,
|
||||
),
|
||||
finalizer=_finalizer,
|
||||
)
|
||||
if context["suppress_response_id"]:
|
||||
stream = stream.with_transform_hook(_suppress_response_id)
|
||||
|
||||
return stream.with_transform_hook(_propagate_conversation_id).with_result_hook(_post_hook)
|
||||
|
||||
def _finalize_response_updates(
|
||||
self,
|
||||
@@ -1111,6 +1190,12 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
if active_session is None and self.context_providers:
|
||||
active_session = AgentSession()
|
||||
|
||||
per_service_call_history_providers = self._resolve_per_service_call_history_providers(
|
||||
session=active_session,
|
||||
options=opts,
|
||||
service_stores_history=bool(store_),
|
||||
)
|
||||
|
||||
session_context, chat_options = await self._prepare_session_and_messages(
|
||||
session=active_session,
|
||||
input_messages=input_messages,
|
||||
@@ -1191,6 +1276,43 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
|
||||
if active_session is not None:
|
||||
effective_client_kwargs["session"] = active_session
|
||||
if per_service_call_history_providers and active_session is not None:
|
||||
per_service_call_history_middleware = PerServiceCallHistoryPersistingMiddleware(
|
||||
agent=self,
|
||||
session=active_session,
|
||||
providers=per_service_call_history_providers,
|
||||
)
|
||||
existing_middleware = effective_client_kwargs.get("middleware")
|
||||
if isinstance(existing_middleware, Sequence) and not isinstance(existing_middleware, (str, bytes)):
|
||||
effective_client_kwargs["middleware"] = [per_service_call_history_middleware, *existing_middleware]
|
||||
elif existing_middleware is not None:
|
||||
effective_client_kwargs["middleware"] = [
|
||||
per_service_call_history_middleware,
|
||||
cast(MiddlewareTypes, existing_middleware),
|
||||
]
|
||||
else:
|
||||
effective_client_kwargs["middleware"] = [per_service_call_history_middleware]
|
||||
provider_middleware = session_context.get_middleware()
|
||||
if provider_middleware:
|
||||
middleware_list = categorize_middleware(provider_middleware)
|
||||
provider_function_chat_middleware = [
|
||||
*middleware_list["function"],
|
||||
*middleware_list["chat"],
|
||||
]
|
||||
if provider_function_chat_middleware:
|
||||
existing_middleware = effective_client_kwargs.get("middleware")
|
||||
if isinstance(existing_middleware, Sequence) and not isinstance(existing_middleware, (str, bytes)):
|
||||
effective_client_kwargs["middleware"] = [
|
||||
*existing_middleware,
|
||||
*provider_function_chat_middleware,
|
||||
]
|
||||
elif existing_middleware is not None:
|
||||
effective_client_kwargs["middleware"] = [
|
||||
cast(MiddlewareTypes, existing_middleware),
|
||||
*provider_function_chat_middleware,
|
||||
]
|
||||
else:
|
||||
effective_client_kwargs["middleware"] = provider_function_chat_middleware
|
||||
|
||||
return {
|
||||
"session": active_session,
|
||||
@@ -1198,6 +1320,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
"input_messages": input_messages,
|
||||
"session_messages": session_messages,
|
||||
"agent_name": agent_name,
|
||||
"suppress_response_id": bool(per_service_call_history_providers),
|
||||
"chat_options": co,
|
||||
"compaction_strategy": compaction_strategy or self.compaction_strategy,
|
||||
"tokenizer": tokenizer or self.tokenizer,
|
||||
@@ -1211,6 +1334,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
agent_name: str,
|
||||
session: AgentSession | None,
|
||||
session_context: SessionContext,
|
||||
suppress_response_id: bool = False,
|
||||
) -> None:
|
||||
"""Finalize response by setting author names and running after_run providers.
|
||||
|
||||
@@ -1219,6 +1343,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
agent_name: The name of the agent to set as author.
|
||||
session: The conversation session.
|
||||
session_context: The invocation context.
|
||||
suppress_response_id: When True, omit the raw service response ID from the public response.
|
||||
"""
|
||||
# Ensure that the author name is set for each message in the response.
|
||||
for message in response.messages:
|
||||
@@ -1228,13 +1353,18 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
# Propagate conversation_id back to session (e.g. thread ID from Assistants API).
|
||||
# For Responses-style APIs this can rotate every turn (response_id-based continuation),
|
||||
# so refresh when a newer value is returned.
|
||||
if session and response.conversation_id and session.service_session_id != response.conversation_id:
|
||||
if (
|
||||
session
|
||||
and response.conversation_id
|
||||
and not is_local_history_conversation_id(response.conversation_id)
|
||||
and session.service_session_id != response.conversation_id
|
||||
):
|
||||
session.service_session_id = response.conversation_id
|
||||
|
||||
# Set the response on the context for after_run providers
|
||||
session_context._response = AgentResponse( # type: ignore[assignment]
|
||||
messages=response.messages,
|
||||
response_id=response.response_id,
|
||||
response_id=None if suppress_response_id else response.response_id,
|
||||
)
|
||||
|
||||
# Run after_run providers (reverse order)
|
||||
@@ -1284,9 +1414,15 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
options=options or {},
|
||||
)
|
||||
|
||||
# Run before_run providers (forward order, skip BaseHistoryProvider with load_messages=False)
|
||||
per_service_call_history_required = self.require_per_service_call_history_persistence and bool(
|
||||
self._get_history_providers()
|
||||
)
|
||||
|
||||
# Run before_run providers (forward order, skip HistoryProvider when per-service-call persistence owns history)
|
||||
for provider in self.context_providers:
|
||||
if isinstance(provider, BaseHistoryProvider) and not provider.load_messages:
|
||||
if per_service_call_history_required and isinstance(provider, HistoryProvider):
|
||||
continue
|
||||
if isinstance(provider, HistoryProvider) and not provider.load_messages:
|
||||
continue
|
||||
if provider_session is None:
|
||||
raise RuntimeError("Provider session must be available when context providers are configured.")
|
||||
@@ -1551,8 +1687,9 @@ class Agent(
|
||||
description: str | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
default_options: OptionsCoT | None = None,
|
||||
context_providers: Sequence[BaseContextProvider] | None = None,
|
||||
context_providers: Sequence[ContextProvider] | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
require_per_service_call_history_persistence: bool = False,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
additional_properties: MutableMapping[str, Any] | None = None,
|
||||
@@ -1568,6 +1705,7 @@ class Agent(
|
||||
default_options=default_options,
|
||||
context_providers=context_providers,
|
||||
middleware=middleware,
|
||||
require_per_service_call_history_persistence=require_per_service_call_history_persistence,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
additional_properties=additional_properties,
|
||||
|
||||
@@ -572,6 +572,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
default_options: OptionsCoT | Mapping[str, Any] | None = None,
|
||||
context_providers: Sequence[Any] | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
require_per_service_call_history_persistence: bool = False,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
@@ -596,6 +597,10 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
and dict literals are accepted without specialized option typing.
|
||||
context_providers: Context providers to include during agent invocation.
|
||||
middleware: List of middleware to intercept agent and function invocations.
|
||||
require_per_service_call_history_persistence: Whether to require per-service-call
|
||||
chat history persistence. When enabled, history providers are invoked around
|
||||
each model call instead of once per ``run()`` when the service is not already
|
||||
storing history.
|
||||
function_invocation_configuration: Optional function invocation configuration override.
|
||||
compaction_strategy: Optional agent-level compaction override. When omitted,
|
||||
client-level compaction defaults remain in effect for each call.
|
||||
@@ -636,6 +641,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
"default_options": cast(Any, default_options),
|
||||
"context_providers": context_providers,
|
||||
"middleware": middleware,
|
||||
"require_per_service_call_history_persistence": require_per_service_call_history_persistence,
|
||||
"compaction_strategy": compaction_strategy,
|
||||
"tokenizer": tokenizer,
|
||||
"additional_properties": dict(additional_properties) if additional_properties is not None else None,
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import (
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from ._sessions import BaseContextProvider
|
||||
from ._sessions import ContextProvider
|
||||
from ._types import ChatResponse, Content, Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -1152,7 +1152,7 @@ async def apply_compaction(
|
||||
COMPACTION_STATE_KEY: Final[str] = "_compaction_messages"
|
||||
|
||||
|
||||
class CompactionProvider(BaseContextProvider):
|
||||
class CompactionProvider(ContextProvider):
|
||||
"""Context provider that compacts messages before and after agent runs.
|
||||
|
||||
This provider accepts two separate strategies:
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
|
||||
This module provides the core types for the context provider pipeline:
|
||||
- SessionContext: Per-invocation state passed through providers
|
||||
- BaseContextProvider: Base class for context providers (renamed to ContextProvider in PR2)
|
||||
- BaseHistoryProvider: Base class for history storage providers (renamed to HistoryProvider in PR2)
|
||||
- ContextProvider: Base class for context providers
|
||||
- HistoryProvider: Base class for history storage providers
|
||||
- AgentSession: Lightweight session state container
|
||||
- InMemoryHistoryProvider: Built-in in-memory history provider
|
||||
"""
|
||||
@@ -13,21 +13,42 @@ This module provides the core types for the context provider pipeline:
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import sys
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, cast
|
||||
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, TypeGuard, cast
|
||||
|
||||
from ._types import AgentResponse, Message
|
||||
if sys.version_info >= (3, 13):
|
||||
from warnings import deprecated # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import deprecated # type: ignore # pragma: no cover
|
||||
|
||||
from ._middleware import ChatContext, ChatMiddleware
|
||||
from ._types import AgentResponse, ChatResponse, Message, ResponseStream
|
||||
from .exceptions import ChatClientInvalidResponseException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._agents import SupportsAgentRun
|
||||
from ._middleware import MiddlewareTypes
|
||||
|
||||
|
||||
# Registry of known types for state deserialization
|
||||
_STATE_TYPE_REGISTRY: dict[str, type] = {}
|
||||
|
||||
|
||||
def _is_middleware_sequence(
|
||||
middleware: MiddlewareTypes | Sequence[MiddlewareTypes],
|
||||
) -> TypeGuard[Sequence[MiddlewareTypes]]:
|
||||
return isinstance(middleware, Sequence) and not isinstance(middleware, (str, bytes))
|
||||
|
||||
|
||||
def _is_single_middleware(
|
||||
middleware: MiddlewareTypes | Sequence[MiddlewareTypes],
|
||||
) -> TypeGuard[MiddlewareTypes]:
|
||||
return not _is_middleware_sequence(middleware)
|
||||
|
||||
|
||||
def register_state_type(cls: type) -> None:
|
||||
"""Register a type for automatic deserialization in session state.
|
||||
|
||||
@@ -131,6 +152,8 @@ class SessionContext:
|
||||
Maintains insertion order (provider execution order).
|
||||
instructions: Additional instructions added by providers.
|
||||
tools: Additional tools added by providers.
|
||||
middleware: Dict mapping source_id -> chat/function middleware added by that provider.
|
||||
Maintains insertion order (provider execution order).
|
||||
response: After invocation, contains the full AgentResponse, should not be changed.
|
||||
options: Options passed to agent.run() - read-only, for reflection only.
|
||||
metadata: Shared metadata dictionary for cross-provider communication.
|
||||
@@ -145,6 +168,7 @@ class SessionContext:
|
||||
context_messages: dict[str, list[Message]] | None = None,
|
||||
instructions: list[str] | None = None,
|
||||
tools: list[Any] | None = None,
|
||||
middleware: dict[str, list[MiddlewareTypes]] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
@@ -157,6 +181,7 @@ class SessionContext:
|
||||
context_messages: Pre-populated context messages by source.
|
||||
instructions: Pre-populated instructions.
|
||||
tools: Pre-populated tools.
|
||||
middleware: Pre-populated chat/function middleware by source.
|
||||
options: Options from agent.run() - read-only for providers.
|
||||
metadata: Shared metadata for cross-provider communication.
|
||||
"""
|
||||
@@ -166,6 +191,10 @@ class SessionContext:
|
||||
self.context_messages: dict[str, list[Message]] = context_messages or {}
|
||||
self.instructions: list[str] = instructions or []
|
||||
self.tools: list[Any] = tools or []
|
||||
self.middleware: dict[str, list[MiddlewareTypes]] = {}
|
||||
if middleware:
|
||||
for source_id, provider_middleware in middleware.items():
|
||||
self.extend_middleware(source_id, provider_middleware)
|
||||
self._response: AgentResponse | None = None
|
||||
self.options: dict[str, Any] = options or {}
|
||||
self.metadata: dict[str, Any] = metadata or {}
|
||||
@@ -236,6 +265,40 @@ class SessionContext:
|
||||
additional_properties["context_source"] = source_id
|
||||
self.tools.extend(tools)
|
||||
|
||||
def extend_middleware(
|
||||
self,
|
||||
source_id: str,
|
||||
middleware: MiddlewareTypes | Sequence[MiddlewareTypes],
|
||||
) -> None:
|
||||
"""Add middleware to be applied for this invocation.
|
||||
|
||||
Args:
|
||||
source_id: The provider source_id adding this middleware.
|
||||
middleware: A single chat/function middleware object/callable or sequence of middleware.
|
||||
"""
|
||||
from ._middleware import categorize_middleware
|
||||
from .exceptions import MiddlewareException
|
||||
|
||||
if _is_middleware_sequence(middleware):
|
||||
middleware_items = list(middleware)
|
||||
elif _is_single_middleware(middleware):
|
||||
middleware_items = [middleware]
|
||||
else:
|
||||
raise TypeError("middleware must be a middleware object or a sequence of middleware objects.")
|
||||
middleware_list = categorize_middleware(middleware_items)
|
||||
if middleware_list["agent"]:
|
||||
raise MiddlewareException("Context providers may only add chat or function middleware.")
|
||||
if source_id not in self.middleware:
|
||||
self.middleware[source_id] = []
|
||||
self.middleware[source_id].extend(middleware_items)
|
||||
|
||||
def get_middleware(self) -> list[MiddlewareTypes]:
|
||||
"""Get provider-added chat/function middleware in provider execution order."""
|
||||
result: list[MiddlewareTypes] = []
|
||||
for middleware_items in self.middleware.values():
|
||||
result.extend(middleware_items)
|
||||
return result
|
||||
|
||||
def get_messages(
|
||||
self,
|
||||
*,
|
||||
@@ -272,17 +335,12 @@ class SessionContext:
|
||||
return result
|
||||
|
||||
|
||||
class BaseContextProvider:
|
||||
"""Base class for context providers (hooks pattern).
|
||||
class ContextProvider:
|
||||
"""Base class for context providers.
|
||||
|
||||
Context providers participate in the context engineering pipeline,
|
||||
adding context before model invocation and processing responses after.
|
||||
|
||||
Note:
|
||||
This class uses a temporary name prefixed with ``_`` to avoid collision
|
||||
with the existing ``ContextProvider`` in ``_memory.py``. It will be
|
||||
renamed to ``ContextProvider`` in PR2 when the old class is removed.
|
||||
|
||||
Attributes:
|
||||
source_id: Unique identifier for this provider instance (required).
|
||||
Used for message/tool attribution so other providers can filter.
|
||||
@@ -312,7 +370,7 @@ class BaseContextProvider:
|
||||
Args:
|
||||
agent: The agent running this invocation.
|
||||
session: The current session.
|
||||
context: The invocation context - add messages/instructions/tools here.
|
||||
context: The invocation context - add messages/instructions/tools/chat/function middleware here.
|
||||
state: The provider-scoped mutable state dict for this provider.
|
||||
Full cross-provider state remains available at ``session.state``.
|
||||
"""
|
||||
@@ -339,7 +397,7 @@ class BaseContextProvider:
|
||||
"""
|
||||
|
||||
|
||||
class BaseHistoryProvider(BaseContextProvider):
|
||||
class HistoryProvider(ContextProvider):
|
||||
"""Base class for conversation history storage providers.
|
||||
|
||||
A single class configurable for different use cases:
|
||||
@@ -347,10 +405,6 @@ class BaseHistoryProvider(BaseContextProvider):
|
||||
- Audit/logging storage (stores only, doesn't load)
|
||||
- Evaluation storage (stores only for later analysis)
|
||||
|
||||
Note:
|
||||
This class uses a temporary name prefixed with ``_`` to avoid collision
|
||||
with existing types. It will be renamed to ``HistoryProvider`` in PR2.
|
||||
|
||||
Subclasses only need to implement ``get_messages()`` and ``save_messages()``.
|
||||
The default ``before_run``/``after_run`` handle loading and storing based on
|
||||
configuration flags. Override them for custom behavior.
|
||||
@@ -467,6 +521,207 @@ class BaseHistoryProvider(BaseContextProvider):
|
||||
await self.save_messages(context.session_id, messages_to_store, state=state)
|
||||
|
||||
|
||||
LOCAL_HISTORY_CONVERSATION_ID = "agent_framework_local_history_persistence"
|
||||
|
||||
|
||||
def is_local_history_conversation_id(conversation_id: str | None) -> bool:
|
||||
"""Return whether a conversation id is the local history-persistence sentinel."""
|
||||
return conversation_id == LOCAL_HISTORY_CONVERSATION_ID
|
||||
|
||||
|
||||
def _response_contains_follow_up_request(response: ChatResponse) -> bool:
|
||||
"""Return whether a response requires another model call in the current run."""
|
||||
return any(
|
||||
item.type in {"function_call", "function_approval_request"}
|
||||
for message in response.messages
|
||||
for item in message.contents
|
||||
)
|
||||
|
||||
|
||||
def _split_service_call_messages(messages: Sequence[Message]) -> tuple[list[Message], dict[str, list[Message]]]:
|
||||
"""Split service-call messages into input messages and attributed context messages."""
|
||||
input_messages: list[Message] = []
|
||||
context_messages: dict[str, list[Message]] = {}
|
||||
for message in messages:
|
||||
attribution = message.additional_properties.get("_attribution")
|
||||
if isinstance(attribution, Mapping):
|
||||
attribution_mapping = cast(Mapping[str, Any], attribution)
|
||||
source_id = attribution_mapping.get("source_id")
|
||||
if isinstance(source_id, str):
|
||||
context_messages.setdefault(source_id, []).append(message)
|
||||
continue
|
||||
input_messages.append(message)
|
||||
return input_messages, context_messages
|
||||
|
||||
|
||||
class PerServiceCallHistoryPersistingMiddleware(ChatMiddleware):
|
||||
"""Persist local chat history after each service call when history is framework-managed.
|
||||
|
||||
This middleware runs around each model call when
|
||||
``require_per_service_call_history_persistence`` is enabled. It loads history providers
|
||||
before the model call, persists them after the model call, and uses a local
|
||||
sentinel conversation id so the function loop follows the existing
|
||||
service-managed branch without forwarding that sentinel to the leaf client.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
agent: SupportsAgentRun,
|
||||
session: AgentSession,
|
||||
providers: Sequence[HistoryProvider],
|
||||
) -> None:
|
||||
"""Initialize the middleware.
|
||||
|
||||
Args:
|
||||
agent: The agent that owns the history providers.
|
||||
session: The active session for the current run.
|
||||
providers: The history providers participating in per-service-call persistence.
|
||||
"""
|
||||
self._agent = agent
|
||||
self._session = session
|
||||
self._providers = list(providers)
|
||||
|
||||
async def _prepare_service_call_context(self, messages: Sequence[Message]) -> SessionContext:
|
||||
"""Create a per-call SessionContext and load history providers into it."""
|
||||
input_messages, context_messages = _split_service_call_messages(messages)
|
||||
service_call_context = SessionContext(
|
||||
session_id=self._session.session_id,
|
||||
service_session_id=None,
|
||||
input_messages=list(input_messages),
|
||||
)
|
||||
for source_id, source_messages in context_messages.items():
|
||||
service_call_context.extend_messages(source_id, source_messages)
|
||||
for provider in self._providers:
|
||||
if not provider.load_messages:
|
||||
continue
|
||||
await provider.before_run(
|
||||
agent=self._agent,
|
||||
session=self._session,
|
||||
context=service_call_context,
|
||||
state=self._session.state.setdefault(provider.source_id, {}),
|
||||
)
|
||||
return service_call_context
|
||||
|
||||
async def _persist_service_call_response(
|
||||
self,
|
||||
*,
|
||||
service_call_context: SessionContext,
|
||||
response: ChatResponse,
|
||||
) -> None:
|
||||
"""Persist a single model-call response through the configured history providers."""
|
||||
service_call_context._response = AgentResponse( # type: ignore[assignment]
|
||||
messages=response.messages,
|
||||
response_id=None,
|
||||
)
|
||||
for provider in reversed(self._providers):
|
||||
await provider.after_run(
|
||||
agent=self._agent,
|
||||
session=self._session,
|
||||
context=service_call_context,
|
||||
state=self._session.state.setdefault(provider.source_id, {}),
|
||||
)
|
||||
|
||||
def _strip_local_conversation_id(self, context: ChatContext) -> None:
|
||||
"""Remove the local sentinel before the leaf chat client is invoked."""
|
||||
if is_local_history_conversation_id(cast(str | None, context.kwargs.get("conversation_id"))):
|
||||
context.kwargs.pop("conversation_id", None)
|
||||
|
||||
if context.options is None:
|
||||
return
|
||||
|
||||
mutable_options = dict(context.options)
|
||||
if is_local_history_conversation_id(cast(str | None, mutable_options.get("conversation_id"))):
|
||||
mutable_options.pop("conversation_id", None)
|
||||
context.options = mutable_options
|
||||
|
||||
async def _finalize_response(
|
||||
self,
|
||||
*,
|
||||
service_call_context: SessionContext,
|
||||
response: ChatResponse,
|
||||
) -> ChatResponse:
|
||||
"""Persist a model response and apply the local follow-up sentinel when needed."""
|
||||
if response.conversation_id is not None and not is_local_history_conversation_id(response.conversation_id):
|
||||
raise ChatClientInvalidResponseException(
|
||||
"require_per_service_call_history_persistence cannot be used "
|
||||
"when the chat client returns a real conversation_id."
|
||||
)
|
||||
|
||||
await self._persist_service_call_response(
|
||||
service_call_context=service_call_context,
|
||||
response=response,
|
||||
)
|
||||
if _response_contains_follow_up_request(response):
|
||||
response.mark_internal_conversation_id()
|
||||
response.conversation_id = LOCAL_HISTORY_CONVERSATION_ID
|
||||
return response
|
||||
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
"""Load and persist history providers around a single model call.
|
||||
|
||||
Args:
|
||||
context: The chat invocation context for the current model call.
|
||||
call_next: The next middleware or the leaf chat client.
|
||||
|
||||
Raises:
|
||||
ChatClientInvalidResponseException: If the leaf client returns a real
|
||||
service-managed conversation id while local per-service-call persistence is enabled.
|
||||
ValueError: If the downstream middleware contract returns the wrong
|
||||
result type for streaming or non-streaming execution.
|
||||
"""
|
||||
service_call_context = await self._prepare_service_call_context(context.messages)
|
||||
context.messages = service_call_context.get_messages(include_input=True)
|
||||
self._strip_local_conversation_id(context)
|
||||
|
||||
await call_next()
|
||||
|
||||
if context.result is None:
|
||||
return
|
||||
|
||||
if context.stream:
|
||||
if not isinstance(context.result, ResponseStream):
|
||||
raise ValueError("Streaming chat middleware requires a ResponseStream result.")
|
||||
context.result = context.result.with_result_hook(
|
||||
lambda response: self._finalize_response(
|
||||
service_call_context=service_call_context,
|
||||
response=response,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(context.result, ResponseStream):
|
||||
raise ValueError("Non-streaming chat middleware requires a ChatResponse result.")
|
||||
context.result = await self._finalize_response(
|
||||
service_call_context=service_call_context,
|
||||
response=context.result,
|
||||
)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"BaseContextProvider is deprecated. Use ContextProvider instead.",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
class BaseContextProvider(ContextProvider):
|
||||
"""Deprecated alias for :class:`ContextProvider`.
|
||||
|
||||
.. deprecated::
|
||||
BaseContextProvider is deprecated. Use :class:`ContextProvider` instead.
|
||||
"""
|
||||
|
||||
|
||||
@deprecated(
|
||||
"BaseHistoryProvider is deprecated. Use HistoryProvider instead.",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
class BaseHistoryProvider(HistoryProvider):
|
||||
"""Deprecated alias for :class:`HistoryProvider`.
|
||||
|
||||
.. deprecated::
|
||||
BaseHistoryProvider is deprecated. Use :class:`HistoryProvider` instead.
|
||||
"""
|
||||
|
||||
|
||||
class AgentSession:
|
||||
"""A conversation session with an agent.
|
||||
|
||||
@@ -535,7 +790,7 @@ class AgentSession:
|
||||
return session
|
||||
|
||||
|
||||
class InMemoryHistoryProvider(BaseHistoryProvider):
|
||||
class InMemoryHistoryProvider(HistoryProvider):
|
||||
"""Built-in history provider that stores messages in session.state.
|
||||
|
||||
Messages are stored in ``state["messages"]`` as a list of
|
||||
|
||||
@@ -36,7 +36,7 @@ from pathlib import Path, PurePosixPath
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Final, Protocol, runtime_checkable
|
||||
|
||||
from ._feature_stage import ExperimentalFeature, experimental
|
||||
from ._sessions import BaseContextProvider
|
||||
from ._sessions import ContextProvider
|
||||
from ._tools import FunctionTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -519,7 +519,7 @@ SCRIPT_RUNNER_INSTRUCTIONS: Final[str] = (
|
||||
|
||||
|
||||
@experimental(feature_id=ExperimentalFeature.SKILLS)
|
||||
class SkillsProvider(BaseContextProvider):
|
||||
class SkillsProvider(ContextProvider):
|
||||
"""Context provider that advertises skills and exposes skill tools.
|
||||
|
||||
Supports both **file-based** skills (discovered from ``SKILL.md`` files)
|
||||
|
||||
@@ -1688,6 +1688,34 @@ def _update_conversation_id(
|
||||
options["conversation_id"] = conversation_id
|
||||
|
||||
|
||||
def _update_continuation_state(
|
||||
kwargs: dict[str, Any],
|
||||
response: ChatResponse[Any],
|
||||
*,
|
||||
session: AgentSession | None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Update in-flight and persisted continuation state from a response."""
|
||||
conversation_id = response.conversation_id
|
||||
if conversation_id is None:
|
||||
return
|
||||
|
||||
_update_conversation_id(kwargs, conversation_id, options)
|
||||
if (
|
||||
session is not None
|
||||
and not response.has_internal_conversation_id()
|
||||
and session.service_session_id != conversation_id
|
||||
):
|
||||
session.service_session_id = conversation_id
|
||||
|
||||
|
||||
def _clear_internal_conversation_id(response: ChatResponse[Any]) -> ChatResponse[Any]:
|
||||
if response.has_internal_conversation_id():
|
||||
response.conversation_id = None
|
||||
response.clear_internal_conversation_id()
|
||||
return response
|
||||
|
||||
|
||||
def _extract_tools(
|
||||
options: dict[str, Any] | None,
|
||||
) -> ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None:
|
||||
@@ -2206,9 +2234,14 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
),
|
||||
)
|
||||
aggregated_usage = add_usage_details(aggregated_usage, response.usage_details)
|
||||
_update_continuation_state(
|
||||
filtered_kwargs,
|
||||
response,
|
||||
session=invocation_session,
|
||||
options=mutable_options,
|
||||
)
|
||||
|
||||
if response.conversation_id is not None:
|
||||
_update_conversation_id(filtered_kwargs, response.conversation_id, mutable_options)
|
||||
prepped_messages = []
|
||||
|
||||
result = await _process_function_requests(
|
||||
@@ -2223,7 +2256,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
)
|
||||
if result.get("action") == "return":
|
||||
response.usage_details = aggregated_usage
|
||||
return response
|
||||
return _clear_internal_conversation_id(response)
|
||||
total_function_calls += result.get("function_call_count", 0)
|
||||
if result.get("action") == "stop":
|
||||
# Error threshold reached: force a final non-tool turn so
|
||||
@@ -2279,11 +2312,17 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
),
|
||||
)
|
||||
aggregated_usage = add_usage_details(aggregated_usage, response.usage_details)
|
||||
_update_continuation_state(
|
||||
filtered_kwargs,
|
||||
response,
|
||||
session=invocation_session,
|
||||
options=mutable_options,
|
||||
)
|
||||
response.usage_details = aggregated_usage
|
||||
if fcc_messages:
|
||||
for msg in reversed(fcc_messages):
|
||||
response.messages.insert(0, msg)
|
||||
return response
|
||||
return _clear_internal_conversation_id(response)
|
||||
|
||||
return _get_response()
|
||||
|
||||
@@ -2343,6 +2382,12 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
# Get the finalized response from the inner stream
|
||||
# This triggers the inner stream's finalizer and result hooks
|
||||
response = await inner_stream.get_final_response()
|
||||
_update_continuation_state(
|
||||
filtered_kwargs,
|
||||
response,
|
||||
session=invocation_session,
|
||||
options=mutable_options,
|
||||
)
|
||||
|
||||
if not any(
|
||||
item.type in ("function_call", "function_approval_request")
|
||||
@@ -2352,7 +2397,6 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
return
|
||||
|
||||
if response.conversation_id is not None:
|
||||
_update_conversation_id(filtered_kwargs, response.conversation_id, mutable_options)
|
||||
prepped_messages = []
|
||||
|
||||
result = await _process_function_requests(
|
||||
@@ -2430,7 +2474,13 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
async for update in final_inner_stream:
|
||||
yield update
|
||||
# Finalize the inner stream to trigger its hooks
|
||||
await final_inner_stream.get_final_response()
|
||||
final_response = await final_inner_stream.get_final_response()
|
||||
_update_continuation_state(
|
||||
filtered_kwargs,
|
||||
final_response,
|
||||
session=invocation_session,
|
||||
options=mutable_options,
|
||||
)
|
||||
|
||||
def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse[Any]:
|
||||
# Note: stream_result_hooks are already run via inner stream's get_final_response()
|
||||
|
||||
@@ -2001,6 +2001,7 @@ class ChatResponse(SerializationMixin, Generic[ResponseModelT]):
|
||||
"""
|
||||
|
||||
DEFAULT_EXCLUDE: ClassVar[set[str]] = {"raw_representation", "additional_properties"}
|
||||
_INTERNAL_CONVERSATION_ID_KEY: ClassVar[str] = "_agent_framework_internal_conversation_id"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -2069,6 +2070,18 @@ class ChatResponse(SerializationMixin, Generic[ResponseModelT]):
|
||||
self.continuation_token = continuation_token
|
||||
self.raw_representation: Any | list[Any] | None = raw_representation
|
||||
|
||||
def mark_internal_conversation_id(self) -> None:
|
||||
"""Mark the current conversation_id as internal control-flow state."""
|
||||
self.additional_properties[self._INTERNAL_CONVERSATION_ID_KEY] = True
|
||||
|
||||
def clear_internal_conversation_id(self) -> None:
|
||||
"""Remove the internal conversation-id marker."""
|
||||
self.additional_properties.pop(self._INTERNAL_CONVERSATION_ID_KEY, None)
|
||||
|
||||
def has_internal_conversation_id(self) -> bool:
|
||||
"""Return whether conversation_id is internal control-flow state."""
|
||||
return bool(self.additional_properties.get(self._INTERNAL_CONVERSATION_ID_KEY, False))
|
||||
|
||||
@property
|
||||
def model_id(self) -> str | None:
|
||||
"""Deprecated alias for :attr:`model`."""
|
||||
|
||||
@@ -14,8 +14,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload
|
||||
from .._agents import BaseAgent
|
||||
from .._sessions import (
|
||||
AgentSession,
|
||||
BaseContextProvider,
|
||||
BaseHistoryProvider,
|
||||
ContextProvider,
|
||||
HistoryProvider,
|
||||
InMemoryHistoryProvider,
|
||||
SessionContext,
|
||||
)
|
||||
@@ -86,7 +86,7 @@ class WorkflowAgent(BaseAgent):
|
||||
id: str | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
context_providers: Sequence[BaseContextProvider] | None = None,
|
||||
context_providers: Sequence[ContextProvider] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the WorkflowAgent.
|
||||
@@ -249,7 +249,7 @@ class WorkflowAgent(BaseAgent):
|
||||
options={},
|
||||
)
|
||||
for provider in self.context_providers:
|
||||
if isinstance(provider, BaseHistoryProvider) and not provider.load_messages:
|
||||
if isinstance(provider, HistoryProvider) and not provider.load_messages:
|
||||
continue
|
||||
if provider_session is None:
|
||||
raise RuntimeError("Provider session must be available when context providers are configured.")
|
||||
@@ -314,7 +314,7 @@ class WorkflowAgent(BaseAgent):
|
||||
options={},
|
||||
)
|
||||
for provider in self.context_providers:
|
||||
if isinstance(provider, BaseHistoryProvider) and not provider.load_messages:
|
||||
if isinstance(provider, HistoryProvider) and not provider.load_messages:
|
||||
continue
|
||||
if provider_session is None:
|
||||
raise RuntimeError("Provider session must be available when context providers are configured.")
|
||||
|
||||
@@ -1502,6 +1502,161 @@ class AgentTelemetryLayer:
|
||||
self.token_usage_histogram = _get_token_usage_histogram()
|
||||
self.duration_histogram = _get_duration_histogram()
|
||||
|
||||
def _trace_agent_invocation(
|
||||
self,
|
||||
*,
|
||||
messages: AgentRunInputs | None,
|
||||
session: AgentSession | None,
|
||||
merged_options: Mapping[str, Any],
|
||||
client_kwargs: Mapping[str, Any] | None,
|
||||
stream: bool,
|
||||
execute: Callable[[], Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]],
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""Trace an agent invocation while delegating execution to ``execute``."""
|
||||
global OBSERVABILITY_SETTINGS
|
||||
from ._types import ResponseStream
|
||||
|
||||
if not OBSERVABILITY_SETTINGS.ENABLED:
|
||||
return execute()
|
||||
|
||||
provider_name = str(self.otel_provider_name)
|
||||
merged_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
|
||||
attributes = _get_span_attributes(
|
||||
operation_name=OtelAttr.AGENT_INVOKE_OPERATION,
|
||||
provider_name=provider_name,
|
||||
agent_id=getattr(self, "id", "unknown"),
|
||||
agent_name=getattr(self, "name", None) or getattr(self, "id", "unknown"),
|
||||
agent_description=getattr(self, "description", None),
|
||||
thread_id=session.service_session_id if session else None,
|
||||
all_options=dict(merged_options),
|
||||
**merged_client_kwargs,
|
||||
)
|
||||
|
||||
inner_response_telemetry_captured_fields: set[str] = set()
|
||||
inner_response_telemetry_captured_fields_token = INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.set(
|
||||
inner_response_telemetry_captured_fields
|
||||
)
|
||||
inner_accumulated_usage_token = INNER_ACCUMULATED_USAGE.set({})
|
||||
|
||||
if stream:
|
||||
try:
|
||||
run_result: object = execute()
|
||||
if isinstance(run_result, ResponseStream):
|
||||
result_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = run_result # pyright: ignore[reportUnknownVariableType]
|
||||
elif isinstance(run_result, Awaitable):
|
||||
result_stream = ResponseStream.from_awaitable(run_result) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
|
||||
else:
|
||||
raise RuntimeError("Streaming telemetry requires a ResponseStream result.")
|
||||
except Exception:
|
||||
INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(inner_response_telemetry_captured_fields_token)
|
||||
INNER_ACCUMULATED_USAGE.reset(inner_accumulated_usage_token)
|
||||
raise
|
||||
|
||||
operation = attributes.get(OtelAttr.OPERATION, "operation")
|
||||
span_name = attributes.get(OtelAttr.AGENT_NAME, "unknown")
|
||||
span = get_tracer().start_span(f"{operation} {span_name}")
|
||||
span.set_attributes(attributes)
|
||||
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages:
|
||||
_capture_messages(
|
||||
span=span,
|
||||
provider_name=provider_name,
|
||||
messages=messages,
|
||||
system_instructions=_get_instructions_from_options(dict(merged_options)),
|
||||
)
|
||||
|
||||
span_state = {"closed": False}
|
||||
duration_state: dict[str, float] = {}
|
||||
start_time = perf_counter()
|
||||
|
||||
def _close_span() -> None:
|
||||
if span_state["closed"]:
|
||||
return
|
||||
span_state["closed"] = True
|
||||
span.end()
|
||||
|
||||
def _record_duration() -> None:
|
||||
duration_state["duration"] = perf_counter() - start_time
|
||||
|
||||
async def _finalize_stream() -> None:
|
||||
from ._types import AgentResponse
|
||||
|
||||
try:
|
||||
response: AgentResponse[Any] = await result_stream.get_final_response()
|
||||
duration = duration_state.get("duration")
|
||||
response_attributes = _get_response_attributes(
|
||||
attributes,
|
||||
response,
|
||||
capture_response_id=INNER_RESPONSE_ID_CAPTURED_FIELD
|
||||
not in inner_response_telemetry_captured_fields,
|
||||
capture_usage=INNER_USAGE_CAPTURED_FIELD not in inner_response_telemetry_captured_fields,
|
||||
)
|
||||
_apply_accumulated_usage(response_attributes, inner_response_telemetry_captured_fields)
|
||||
_capture_response(span=span, attributes=response_attributes, duration=duration)
|
||||
if (
|
||||
OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED
|
||||
and isinstance(response, AgentResponse)
|
||||
and response.messages
|
||||
):
|
||||
_capture_messages(
|
||||
span=span,
|
||||
provider_name=provider_name,
|
||||
messages=response.messages,
|
||||
output=True,
|
||||
)
|
||||
except Exception as exception:
|
||||
capture_exception(span=span, exception=exception, timestamp=time_ns())
|
||||
finally:
|
||||
INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(inner_response_telemetry_captured_fields_token)
|
||||
INNER_ACCUMULATED_USAGE.reset(inner_accumulated_usage_token)
|
||||
_close_span()
|
||||
|
||||
wrapped_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = result_stream.with_cleanup_hook(
|
||||
_record_duration
|
||||
).with_cleanup_hook(_finalize_stream)
|
||||
weakref.finalize(wrapped_stream, _close_span)
|
||||
return wrapped_stream
|
||||
|
||||
async def _run() -> AgentResponse[Any]:
|
||||
try:
|
||||
with _get_span(attributes=attributes, span_name_attribute=OtelAttr.AGENT_NAME) as span:
|
||||
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages:
|
||||
_capture_messages(
|
||||
span=span,
|
||||
provider_name=provider_name,
|
||||
messages=messages,
|
||||
system_instructions=_get_instructions_from_options(dict(merged_options)),
|
||||
)
|
||||
start_time_stamp = perf_counter()
|
||||
try:
|
||||
response: AgentResponse[Any] = await execute()
|
||||
except Exception as exception:
|
||||
capture_exception(span=span, exception=exception, timestamp=time_ns())
|
||||
raise
|
||||
duration = perf_counter() - start_time_stamp
|
||||
if response:
|
||||
response_attributes = _get_response_attributes(
|
||||
attributes,
|
||||
response,
|
||||
capture_response_id=INNER_RESPONSE_ID_CAPTURED_FIELD
|
||||
not in inner_response_telemetry_captured_fields,
|
||||
capture_usage=INNER_USAGE_CAPTURED_FIELD not in inner_response_telemetry_captured_fields,
|
||||
)
|
||||
_apply_accumulated_usage(response_attributes, inner_response_telemetry_captured_fields)
|
||||
_capture_response(span=span, attributes=response_attributes, duration=duration)
|
||||
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and response.messages:
|
||||
_capture_messages(
|
||||
span=span,
|
||||
provider_name=provider_name,
|
||||
messages=response.messages,
|
||||
output=True,
|
||||
)
|
||||
return response # type: ignore[return-value,no-any-return]
|
||||
finally:
|
||||
INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(inner_response_telemetry_captured_fields_token)
|
||||
INNER_ACCUMULATED_USAGE.reset(inner_accumulated_usage_token)
|
||||
|
||||
return _run()
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
@@ -1565,14 +1720,12 @@ class AgentTelemetryLayer:
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""Trace agent runs with OpenTelemetry spans and metrics."""
|
||||
global OBSERVABILITY_SETTINGS
|
||||
from ._types import ResponseStream, merge_chat_options
|
||||
from ._types import merge_chat_options
|
||||
|
||||
super_run = cast(
|
||||
"Callable[..., Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]]",
|
||||
super().run, # type: ignore[misc]
|
||||
)
|
||||
provider_name = str(self.otel_provider_name)
|
||||
super_run_kwargs: dict[str, Any] = {
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
@@ -1586,156 +1739,21 @@ class AgentTelemetryLayer:
|
||||
}
|
||||
if middleware is not None:
|
||||
super_run_kwargs["middleware"] = middleware
|
||||
if not OBSERVABILITY_SETTINGS.ENABLED:
|
||||
return super_run(**super_run_kwargs) # type: ignore[no-any-return]
|
||||
|
||||
default_options = dict(getattr(self, "default_options", {}))
|
||||
merged_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
|
||||
merged_options: dict[str, Any] = merge_chat_options(
|
||||
default_options, dict(options) if options is not None else {}
|
||||
)
|
||||
attributes = _get_span_attributes(
|
||||
operation_name=OtelAttr.AGENT_INVOKE_OPERATION,
|
||||
provider_name=provider_name,
|
||||
agent_id=getattr(self, "id", "unknown"),
|
||||
agent_name=getattr(self, "name", None) or getattr(self, "id", "unknown"),
|
||||
agent_description=getattr(self, "description", None),
|
||||
thread_id=session.service_session_id if session else None,
|
||||
all_options=merged_options,
|
||||
**merged_client_kwargs,
|
||||
return self._trace_agent_invocation(
|
||||
messages=messages,
|
||||
session=session,
|
||||
merged_options=merged_options,
|
||||
client_kwargs=merged_client_kwargs,
|
||||
stream=stream,
|
||||
execute=lambda: super_run(**super_run_kwargs),
|
||||
)
|
||||
|
||||
inner_response_telemetry_captured_fields: set[str] = set()
|
||||
inner_response_telemetry_captured_fields_token = INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.set(
|
||||
inner_response_telemetry_captured_fields
|
||||
)
|
||||
inner_accumulated_usage_token = INNER_ACCUMULATED_USAGE.set({})
|
||||
|
||||
if stream:
|
||||
try:
|
||||
run_result: object = super_run(**super_run_kwargs)
|
||||
if isinstance(run_result, ResponseStream):
|
||||
result_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = run_result # pyright: ignore[reportUnknownVariableType]
|
||||
elif isinstance(run_result, Awaitable):
|
||||
result_stream = ResponseStream.from_awaitable(run_result) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
|
||||
else:
|
||||
raise RuntimeError("Streaming telemetry requires a ResponseStream result.")
|
||||
except Exception:
|
||||
INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(inner_response_telemetry_captured_fields_token)
|
||||
INNER_ACCUMULATED_USAGE.reset(inner_accumulated_usage_token)
|
||||
raise
|
||||
|
||||
# Create span directly without trace.use_span() context attachment.
|
||||
# Streaming spans are closed asynchronously in cleanup hooks, which run
|
||||
# in a different async context than creation — using use_span() would
|
||||
# cause "Failed to detach context" errors from OpenTelemetry.
|
||||
operation = attributes.get(OtelAttr.OPERATION, "operation")
|
||||
span_name = attributes.get(OtelAttr.AGENT_NAME, "unknown")
|
||||
span = get_tracer().start_span(f"{operation} {span_name}")
|
||||
span.set_attributes(attributes)
|
||||
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages:
|
||||
_capture_messages(
|
||||
span=span,
|
||||
provider_name=provider_name,
|
||||
messages=messages,
|
||||
system_instructions=_get_instructions_from_options(merged_options),
|
||||
)
|
||||
|
||||
span_state = {"closed": False}
|
||||
duration_state: dict[str, float] = {}
|
||||
start_time = perf_counter()
|
||||
|
||||
def _close_span() -> None:
|
||||
if span_state["closed"]:
|
||||
return
|
||||
span_state["closed"] = True
|
||||
span.end()
|
||||
|
||||
def _record_duration() -> None:
|
||||
duration_state["duration"] = perf_counter() - start_time
|
||||
|
||||
async def _finalize_stream() -> None:
|
||||
from ._types import AgentResponse
|
||||
|
||||
try:
|
||||
response: AgentResponse[Any] = await result_stream.get_final_response()
|
||||
duration = duration_state.get("duration")
|
||||
response_attributes = _get_response_attributes(
|
||||
attributes,
|
||||
response,
|
||||
capture_response_id=INNER_RESPONSE_ID_CAPTURED_FIELD
|
||||
not in inner_response_telemetry_captured_fields,
|
||||
capture_usage=INNER_USAGE_CAPTURED_FIELD not in inner_response_telemetry_captured_fields,
|
||||
)
|
||||
_apply_accumulated_usage(response_attributes, inner_response_telemetry_captured_fields)
|
||||
_capture_response(span=span, attributes=response_attributes, duration=duration)
|
||||
if (
|
||||
OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED
|
||||
and isinstance(response, AgentResponse)
|
||||
and response.messages
|
||||
):
|
||||
_capture_messages(
|
||||
span=span,
|
||||
provider_name=provider_name,
|
||||
messages=response.messages,
|
||||
output=True,
|
||||
)
|
||||
except Exception as exception:
|
||||
capture_exception(span=span, exception=exception, timestamp=time_ns())
|
||||
finally:
|
||||
INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(inner_response_telemetry_captured_fields_token)
|
||||
INNER_ACCUMULATED_USAGE.reset(inner_accumulated_usage_token)
|
||||
_close_span()
|
||||
|
||||
# Register a weak reference callback to close the span if stream is garbage collected
|
||||
# without being consumed. This ensures spans don't leak if users don't consume streams.
|
||||
wrapped_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = result_stream.with_cleanup_hook(
|
||||
_record_duration
|
||||
).with_cleanup_hook(_finalize_stream)
|
||||
weakref.finalize(wrapped_stream, _close_span)
|
||||
return wrapped_stream
|
||||
|
||||
async def _run() -> AgentResponse:
|
||||
try:
|
||||
with _get_span(attributes=attributes, span_name_attribute=OtelAttr.AGENT_NAME) as span:
|
||||
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages:
|
||||
_capture_messages(
|
||||
span=span,
|
||||
provider_name=provider_name,
|
||||
messages=messages,
|
||||
system_instructions=_get_instructions_from_options(merged_options),
|
||||
)
|
||||
start_time_stamp = perf_counter()
|
||||
try:
|
||||
response: AgentResponse[Any] = await super_run(**super_run_kwargs)
|
||||
except Exception as exception:
|
||||
capture_exception(span=span, exception=exception, timestamp=time_ns())
|
||||
raise
|
||||
duration = perf_counter() - start_time_stamp
|
||||
if response:
|
||||
response_attributes = _get_response_attributes(
|
||||
attributes,
|
||||
response,
|
||||
capture_response_id=INNER_RESPONSE_ID_CAPTURED_FIELD
|
||||
not in inner_response_telemetry_captured_fields,
|
||||
capture_usage=INNER_USAGE_CAPTURED_FIELD not in inner_response_telemetry_captured_fields,
|
||||
)
|
||||
_apply_accumulated_usage(response_attributes, inner_response_telemetry_captured_fields)
|
||||
_capture_response(span=span, attributes=response_attributes, duration=duration)
|
||||
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and response.messages:
|
||||
_capture_messages(
|
||||
span=span,
|
||||
provider_name=provider_name,
|
||||
messages=response.messages,
|
||||
output=True,
|
||||
)
|
||||
return response # type: ignore[return-value,no-any-return]
|
||||
finally:
|
||||
INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(inner_response_telemetry_captured_fields_token)
|
||||
INNER_ACCUMULATED_USAGE.reset(inner_accumulated_usage_token)
|
||||
|
||||
return _run()
|
||||
|
||||
|
||||
# region Otel Helpers
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
import contextlib
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import AsyncIterable, MutableSequence
|
||||
from typing import Any
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, MutableSequence, Sequence
|
||||
from typing import Any, cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -18,22 +18,29 @@ from agent_framework import (
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentSession,
|
||||
BaseContextProvider,
|
||||
ChatContext,
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
Content,
|
||||
ContextProvider,
|
||||
FunctionTool,
|
||||
HistoryProvider,
|
||||
InMemoryHistoryProvider,
|
||||
Message,
|
||||
ResponseStream,
|
||||
SessionContext,
|
||||
SlidingWindowStrategy,
|
||||
SupportsAgentRun,
|
||||
SupportsChatGetResponse,
|
||||
TruncationStrategy,
|
||||
chat_middleware,
|
||||
tool,
|
||||
)
|
||||
from agent_framework._agents import _get_tool_name, _merge_options, _sanitize_agent_name
|
||||
from agent_framework._mcp import MCPTool, _build_prefixed_mcp_name, _normalize_mcp_name
|
||||
from agent_framework._middleware import FunctionInvocationContext
|
||||
from agent_framework.exceptions import AgentInvalidRequestException, ChatClientInvalidResponseException
|
||||
|
||||
|
||||
class _FixedTokenizer:
|
||||
@@ -68,6 +75,49 @@ class _ConnectedMCPTool(MCPTool):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _RecordingHistoryProvider(HistoryProvider):
|
||||
def __init__(self, source_id: str = "recording_history") -> None:
|
||||
super().__init__(source_id=source_id)
|
||||
|
||||
async def get_messages(
|
||||
self,
|
||||
session_id: str | None,
|
||||
*,
|
||||
state: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[Message]:
|
||||
if state is None:
|
||||
return []
|
||||
state["get_call_count"] = state.get("get_call_count", 0) + 1
|
||||
return list(cast(list[Message], state.get("messages", [])))
|
||||
|
||||
async def save_messages(
|
||||
self,
|
||||
session_id: str | None,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
state: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if state is None:
|
||||
return
|
||||
state["save_call_count"] = state.get("save_call_count", 0) + 1
|
||||
state.setdefault("messages", []).extend(messages)
|
||||
|
||||
|
||||
class _ResponseIdRecordingHistoryProvider(_RecordingHistoryProvider):
|
||||
async def after_run(
|
||||
self,
|
||||
*,
|
||||
agent: SupportsAgentRun,
|
||||
session: AgentSession,
|
||||
context: SessionContext,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
state.setdefault("response_ids", []).append(context.response.response_id if context.response else None)
|
||||
await super().after_run(agent=agent, session=session, context=context, state=state)
|
||||
|
||||
|
||||
def test_agent_session_type(agent_session: AgentSession) -> None:
|
||||
assert isinstance(agent_session, AgentSession)
|
||||
|
||||
@@ -314,6 +364,413 @@ async def test_prepare_run_context_handles_function_kwargs(
|
||||
assert ctx["client_kwargs"]["session"] is session
|
||||
|
||||
|
||||
async def test_chat_agent_persists_history_per_service_call(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
provider = _RecordingHistoryProvider()
|
||||
|
||||
@tool(name="lookup_weather", approval_mode="never_require")
|
||||
def lookup_weather(location: str) -> str:
|
||||
return f"Weather in {location}: sunny"
|
||||
|
||||
session = AgentSession()
|
||||
session.state[provider.source_id] = {
|
||||
"messages": [
|
||||
Message(role="user", text="Earlier question"),
|
||||
Message(role="assistant", text="Earlier answer"),
|
||||
]
|
||||
}
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_1",
|
||||
name="lookup_weather",
|
||||
arguments='{"location": "Seattle"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
response_id="resp_call_1",
|
||||
),
|
||||
ChatResponse(messages=Message(role="assistant", text="It is sunny in Seattle."), response_id="resp_call_2"),
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
tools=[lookup_weather],
|
||||
context_providers=[provider],
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
result = await agent.run("What's the weather in Seattle?", session=session)
|
||||
|
||||
provider_state = session.state[provider.source_id]
|
||||
stored_messages = cast(list[Message], provider_state["messages"])
|
||||
|
||||
assert result.text == "It is sunny in Seattle."
|
||||
assert result.response_id is None
|
||||
assert chat_client_base.call_count == 2
|
||||
assert provider_state["get_call_count"] == 2
|
||||
assert provider_state["save_call_count"] == 2
|
||||
assert stored_messages[-1].text == "It is sunny in Seattle."
|
||||
assert session.service_session_id is None
|
||||
|
||||
|
||||
async def test_chat_agent_persists_history_per_service_call_streaming(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
provider = _RecordingHistoryProvider()
|
||||
|
||||
@tool(name="lookup_weather", approval_mode="never_require")
|
||||
def lookup_weather(location: str) -> str:
|
||||
return f"Weather in {location}: sunny"
|
||||
|
||||
session = AgentSession()
|
||||
session.state[provider.source_id] = {
|
||||
"messages": [
|
||||
Message(role="user", text="Earlier question"),
|
||||
Message(role="assistant", text="Earlier answer"),
|
||||
]
|
||||
}
|
||||
chat_client_base.streaming_responses = [
|
||||
[
|
||||
ChatResponseUpdate(
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_1",
|
||||
name="lookup_weather",
|
||||
arguments='{"location": "Seattle"}',
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
finish_reason="stop",
|
||||
response_id="resp_call_1",
|
||||
)
|
||||
],
|
||||
[
|
||||
ChatResponseUpdate(
|
||||
contents=[Content.from_text("It is sunny in Seattle.")],
|
||||
role="assistant",
|
||||
finish_reason="stop",
|
||||
response_id="resp_call_2",
|
||||
)
|
||||
],
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
tools=[lookup_weather],
|
||||
context_providers=[provider],
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
stream = agent.run("What's the weather in Seattle?", session=session, stream=True)
|
||||
async for _ in stream:
|
||||
pass
|
||||
result = await stream.get_final_response()
|
||||
|
||||
provider_state = session.state[provider.source_id]
|
||||
stored_messages = cast(list[Message], provider_state["messages"])
|
||||
|
||||
assert result.text == "It is sunny in Seattle."
|
||||
assert result.response_id is None
|
||||
assert chat_client_base.call_count == 2
|
||||
assert provider_state["get_call_count"] == 2
|
||||
assert provider_state["save_call_count"] == 2
|
||||
assert stored_messages[-1].text == "It is sunny in Seattle."
|
||||
assert session.service_session_id is None
|
||||
|
||||
|
||||
async def test_streaming_per_service_call_persistence_hides_response_id_from_after_run(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
provider = _ResponseIdRecordingHistoryProvider()
|
||||
|
||||
@tool(name="lookup_weather", approval_mode="never_require")
|
||||
def lookup_weather(location: str) -> str:
|
||||
return f"Weather in {location}: sunny"
|
||||
|
||||
session = AgentSession()
|
||||
session.state[provider.source_id] = {"messages": []}
|
||||
chat_client_base.streaming_responses = [
|
||||
[
|
||||
ChatResponseUpdate(
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_1",
|
||||
name="lookup_weather",
|
||||
arguments='{"location": "Seattle"}',
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
finish_reason="stop",
|
||||
response_id="resp_call_1",
|
||||
)
|
||||
],
|
||||
[
|
||||
ChatResponseUpdate(
|
||||
contents=[Content.from_text("It is sunny in Seattle.")],
|
||||
role="assistant",
|
||||
finish_reason="stop",
|
||||
response_id="resp_call_2",
|
||||
)
|
||||
],
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
tools=[lookup_weather],
|
||||
context_providers=[provider],
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
stream = agent.run("What's the weather in Seattle?", session=session, stream=True)
|
||||
async for _ in stream:
|
||||
pass
|
||||
result = await stream.get_final_response()
|
||||
|
||||
provider_state = session.state[provider.source_id]
|
||||
|
||||
assert result.response_id is None
|
||||
assert provider_state["response_ids"] == [None, None]
|
||||
|
||||
|
||||
async def test_per_service_call_persistence_uses_real_service_storage_when_client_stores_by_default(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
provider = _RecordingHistoryProvider()
|
||||
|
||||
@tool(name="lookup_weather", approval_mode="never_require")
|
||||
def lookup_weather(location: str) -> str:
|
||||
return f"Weather in {location}: sunny"
|
||||
|
||||
chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined]
|
||||
|
||||
session = AgentSession()
|
||||
session.state[provider.source_id] = {"messages": []}
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_1",
|
||||
name="lookup_weather",
|
||||
arguments='{"location": "Seattle"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
conversation_id="resp_service_managed",
|
||||
response_id="resp_call_1",
|
||||
),
|
||||
ChatResponse(
|
||||
messages=Message(role="assistant", text="It is sunny in Seattle."),
|
||||
conversation_id="resp_service_managed",
|
||||
response_id="resp_call_2",
|
||||
),
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
tools=[lookup_weather],
|
||||
context_providers=[provider],
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
result = await agent.run("What's the weather in Seattle?", session=session)
|
||||
|
||||
provider_state = session.state[provider.source_id]
|
||||
|
||||
assert result.text == "It is sunny in Seattle."
|
||||
assert result.response_id == "resp_call_2"
|
||||
assert chat_client_base.call_count == 2
|
||||
assert "get_call_count" not in provider_state
|
||||
assert "save_call_count" not in provider_state
|
||||
assert session.service_session_id == "resp_service_managed"
|
||||
|
||||
|
||||
async def test_service_storage_updates_session_handle_per_service_call_before_non_streaming_failure(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
provider = _RecordingHistoryProvider()
|
||||
|
||||
@tool(name="lookup_weather", approval_mode="never_require")
|
||||
def lookup_weather(location: str) -> str:
|
||||
return f"Weather in {location}: sunny"
|
||||
|
||||
chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined]
|
||||
|
||||
session = AgentSession()
|
||||
session.state[provider.source_id] = {"messages": []}
|
||||
first_response = ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_1",
|
||||
name="lookup_weather",
|
||||
arguments='{"location": "Seattle"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
conversation_id="resp_call_1",
|
||||
response_id="resp_call_1",
|
||||
)
|
||||
mock_get_non_streaming_response = AsyncMock(
|
||||
side_effect=[first_response, RuntimeError("service down")],
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
tools=[lookup_weather],
|
||||
context_providers=[provider],
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(chat_client_base, "_get_non_streaming_response", new=mock_get_non_streaming_response),
|
||||
pytest.raises(RuntimeError, match="service down"),
|
||||
):
|
||||
await agent.run("What's the weather in Seattle?", session=session)
|
||||
|
||||
assert mock_get_non_streaming_response.await_count == 2
|
||||
assert session.service_session_id == "resp_call_1"
|
||||
|
||||
|
||||
async def test_service_storage_updates_session_handle_per_service_call_before_streaming_failure(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
provider = _RecordingHistoryProvider()
|
||||
|
||||
@tool(name="lookup_weather", approval_mode="never_require")
|
||||
def lookup_weather(location: str) -> str:
|
||||
return f"Weather in {location}: sunny"
|
||||
|
||||
chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined]
|
||||
|
||||
session = AgentSession()
|
||||
session.state[provider.source_id] = {"messages": []}
|
||||
|
||||
async def _first_stream_updates() -> AsyncIterable[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_1",
|
||||
name="lookup_weather",
|
||||
arguments='{"location": "Seattle"}',
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
def _finalize_first_stream(_updates: Sequence[ChatResponseUpdate]) -> ChatResponse[Any]:
|
||||
return ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_1",
|
||||
name="lookup_weather",
|
||||
arguments='{"location": "Seattle"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
conversation_id="resp_call_1",
|
||||
response_id="resp_call_1",
|
||||
)
|
||||
|
||||
first_stream = ResponseStream(_first_stream_updates(), finalizer=_finalize_first_stream)
|
||||
mock_get_streaming_response = MagicMock(side_effect=[first_stream, RuntimeError("service down")])
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
tools=[lookup_weather],
|
||||
context_providers=[provider],
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(chat_client_base, "_get_streaming_response", new=mock_get_streaming_response),
|
||||
pytest.raises(RuntimeError, match="service down"),
|
||||
):
|
||||
stream = agent.run("What's the weather in Seattle?", session=session, stream=True)
|
||||
async for _ in stream:
|
||||
pass
|
||||
|
||||
assert mock_get_streaming_response.call_count == 2
|
||||
assert session.service_session_id == "resp_call_1"
|
||||
|
||||
|
||||
async def test_chat_agent_without_per_service_call_persistence_preserves_response_id(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(role="assistant", text="Hello"),
|
||||
response_id="resp_call_1",
|
||||
)
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
context_providers=[InMemoryHistoryProvider()],
|
||||
)
|
||||
|
||||
result = await agent.run("Hello", session=AgentSession(), options={"store": False})
|
||||
|
||||
assert result.response_id == "resp_call_1"
|
||||
|
||||
|
||||
async def test_per_service_call_persistence_rejects_real_service_conversation_id(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
provider = _RecordingHistoryProvider()
|
||||
chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined]
|
||||
session = AgentSession()
|
||||
session.state[provider.source_id] = {"messages": []}
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(role="assistant", text="Hello"),
|
||||
conversation_id="resp_service_managed",
|
||||
)
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
context_providers=[provider],
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ChatClientInvalidResponseException,
|
||||
match="require_per_service_call_history_persistence cannot be used",
|
||||
):
|
||||
await agent.run("Hello", session=session, options={"store": False})
|
||||
|
||||
|
||||
async def test_per_service_call_persistence_rejects_existing_conversation_id_when_service_not_storing_history(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
provider = _RecordingHistoryProvider()
|
||||
session = AgentSession()
|
||||
session.state[provider.source_id] = {"messages": []}
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
context_providers=[provider],
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
AgentInvalidRequestException,
|
||||
match="require_per_service_call_history_persistence cannot be used",
|
||||
):
|
||||
await agent.run("Hello", session=session, options={"store": False, "conversation_id": "existing_conversation"})
|
||||
|
||||
|
||||
async def test_chat_client_agent_run_with_session(chat_client_base: SupportsChatGetResponse) -> None:
|
||||
mock_response = ChatResponse(
|
||||
messages=[Message(role="assistant", contents=[Content.from_text("test response")])],
|
||||
@@ -586,7 +1043,7 @@ async def test_chat_client_agent_author_name_is_used_from_response(
|
||||
|
||||
|
||||
# Mock context provider for testing
|
||||
class MockContextProvider(BaseContextProvider):
|
||||
class MockContextProvider(ContextProvider):
|
||||
def __init__(self, messages: list[Message] | None = None) -> None:
|
||||
super().__init__(source_id="mock")
|
||||
self.context_messages = messages
|
||||
@@ -1723,7 +2180,7 @@ async def test_agent_create_session_with_context_providers(
|
||||
):
|
||||
"""Test that create_session works when context_providers are set on the agent."""
|
||||
|
||||
class TestContextProvider(BaseContextProvider):
|
||||
class TestContextProvider(ContextProvider):
|
||||
def __init__(self):
|
||||
super().__init__(source_id="test")
|
||||
|
||||
@@ -1798,7 +2255,7 @@ async def test_chat_agent_context_provider_adds_tools_when_agent_has_none(
|
||||
"""A tool provided by context."""
|
||||
return text
|
||||
|
||||
class ToolContextProvider(BaseContextProvider):
|
||||
class ToolContextProvider(ContextProvider):
|
||||
def __init__(self):
|
||||
super().__init__(source_id="tool-context")
|
||||
|
||||
@@ -1827,7 +2284,7 @@ async def test_chat_agent_context_provider_adds_instructions_when_agent_has_none
|
||||
):
|
||||
"""Test that context provider instructions are used when agent has no default instructions."""
|
||||
|
||||
class InstructionContextProvider(BaseContextProvider):
|
||||
class InstructionContextProvider(ContextProvider):
|
||||
def __init__(self):
|
||||
super().__init__(source_id="instruction-context")
|
||||
|
||||
@@ -1849,6 +2306,33 @@ async def test_chat_agent_context_provider_adds_instructions_when_agent_has_none
|
||||
assert options.get("instructions") == "Context-provided instructions"
|
||||
|
||||
|
||||
async def test_chat_agent_context_provider_adds_middleware_when_agent_has_none(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
"""Test that context provider middleware is collected during preparation."""
|
||||
|
||||
@chat_middleware
|
||||
async def context_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
await call_next()
|
||||
|
||||
class MiddlewareContextProvider(ContextProvider):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(source_id="middleware-context")
|
||||
|
||||
async def before_run(self, *, agent, session, context, state) -> None:
|
||||
context.extend_middleware("middleware-context", context_chat_middleware)
|
||||
|
||||
agent = Agent(client=chat_client_base, context_providers=[MiddlewareContextProvider()])
|
||||
|
||||
session_context, _ = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage]
|
||||
session=None,
|
||||
input_messages=[Message(role="user", text="Hello")],
|
||||
)
|
||||
|
||||
assert session_context.middleware["middleware-context"] == [context_chat_middleware]
|
||||
assert session_context.get_middleware() == [context_chat_middleware]
|
||||
|
||||
|
||||
# region STORES_BY_DEFAULT tests
|
||||
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from agent_framework import (
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
Content,
|
||||
ContextProvider,
|
||||
FunctionInvocationContext,
|
||||
FunctionMiddleware,
|
||||
FunctionTool,
|
||||
@@ -464,6 +465,31 @@ class TestChatAgentMultipleMiddlewareOrdering:
|
||||
expected_order = ["class_agent_before", "function_agent_before", "function_agent_after", "class_agent_after"]
|
||||
assert execution_order == expected_order
|
||||
|
||||
async def test_provider_added_agent_middleware_is_rejected(self, chat_client_base: "MockBaseChatClient") -> None:
|
||||
"""Test provider-added agent middleware is rejected explicitly."""
|
||||
|
||||
@agent_middleware
|
||||
async def provider_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
await call_next()
|
||||
|
||||
class ProviderMiddlewareContextProvider(ContextProvider):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(source_id="provider-middleware")
|
||||
|
||||
async def before_run(self, *, agent, session, context, state) -> None:
|
||||
context.extend_middleware(self.source_id, provider_middleware)
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
context_providers=[ProviderMiddlewareContextProvider()],
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
MiddlewareException,
|
||||
match="Context providers may only add chat or function middleware",
|
||||
):
|
||||
await agent.run([Message(role="user", text="test message")])
|
||||
|
||||
|
||||
# region Tool Functions for Testing
|
||||
|
||||
@@ -2066,6 +2092,121 @@ class TestChatAgentChatMiddleware:
|
||||
"agent_middleware_after",
|
||||
]
|
||||
|
||||
async def test_provider_added_chat_and_function_middleware_are_forwarded(
|
||||
self, chat_client_base: "MockBaseChatClient"
|
||||
) -> None:
|
||||
"""Test provider-added chat and function middleware forwarding and ordering."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@chat_middleware
|
||||
async def constructor_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("constructor_chat_before")
|
||||
await call_next()
|
||||
execution_order.append("constructor_chat_after")
|
||||
|
||||
@chat_middleware
|
||||
async def provider_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("provider_chat_before")
|
||||
await call_next()
|
||||
execution_order.append("provider_chat_after")
|
||||
|
||||
@chat_middleware
|
||||
async def run_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("run_chat_before")
|
||||
await call_next()
|
||||
execution_order.append("run_chat_after")
|
||||
|
||||
@function_middleware
|
||||
async def constructor_function_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("constructor_function_before")
|
||||
await call_next()
|
||||
execution_order.append("constructor_function_after")
|
||||
|
||||
@function_middleware
|
||||
async def provider_function_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("provider_function_before")
|
||||
await call_next()
|
||||
execution_order.append("provider_function_after")
|
||||
|
||||
@function_middleware
|
||||
async def run_function_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("run_function_before")
|
||||
await call_next()
|
||||
execution_order.append("run_function_after")
|
||||
|
||||
class ProviderMiddlewareContextProvider(ContextProvider):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(source_id="provider-middleware")
|
||||
|
||||
async def before_run(self, *, agent, session, context, state) -> None:
|
||||
context.extend_middleware(
|
||||
self.source_id,
|
||||
[
|
||||
provider_chat_middleware,
|
||||
provider_function_middleware,
|
||||
],
|
||||
)
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=[
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_provider",
|
||||
name="sample_tool_function",
|
||||
arguments='{"location": "Seattle"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
ChatResponse(messages=[Message(role="assistant", text="Final response")]),
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
middleware=[constructor_chat_middleware, constructor_function_middleware],
|
||||
context_providers=[ProviderMiddlewareContextProvider()],
|
||||
tools=[sample_tool_function],
|
||||
)
|
||||
|
||||
response = await agent.run(
|
||||
[Message(role="user", text="Get weather for Seattle")],
|
||||
middleware=[run_chat_middleware, run_function_middleware],
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert chat_client_base.call_count == 2
|
||||
assert response.messages[-1].text == "Final response"
|
||||
assert execution_order == [
|
||||
"constructor_chat_before",
|
||||
"run_chat_before",
|
||||
"provider_chat_before",
|
||||
"provider_chat_after",
|
||||
"run_chat_after",
|
||||
"constructor_chat_after",
|
||||
"constructor_function_before",
|
||||
"run_function_before",
|
||||
"provider_function_before",
|
||||
"provider_function_after",
|
||||
"run_function_after",
|
||||
"constructor_function_after",
|
||||
"constructor_chat_before",
|
||||
"run_chat_before",
|
||||
"provider_chat_before",
|
||||
"provider_chat_after",
|
||||
"run_chat_after",
|
||||
"constructor_chat_after",
|
||||
]
|
||||
|
||||
async def test_agent_middleware_can_access_and_override_options(self) -> None:
|
||||
"""Test that agent middleware can access and override runtime options."""
|
||||
captured_options: dict[str, Any] = {}
|
||||
|
||||
@@ -1,16 +1,26 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
|
||||
from agent_framework import Message
|
||||
from agent_framework._sessions import (
|
||||
import pytest
|
||||
|
||||
from agent_framework import (
|
||||
AgentContext,
|
||||
AgentSession,
|
||||
BaseContextProvider,
|
||||
BaseHistoryProvider,
|
||||
ChatContext,
|
||||
ContextProvider,
|
||||
HistoryProvider,
|
||||
InMemoryHistoryProvider,
|
||||
Message,
|
||||
SessionContext,
|
||||
agent_middleware,
|
||||
chat_middleware,
|
||||
)
|
||||
from agent_framework._sessions import LOCAL_HISTORY_CONVERSATION_ID, is_local_history_conversation_id
|
||||
from agent_framework.exceptions import MiddlewareException
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SessionContext tests
|
||||
@@ -102,6 +112,50 @@ class TestSessionContext:
|
||||
ctx.extend_instructions("sys", ["Be helpful", "Be concise"])
|
||||
assert ctx.instructions == ["Be helpful", "Be concise"]
|
||||
|
||||
def test_extend_middleware_creates_key_and_appends(self) -> None:
|
||||
ctx = SessionContext(input_messages=[])
|
||||
|
||||
@chat_middleware
|
||||
async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
await call_next()
|
||||
|
||||
@chat_middleware
|
||||
async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
await call_next()
|
||||
|
||||
ctx.extend_middleware("rag", first_middleware)
|
||||
ctx.extend_middleware("rag", [second_middleware])
|
||||
|
||||
assert ctx.middleware["rag"] == [first_middleware, second_middleware]
|
||||
assert ctx.get_middleware() == [first_middleware, second_middleware]
|
||||
|
||||
def test_extend_middleware_preserves_source_order(self) -> None:
|
||||
ctx = SessionContext(input_messages=[])
|
||||
|
||||
@chat_middleware
|
||||
async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
await call_next()
|
||||
|
||||
@chat_middleware
|
||||
async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
await call_next()
|
||||
|
||||
ctx.extend_middleware("a", first_middleware)
|
||||
ctx.extend_middleware("b", second_middleware)
|
||||
|
||||
assert list(ctx.middleware.keys()) == ["a", "b"]
|
||||
assert ctx.get_middleware() == [first_middleware, second_middleware]
|
||||
|
||||
def test_extend_middleware_rejects_agent_middleware(self) -> None:
|
||||
ctx = SessionContext(input_messages=[])
|
||||
|
||||
@agent_middleware
|
||||
async def provider_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
await call_next()
|
||||
|
||||
with pytest.raises(MiddlewareException, match="Context providers may only add chat or function middleware"):
|
||||
ctx.extend_middleware("rag", provider_agent_middleware)
|
||||
|
||||
def test_get_messages_all(self) -> None:
|
||||
ctx = SessionContext(input_messages=[])
|
||||
ctx.extend_messages("a", [Message(role="user", contents=["a"])])
|
||||
@@ -154,37 +208,58 @@ class TestSessionContext:
|
||||
ctx._response = resp
|
||||
assert ctx.response is resp
|
||||
|
||||
def test_local_history_conversation_id_sentinel(self) -> None:
|
||||
assert is_local_history_conversation_id(LOCAL_HISTORY_CONVERSATION_ID) is True
|
||||
assert is_local_history_conversation_id("some_other_id") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BaseContextProvider tests
|
||||
# ContextProvider tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestContextProviderBase:
|
||||
class TestContextProvider:
|
||||
def test_source_id_required(self) -> None:
|
||||
provider = BaseContextProvider(source_id="test")
|
||||
provider = ContextProvider(source_id="test")
|
||||
assert provider.source_id == "test"
|
||||
|
||||
async def test_before_run_is_noop(self) -> None:
|
||||
provider = BaseContextProvider(source_id="test")
|
||||
provider = ContextProvider(source_id="test")
|
||||
session = AgentSession()
|
||||
ctx = SessionContext(input_messages=[])
|
||||
# Should not raise
|
||||
await provider.before_run(agent=None, session=session, context=ctx, state={}) # type: ignore[arg-type]
|
||||
|
||||
async def test_after_run_is_noop(self) -> None:
|
||||
provider = BaseContextProvider(source_id="test")
|
||||
provider = ContextProvider(source_id="test")
|
||||
session = AgentSession()
|
||||
ctx = SessionContext(input_messages=[])
|
||||
await provider.after_run(agent=None, session=session, context=ctx, state={}) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BaseHistoryProvider tests
|
||||
# Deprecated provider alias tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ConcreteHistoryProvider(BaseHistoryProvider):
|
||||
class TestDeprecatedProviderAliases:
|
||||
def test_base_context_provider_warns_and_is_compatible(self) -> None:
|
||||
with pytest.warns(DeprecationWarning, match="BaseContextProvider is deprecated. Use ContextProvider instead."):
|
||||
provider = BaseContextProvider(source_id="test")
|
||||
|
||||
assert isinstance(provider, ContextProvider)
|
||||
|
||||
def test_base_provider_aliases_preserve_subtyping(self) -> None:
|
||||
assert issubclass(BaseContextProvider, ContextProvider)
|
||||
assert issubclass(BaseHistoryProvider, HistoryProvider)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HistoryProvider tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ConcreteHistoryProvider(HistoryProvider):
|
||||
"""Concrete test implementation."""
|
||||
|
||||
def __init__(self, source_id: str, stored_messages: list[Message] | None = None, **kwargs) -> None:
|
||||
|
||||
@@ -17,9 +17,9 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic, cast
|
||||
from agent_framework import (
|
||||
AGENT_FRAMEWORK_USER_AGENT,
|
||||
AgentMiddlewareLayer,
|
||||
BaseContextProvider,
|
||||
ChatAndFunctionMiddlewareTypes,
|
||||
ChatMiddlewareLayer,
|
||||
ContextProvider,
|
||||
FunctionInvocationConfiguration,
|
||||
FunctionInvocationLayer,
|
||||
FunctionTool,
|
||||
@@ -50,8 +50,8 @@ else:
|
||||
if TYPE_CHECKING:
|
||||
from agent_framework import (
|
||||
Agent,
|
||||
BaseContextProvider,
|
||||
ChatAndFunctionMiddlewareTypes,
|
||||
ContextProvider,
|
||||
MiddlewareTypes,
|
||||
ToolTypes,
|
||||
)
|
||||
@@ -224,8 +224,9 @@ class RawFoundryAgentChatClient( # type: ignore[misc]
|
||||
instructions: str | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
default_options: FoundryAgentOptionsT | Mapping[str, Any] | None = None,
|
||||
context_providers: Sequence[BaseContextProvider] | None = None,
|
||||
context_providers: Sequence[ContextProvider] | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
require_per_service_call_history_persistence: bool = False,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
@@ -246,6 +247,7 @@ class RawFoundryAgentChatClient( # type: ignore[misc]
|
||||
tools=function_tools,
|
||||
context_providers=context_providers,
|
||||
middleware=middleware,
|
||||
require_per_service_call_history_persistence=require_per_service_call_history_persistence,
|
||||
client_type=cast(type[RawFoundryAgentChatClient], self.__class__),
|
||||
id=id,
|
||||
name=self.agent_name if name is None else name,
|
||||
@@ -468,7 +470,7 @@ class RawFoundryAgent( # type: ignore[misc]
|
||||
project_client: AIProjectClient | None = None,
|
||||
allow_preview: bool | None = None,
|
||||
tools: FunctionTool | Callable[..., Any] | Sequence[FunctionTool | Callable[..., Any]] | None = None,
|
||||
context_providers: Sequence[BaseContextProvider] | None = None,
|
||||
context_providers: Sequence[ContextProvider] | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
client_type: type[RawFoundryAgentChatClient] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
@@ -478,6 +480,7 @@ class RawFoundryAgent( # type: ignore[misc]
|
||||
description: str | None = None,
|
||||
instructions: str | None = None,
|
||||
default_options: FoundryAgentOptionsT | Mapping[str, Any] | None = None,
|
||||
require_per_service_call_history_persistence: bool = False,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
@@ -507,6 +510,8 @@ class RawFoundryAgent( # type: ignore[misc]
|
||||
description: Optional local description for the local agent wrapper.
|
||||
instructions: Optional instructions for the local agent wrapper.
|
||||
default_options: Default chat options for the local agent wrapper.
|
||||
require_per_service_call_history_persistence: Whether to require per-service-call
|
||||
chat history persistence when using local history providers.
|
||||
function_invocation_configuration: Optional function invocation configuration override.
|
||||
compaction_strategy: Optional agent-level in-run compaction override.
|
||||
tokenizer: Optional agent-level tokenizer override.
|
||||
@@ -548,6 +553,7 @@ class RawFoundryAgent( # type: ignore[misc]
|
||||
default_options=cast(FoundryAgentOptionsT | None, default_options),
|
||||
context_providers=context_providers,
|
||||
middleware=middleware,
|
||||
require_per_service_call_history_persistence=require_per_service_call_history_persistence,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
additional_properties=dict(additional_properties) if additional_properties is not None else None,
|
||||
@@ -661,7 +667,7 @@ class FoundryAgent( # type: ignore[misc]
|
||||
project_client: AIProjectClient | None = None,
|
||||
allow_preview: bool | None = None,
|
||||
tools: FunctionTool | Callable[..., Any] | Sequence[FunctionTool | Callable[..., Any]] | None = None,
|
||||
context_providers: Sequence[BaseContextProvider] | None = None,
|
||||
context_providers: Sequence[ContextProvider] | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
client_type: type[RawFoundryAgentChatClient] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
@@ -671,6 +677,7 @@ class FoundryAgent( # type: ignore[misc]
|
||||
description: str | None = None,
|
||||
instructions: str | None = None,
|
||||
default_options: FoundryAgentOptionsT | Mapping[str, Any] | None = None,
|
||||
require_per_service_call_history_persistence: bool = False,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
@@ -696,6 +703,8 @@ class FoundryAgent( # type: ignore[misc]
|
||||
description: Optional local description for the local agent wrapper.
|
||||
instructions: Optional instructions for the local agent wrapper.
|
||||
default_options: Default chat options for the local agent wrapper.
|
||||
require_per_service_call_history_persistence: Whether to require per-service-call
|
||||
chat history persistence when using local history providers.
|
||||
function_invocation_configuration: Optional function invocation configuration override.
|
||||
compaction_strategy: Optional agent-level in-run compaction override.
|
||||
tokenizer: Optional agent-level tokenizer override.
|
||||
@@ -719,6 +728,7 @@ class FoundryAgent( # type: ignore[misc]
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
default_options=default_options,
|
||||
require_per_service_call_history_persistence=require_per_service_call_history_persistence,
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Foundry Memory Context Provider using BaseContextProvider.
|
||||
"""Foundry Memory Context Provider using ContextProvider.
|
||||
|
||||
This module provides ``FoundryMemoryProvider``, built on
|
||||
:class:`BaseContextProvider`.
|
||||
:class:`ContextProvider`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, ClassVar
|
||||
from agent_framework import (
|
||||
AGENT_FRAMEWORK_USER_AGENT,
|
||||
AgentSession,
|
||||
BaseContextProvider,
|
||||
ContextProvider,
|
||||
Message,
|
||||
SessionContext,
|
||||
load_settings,
|
||||
@@ -46,8 +46,8 @@ class FoundryProjectSettings(TypedDict, total=False):
|
||||
project_endpoint: str | None
|
||||
|
||||
|
||||
class FoundryMemoryProvider(BaseContextProvider):
|
||||
"""Foundry Memory context provider using the new BaseContextProvider hooks pattern.
|
||||
class FoundryMemoryProvider(ContextProvider):
|
||||
"""Foundry Memory context provider using the new ContextProvider hooks pattern.
|
||||
|
||||
Integrates Azure AI Foundry Memory Store for persistent semantic memory,
|
||||
searching and storing memories via the Azure AI Projects SDK.
|
||||
|
||||
@@ -15,8 +15,8 @@ from agent_framework import (
|
||||
AgentResponseUpdate,
|
||||
AgentSession,
|
||||
BaseAgent,
|
||||
BaseContextProvider,
|
||||
Content,
|
||||
ContextProvider,
|
||||
Message,
|
||||
ResponseStream,
|
||||
normalize_messages,
|
||||
@@ -178,7 +178,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
id: str | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
context_providers: Sequence[BaseContextProvider] | None = None,
|
||||
context_providers: Sequence[ContextProvider] | None = None,
|
||||
middleware: Sequence[AgentMiddlewareTypes] | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
default_options: OptionsT | None = None,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""New-pattern Mem0 context provider using BaseContextProvider.
|
||||
"""New-pattern Mem0 context provider using ContextProvider.
|
||||
|
||||
This module provides ``Mem0ContextProvider``, built on the new
|
||||
:class:`BaseContextProvider` hooks pattern.
|
||||
:class:`ContextProvider` hooks pattern.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -13,7 +13,7 @@ from contextlib import AbstractAsyncContextManager
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
from agent_framework import Message
|
||||
from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext
|
||||
from agent_framework._sessions import AgentSession, ContextProvider, SessionContext
|
||||
from mem0 import AsyncMemory, AsyncMemoryClient
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
@@ -33,8 +33,8 @@ class _MemorySearchResponse_v1_1(TypedDict):
|
||||
_MemorySearchResponse_v2 = list[dict[str, Any]]
|
||||
|
||||
|
||||
class Mem0ContextProvider(BaseContextProvider):
|
||||
"""Mem0 context provider using the new BaseContextProvider hooks pattern.
|
||||
class Mem0ContextProvider(ContextProvider):
|
||||
"""Mem0 context provider using the new ContextProvider hooks pattern.
|
||||
|
||||
Integrates Mem0 for persistent semantic memory, searching and storing
|
||||
memories via the Mem0 API.
|
||||
|
||||
@@ -39,7 +39,7 @@ from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import Agent, SupportsAgentRun
|
||||
from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware
|
||||
from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware, MiddlewareTermination
|
||||
from agent_framework._sessions import AgentSession
|
||||
from agent_framework._tools import FunctionTool, tool
|
||||
from agent_framework._types import AgentResponse, Content, Message
|
||||
@@ -138,8 +138,6 @@ class _AutoHandoffMiddleware(FunctionMiddleware):
|
||||
await call_next()
|
||||
return
|
||||
|
||||
from agent_framework._middleware import MiddlewareTermination
|
||||
|
||||
# Short-circuit execution and provide deterministic response payload for the tool call.
|
||||
# Parse the result using the default parser to ensure in a form that can be passed directly to LLM APIs.
|
||||
context.result = FunctionTool.parse_result({
|
||||
@@ -375,6 +373,7 @@ class HandoffAgentExecutor(AgentExecutor):
|
||||
description=agent.description,
|
||||
context_providers=agent.context_providers,
|
||||
middleware=agent.agent_middleware,
|
||||
require_per_service_call_history_persistence=agent.require_per_service_call_history_persistence,
|
||||
default_options=cloned_options, # type: ignore[assignment]
|
||||
)
|
||||
|
||||
|
||||
@@ -8,10 +8,11 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
from agent_framework import (
|
||||
Agent,
|
||||
BaseContextProvider,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
Content,
|
||||
ContextProvider,
|
||||
InMemoryHistoryProvider,
|
||||
Message,
|
||||
ResponseStream,
|
||||
WorkflowEvent,
|
||||
@@ -695,6 +696,48 @@ def test_handoff_clone_disables_provider_side_storage() -> None:
|
||||
assert executor._agent.default_options.get("store") is False
|
||||
|
||||
|
||||
async def test_handoff_clone_preserves_per_service_call_history_persistence() -> None:
|
||||
"""Handoff clones should keep per-service-call history persistence active for auto-handoff termination."""
|
||||
triage_history = InMemoryHistoryProvider()
|
||||
triage = Agent(
|
||||
id="triage",
|
||||
name="triage",
|
||||
client=MockChatClient(name="triage", handoff_to="specialist"),
|
||||
context_providers=[triage_history],
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
specialist = Agent(
|
||||
id="specialist",
|
||||
name="specialist",
|
||||
client=MockChatClient(name="specialist"),
|
||||
default_options={"tool_choice": "none"},
|
||||
)
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(participants=[triage, specialist], termination_condition=lambda _: False)
|
||||
.with_start_agent(triage)
|
||||
.add_handoff(triage, [specialist])
|
||||
.add_handoff(specialist, [triage])
|
||||
.build()
|
||||
)
|
||||
|
||||
await _drain(workflow.run("start", stream=True))
|
||||
|
||||
executor = workflow.executors[resolve_agent_id(triage)]
|
||||
assert isinstance(executor, HandoffAgentExecutor)
|
||||
assert executor._agent.require_per_service_call_history_persistence is True
|
||||
|
||||
provider_state = executor._session.state[triage_history.source_id]
|
||||
stored_messages = await triage_history.get_messages(
|
||||
executor._session.session_id,
|
||||
state=provider_state,
|
||||
)
|
||||
|
||||
assert [message.role for message in stored_messages] == ["user", "assistant"]
|
||||
assert any(content.type == "function_call" for content in stored_messages[-1].contents)
|
||||
assert all(message.role != "tool" for message in stored_messages)
|
||||
|
||||
|
||||
async def test_handoff_clears_stale_service_session_id_before_run() -> None:
|
||||
"""Stale service session IDs must be dropped before each handoff agent turn."""
|
||||
triage = MockHandoffAgent(name="triage", handoff_to="specialist")
|
||||
@@ -997,7 +1040,7 @@ async def test_context_provider_preserved_during_handoff():
|
||||
# Track whether context provider methods were called
|
||||
provider_calls: list[str] = []
|
||||
|
||||
class TestContextProvider(BaseContextProvider):
|
||||
class TestContextProvider(ContextProvider):
|
||||
"""A test context provider that tracks its invocations."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""New-pattern Redis context provider using BaseContextProvider.
|
||||
"""New-pattern Redis context provider using ContextProvider.
|
||||
|
||||
This module provides ``RedisContextProvider``, built on the new
|
||||
:class:`BaseContextProvider` hooks pattern.
|
||||
:class:`ContextProvider` hooks pattern.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
||||
|
||||
import numpy as np
|
||||
from agent_framework import Message
|
||||
from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext
|
||||
from agent_framework._sessions import AgentSession, ContextProvider, SessionContext
|
||||
from agent_framework.exceptions import (
|
||||
AgentException,
|
||||
IntegrationInvalidRequestException,
|
||||
@@ -41,8 +41,8 @@ if TYPE_CHECKING:
|
||||
from agent_framework._agents import SupportsAgentRun
|
||||
|
||||
|
||||
class RedisContextProvider(BaseContextProvider):
|
||||
"""Redis context provider using the new BaseContextProvider hooks pattern.
|
||||
class RedisContextProvider(ContextProvider):
|
||||
"""Redis context provider using the new ContextProvider hooks pattern.
|
||||
|
||||
Stores context in Redis and retrieves scoped context via full-text or
|
||||
optional hybrid vector search.
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""New-pattern Redis history provider using BaseHistoryProvider.
|
||||
"""New-pattern Redis history provider using HistoryProvider.
|
||||
|
||||
This module provides ``RedisHistoryProvider``, built on the new
|
||||
:class:`BaseHistoryProvider` hooks pattern.
|
||||
:class:`HistoryProvider` hooks pattern.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -13,12 +13,12 @@ from typing import Any, ClassVar
|
||||
|
||||
import redis.asyncio as redis
|
||||
from agent_framework import Message
|
||||
from agent_framework._sessions import BaseHistoryProvider
|
||||
from agent_framework._sessions import HistoryProvider
|
||||
from redis.credentials import CredentialProvider
|
||||
|
||||
|
||||
class RedisHistoryProvider(BaseHistoryProvider):
|
||||
"""Redis-backed history provider using the new BaseHistoryProvider hooks pattern.
|
||||
class RedisHistoryProvider(HistoryProvider):
|
||||
"""Redis-backed history provider using the new HistoryProvider hooks pattern.
|
||||
|
||||
Stores conversation history in Redis Lists, with each session isolated by a
|
||||
unique Redis key.
|
||||
|
||||
@@ -475,7 +475,7 @@ class TestRedisHistoryProviderClear:
|
||||
|
||||
|
||||
class TestRedisHistoryProviderBeforeAfterRun:
|
||||
"""Test before_run/after_run integration via BaseHistoryProvider defaults."""
|
||||
"""Test before_run/after_run integration via HistoryProvider defaults."""
|
||||
|
||||
async def test_before_run_loads_history(self, mock_redis_client: MagicMock):
|
||||
msg = Message(role="user", contents=["old msg"])
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import Agent, AgentSession, BaseContextProvider, SessionContext
|
||||
from agent_framework import Agent, AgentSession, ContextProvider, SessionContext
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from azure.identity import AzureCliCredential
|
||||
|
||||
@@ -17,7 +17,7 @@ responses — the name persists across turns via the session.
|
||||
|
||||
|
||||
# <context_provider>
|
||||
class UserMemoryProvider(BaseContextProvider):
|
||||
class UserMemoryProvider(ContextProvider):
|
||||
"""A context provider that remembers user info in session state."""
|
||||
|
||||
DEFAULT_SOURCE_ID = "user_memory"
|
||||
|
||||
@@ -9,6 +9,7 @@ This folder contains examples for direct chat client usage patterns.
|
||||
| [`built_in_chat_clients.py`](built_in_chat_clients.py) | Consolidated sample for built-in chat clients. Uses `get_client()` to create the selected client and pass it to `main()`. |
|
||||
| [`chat_response_cancellation.py`](chat_response_cancellation.py) | Demonstrates how to cancel chat responses during streaming, showing proper cancellation handling and cleanup. |
|
||||
| [`custom_chat_client.py`](custom_chat_client.py) | Demonstrates how to create custom chat clients by extending the `BaseChatClient` class. Shows a `EchoingChatClient` implementation and how to integrate it with `Agent` using the `as_agent()` method. |
|
||||
| [`require_per_service_call_history_persistence.py`](require_per_service_call_history_persistence.py) | Compares two otherwise identical `FoundryChatClient` agents with `store=False`; the only difference is whether `require_per_service_call_history_persistence` is enabled, and only the run without it stores the synthesized tool result when middleware terminates the loop early. |
|
||||
|
||||
## Selecting a built-in client
|
||||
|
||||
@@ -35,6 +36,15 @@ Example:
|
||||
uv run samples/02-agents/chat_client/built_in_chat_clients.py
|
||||
```
|
||||
|
||||
The `require_per_service_call_history_persistence.py` sample uses `FoundryChatClient`, so set the usual Foundry settings first and sign in with the Azure CLI:
|
||||
|
||||
```bash
|
||||
export FOUNDRY_PROJECT_ENDPOINT="https://<your-project>.services.ai.azure.com/api/projects/<project-name>"
|
||||
export FOUNDRY_MODEL="<your-model-deployment-name>"
|
||||
az login
|
||||
uv run samples/02-agents/chat_client/require_per_service_call_history_persistence.py
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
Depending on the selected client, set the appropriate environment variables:
|
||||
|
||||
@@ -0,0 +1,194 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import (
|
||||
Agent,
|
||||
FunctionInvocationContext,
|
||||
FunctionMiddleware,
|
||||
InMemoryHistoryProvider,
|
||||
Message,
|
||||
MiddlewareTermination,
|
||||
)
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from azure.identity import AzureCliCredential
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
"""
|
||||
Compare Foundry agents with and without per-service-call chat history persistence.
|
||||
|
||||
This sample runs two otherwise identical Foundry agents with ``store=False`` so
|
||||
history stays local for both runs.
|
||||
|
||||
The sample adds a function middleware that raises ``MiddlewareTermination``
|
||||
immediately after the tool runs, so the request stops before a second model
|
||||
call.
|
||||
|
||||
That early termination is the important difference:
|
||||
|
||||
- Without per-service-call chat history persistence, the synthesized tool result is
|
||||
still written to local history.
|
||||
- With ``require_per_service_call_history_persistence=True``, that synthesized tool result is
|
||||
not written to local history.
|
||||
|
||||
The per-service-call persistence case matches service-side storage behavior. When a terminated
|
||||
request never sends the tool result back to the service, that result also never
|
||||
becomes part of the service-managed history.
|
||||
"""
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def lookup_weather(
|
||||
location: Annotated[str, Field(description="The location to get the weather for.")],
|
||||
) -> str:
|
||||
"""Return a deterministic weather result for the requested location."""
|
||||
return f"The weather in {location} is sunny."
|
||||
|
||||
|
||||
class TerminateAfterToolMiddleware(FunctionMiddleware):
|
||||
"""Stop the tool loop after the first tool finishes."""
|
||||
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Run the tool, then terminate the loop with that tool result."""
|
||||
await call_next()
|
||||
raise MiddlewareTermination(result=context.result)
|
||||
|
||||
|
||||
def _describe_message(message: Message) -> str:
|
||||
"""Render one stored message in a compact, readable format."""
|
||||
parts: list[str] = []
|
||||
for content in message.contents:
|
||||
if content.type == "text" and content.text:
|
||||
parts.append(content.text)
|
||||
elif content.type == "function_call":
|
||||
parts.append(f"function_call -> {content.name}({content.arguments})")
|
||||
elif content.type == "function_result":
|
||||
parts.append(f"function_result -> {content.result}")
|
||||
else:
|
||||
parts.append(content.type)
|
||||
|
||||
return f"{message.role}: {' | '.join(parts)}"
|
||||
|
||||
|
||||
def _includes_tool_result(messages: list[Message]) -> bool:
|
||||
"""Return whether any stored message contains a tool result."""
|
||||
return any(content.type == "function_result" for message in messages for content in message.contents)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Run both comparison scenarios."""
|
||||
print("=== require_per_service_call_history_persistence when middleware terminates the tool loop ===\n")
|
||||
|
||||
# 1. Create one Foundry chat client that both agents will share.
|
||||
client = FoundryChatClient(credential=AzureCliCredential())
|
||||
query = "What is the weather in Seattle, and should I bring sunglasses?"
|
||||
|
||||
# 2. Create and run the agent without per-service-call persistence.
|
||||
agent_without_persistence = Agent(
|
||||
client=client,
|
||||
instructions=(
|
||||
"You are a weather assistant. Call lookup_weather exactly once before answering "
|
||||
"any weather question, then summarize the tool result in one short paragraph."
|
||||
),
|
||||
tools=[lookup_weather],
|
||||
context_providers=[InMemoryHistoryProvider()],
|
||||
middleware=[TerminateAfterToolMiddleware()],
|
||||
default_options={"tool_choice": "required", "store": False},
|
||||
)
|
||||
session_without_persistence = agent_without_persistence.create_session()
|
||||
await agent_without_persistence.run(
|
||||
query,
|
||||
session=session_without_persistence,
|
||||
)
|
||||
stored_messages_without_persistence = session_without_persistence.state[InMemoryHistoryProvider.DEFAULT_SOURCE_ID][
|
||||
"messages"
|
||||
]
|
||||
|
||||
print("=== Without per-service-call persistence ===")
|
||||
print("Loop terminated immediately after the tool finished.")
|
||||
print(f"Stored synthesized tool result: {_includes_tool_result(stored_messages_without_persistence)}")
|
||||
print("Stored history:")
|
||||
for index, message in enumerate(stored_messages_without_persistence, start=1):
|
||||
print(f" {index}. {_describe_message(message)}")
|
||||
print()
|
||||
|
||||
# 3. Create and run the agent with per-service-call persistence enabled.
|
||||
agent_with_persistence = Agent(
|
||||
client=client,
|
||||
instructions=(
|
||||
"You are a weather assistant. Call lookup_weather exactly once before answering "
|
||||
"any weather question, then summarize the tool result in one short paragraph."
|
||||
),
|
||||
tools=[lookup_weather],
|
||||
context_providers=[InMemoryHistoryProvider()],
|
||||
middleware=[TerminateAfterToolMiddleware()],
|
||||
require_per_service_call_history_persistence=True,
|
||||
default_options={"tool_choice": "required", "store": False},
|
||||
)
|
||||
session_with_persistence = agent_with_persistence.create_session()
|
||||
await agent_with_persistence.run(
|
||||
query,
|
||||
session=session_with_persistence,
|
||||
)
|
||||
stored_messages_with_persistence = session_with_persistence.state[InMemoryHistoryProvider.DEFAULT_SOURCE_ID][
|
||||
"messages"
|
||||
]
|
||||
|
||||
print("=== With per-service-call persistence ===")
|
||||
print("Loop terminated immediately after the tool finished.")
|
||||
print(f"Stored synthesized tool result: {_includes_tool_result(stored_messages_with_persistence)}")
|
||||
print("Stored history:")
|
||||
for index, message in enumerate(stored_messages_with_persistence, start=1):
|
||||
print(f" {index}. {_describe_message(message)}")
|
||||
print()
|
||||
|
||||
# 4. Summarize the effect of the flag.
|
||||
print(
|
||||
"Both runs used FoundryChatClient with store=False and terminated right after the tool. "
|
||||
"Without per-service-call persistence, local history still stored the synthesized tool result. "
|
||||
"With per-service-call persistence, local history stopped at the assistant function-call message instead, "
|
||||
"which matches service-side storage because the terminated tool result is never sent back to the service."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
"""
|
||||
Sample output:
|
||||
=== require_per_service_call_history_persistence when middleware terminates the tool loop ===
|
||||
|
||||
=== Without per-service-call persistence ===
|
||||
Loop terminated immediately after the tool finished.
|
||||
Stored synthesized tool result: True
|
||||
Stored history:
|
||||
1. user: What is the weather in Seattle, and should I bring sunglasses?
|
||||
2. assistant: function_call -> lookup_weather({"location":"Seattle"})
|
||||
3. tool: function_result -> The weather in Seattle is sunny.
|
||||
|
||||
=== With per-service-call persistence ===
|
||||
Loop terminated immediately after the tool finished.
|
||||
Stored synthesized tool result: False
|
||||
Stored history:
|
||||
1. user: What is the weather in Seattle, and should I bring sunglasses?
|
||||
2. assistant: function_call -> lookup_weather({"location":"Seattle"})
|
||||
|
||||
Both runs used FoundryChatClient with store=False and terminated right after
|
||||
the tool. Without per-service-call persistence, local history still stored the
|
||||
synthesized tool result. With per-service-call persistence, local history
|
||||
stopped at the assistant function-call message instead, which matches
|
||||
service-side storage because the terminated tool result is never sent back to
|
||||
the service.
|
||||
"""
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
from contextlib import suppress
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import Agent, AgentSession, BaseContextProvider, SessionContext, SupportsChatGetResponse
|
||||
from agent_framework import Agent, AgentSession, ContextProvider, SessionContext, SupportsChatGetResponse
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from azure.identity import AzureCliCredential
|
||||
from dotenv import load_dotenv
|
||||
@@ -20,7 +20,7 @@ class UserInfo(BaseModel):
|
||||
age: int | None = None
|
||||
|
||||
|
||||
class UserInfoMemory(BaseContextProvider):
|
||||
class UserInfoMemory(ContextProvider):
|
||||
DEFAULT_SOURCE_ID = "user_info_memory"
|
||||
|
||||
def __init__(self, source_id: str = DEFAULT_SOURCE_ID, *, client: SupportsChatGetResponse, **kwargs: Any):
|
||||
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import Agent, AgentSession, BaseHistoryProvider, Message
|
||||
from agent_framework import Agent, AgentSession, HistoryProvider, Message
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -20,7 +20,7 @@ preferred storage solution (database, file system, etc.).
|
||||
"""
|
||||
|
||||
|
||||
class CustomHistoryProvider(BaseHistoryProvider):
|
||||
class CustomHistoryProvider(HistoryProvider):
|
||||
"""Implementation of custom history provider.
|
||||
In real applications, this can be an implementation of relational database or vector store."""
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ These samples demonstrate how to build and host AI agents in Python using the [A
|
||||
| Sample | Description |
|
||||
| ----------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------- |
|
||||
| [`agent_with_hosted_mcp`](./agent_with_hosted_mcp/) | Hosted MCP tool that connects to Microsoft Learn via `https://learn.microsoft.com/api/mcp` |
|
||||
| [`agent_with_text_search_rag`](./agent_with_text_search_rag/) | Retrieval-augmented generation using a custom `BaseContextProvider` with Contoso Outdoors sample data |
|
||||
| [`agent_with_text_search_rag`](./agent_with_text_search_rag/) | Retrieval-augmented generation using a custom `ContextProvider` with Contoso Outdoors sample data |
|
||||
| [`agents_in_workflow`](./agents_in_workflow/) | Concurrent workflow that combines researcher, marketer, and legal specialist agents |
|
||||
| [`agent_with_local_tools`](./agent_with_local_tools/) | Local Python tool execution for Seattle hotel search |
|
||||
| [`writer_reviewer_agents_in_workflow`](./writer_reviewer_agents_in_workflow/) | Writer/Reviewer workflow using `FoundryChatClient` |
|
||||
|
||||
@@ -5,7 +5,7 @@ import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import Agent, AgentSession, BaseContextProvider, Message, SessionContext
|
||||
from agent_framework import Agent, AgentSession, ContextProvider, Message, SessionContext
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from azure.ai.agentserver.agentframework import from_agent_framework # pyright: ignore[reportUnknownVariableType]
|
||||
from azure.identity import DefaultAzureCredential
|
||||
@@ -28,7 +28,7 @@ class TextSearchResult:
|
||||
text: str
|
||||
|
||||
|
||||
class TextSearchContextProvider(BaseContextProvider):
|
||||
class TextSearchContextProvider(ContextProvider):
|
||||
"""A simple context provider that simulates text search results based on keywords in the user's message."""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -7,7 +7,7 @@ This sample demonstrates a simple Weather Forecast Agent built with the Python M
|
||||
- Python 3.11+
|
||||
- [uv](https://github.com/astral-sh/uv) for fast dependency management
|
||||
- [devtunnel](https://learn.microsoft.com/azure/developer/dev-tunnels/get-started?tabs=windows)
|
||||
- [Microsoft 365 Agents Toolkit](https://github.com/OfficeDev/microsoft-365-agents-toolkit) for playground/testing
|
||||
- `agentsplayground` for playground/testing
|
||||
- Access to OpenAI or Azure OpenAI with a model like `gpt-4o-mini`
|
||||
|
||||
## Configuration
|
||||
|
||||
Reference in New Issue
Block a user