mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Added chat middleware and more examples (#883)
* Added example with stateful middleware * Added chat middleware * Updated middleware example with override scenario * Small revert * Small fixes * Added kwargs to context objects * Added README * Added function middleware to chat client * Small refactoring * Reverted example files * Made MiddlewareWrapper generic * Added Middleware exception * Small refactoring * Small fix
This commit is contained in:
committed by
GitHub
Unverified
parent
863c8d7471
commit
eec7f192eb
@@ -10,7 +10,13 @@ from pydantic import BaseModel, Field
|
||||
from ._logging import get_logger
|
||||
from ._mcp import MCPTool
|
||||
from ._memory import AggregateContextProvider, ContextProvider
|
||||
from ._middleware import Middleware
|
||||
from ._middleware import (
|
||||
ChatMiddleware,
|
||||
ChatMiddlewareCallable,
|
||||
FunctionMiddleware,
|
||||
FunctionMiddlewareCallable,
|
||||
Middleware,
|
||||
)
|
||||
from ._pydantic import AFBaseModel
|
||||
from ._threads import ChatMessageStore
|
||||
from ._tools import ToolProtocol
|
||||
@@ -189,6 +195,14 @@ class BaseChatClient(AFBaseModel, ABC):
|
||||
"""Base class for chat clients."""
|
||||
|
||||
additional_properties: dict[str, Any] = Field(default_factory=dict)
|
||||
middleware: (
|
||||
ChatMiddleware
|
||||
| ChatMiddlewareCallable
|
||||
| FunctionMiddleware
|
||||
| FunctionMiddlewareCallable
|
||||
| list[ChatMiddleware | ChatMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable]
|
||||
| None
|
||||
) = None
|
||||
OTEL_PROVIDER_NAME: str = "unknown"
|
||||
# This is used for OTel setup, should be overridden in subclasses
|
||||
|
||||
@@ -346,13 +360,7 @@ class BaseChatClient(AFBaseModel, ABC):
|
||||
prepped_messages = self.prepare_messages(messages)
|
||||
self._prepare_tool_choice(chat_options=chat_options)
|
||||
|
||||
# Remove middleware pipeline from kwargs as it's only used by function invocation wrappers
|
||||
if "_function_middleware_pipeline" in kwargs:
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"}
|
||||
else:
|
||||
filtered_kwargs = kwargs
|
||||
|
||||
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **filtered_kwargs)
|
||||
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **kwargs)
|
||||
|
||||
async def get_streaming_response(
|
||||
self,
|
||||
@@ -432,14 +440,8 @@ class BaseChatClient(AFBaseModel, ABC):
|
||||
prepped_messages = self.prepare_messages(messages)
|
||||
self._prepare_tool_choice(chat_options=chat_options)
|
||||
|
||||
# Remove middleware pipeline from kwargs as it's only used by function invocation wrappers
|
||||
if "_function_middleware_pipeline" in kwargs:
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"}
|
||||
else:
|
||||
filtered_kwargs = kwargs
|
||||
|
||||
async for update in self._inner_get_streaming_response(
|
||||
messages=prepped_messages, chat_options=chat_options, **filtered_kwargs
|
||||
messages=prepped_messages, chat_options=chat_options, **kwargs
|
||||
):
|
||||
yield update
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -607,6 +607,7 @@ async def _auto_invoke_function(
|
||||
middleware_context = FunctionInvocationContext(
|
||||
function=tool,
|
||||
arguments=args,
|
||||
kwargs=custom_args or {},
|
||||
)
|
||||
|
||||
async def final_function_handler(context_obj: Any) -> Any:
|
||||
@@ -721,8 +722,16 @@ def _handle_function_calls_response(
|
||||
**kwargs: Any,
|
||||
) -> "ChatResponse":
|
||||
from ._clients import prepare_messages
|
||||
from ._middleware import extract_and_merge_function_middleware
|
||||
from ._types import ChatMessage, ChatOptions, FunctionCallContent, FunctionResultContent
|
||||
|
||||
# Extract and merge function middleware from chat client with kwargs pipeline
|
||||
extract_and_merge_function_middleware(self, kwargs)
|
||||
|
||||
# Extract the middleware pipeline before calling the underlying function
|
||||
# because the underlying function may not preserve it in kwargs
|
||||
stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline")
|
||||
|
||||
prepped_messages = prepare_messages(messages)
|
||||
response: "ChatResponse | None" = None
|
||||
fcc_messages: "list[ChatMessage]" = []
|
||||
@@ -746,8 +755,9 @@ def _handle_function_calls_response(
|
||||
if not tools and (chat_options := kwargs.get("chat_options")) and isinstance(chat_options, ChatOptions):
|
||||
tools = chat_options.tools
|
||||
if function_calls and tools:
|
||||
# Extract function middleware pipeline from kwargs if available
|
||||
middleware_pipeline = kwargs.get("_function_middleware_pipeline")
|
||||
# Use the stored middleware pipeline instead of extracting from kwargs
|
||||
# because kwargs may have been modified by the underlying function
|
||||
middleware_pipeline = stored_middleware_pipeline
|
||||
function_results = await execute_function_calls(
|
||||
custom_args=kwargs,
|
||||
attempt_idx=attempt_idx,
|
||||
@@ -820,8 +830,16 @@ def _handle_function_calls_streaming_response(
|
||||
) -> AsyncIterable["ChatResponseUpdate"]:
|
||||
"""Wrap the inner get streaming response method to handle tool calls."""
|
||||
from ._clients import prepare_messages
|
||||
from ._middleware import extract_and_merge_function_middleware
|
||||
from ._types import ChatMessage, ChatOptions, ChatResponse, ChatResponseUpdate, FunctionCallContent
|
||||
|
||||
# Extract and merge function middleware from chat client with kwargs pipeline
|
||||
extract_and_merge_function_middleware(self, kwargs)
|
||||
|
||||
# Extract the middleware pipeline before calling the underlying function
|
||||
# because the underlying function may not preserve it in kwargs
|
||||
stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline")
|
||||
|
||||
prepped_messages = prepare_messages(messages)
|
||||
for attempt_idx in range(max_iterations):
|
||||
all_updates: list["ChatResponseUpdate"] = []
|
||||
@@ -858,8 +876,9 @@ def _handle_function_calls_streaming_response(
|
||||
tools = chat_options.tools
|
||||
|
||||
if function_calls and tools:
|
||||
# Extract function middleware pipeline from kwargs if available
|
||||
middleware_pipeline = kwargs.get("_function_middleware_pipeline")
|
||||
# Use the stored middleware pipeline instead of extracting from kwargs
|
||||
# because kwargs may have been modified by the underlying function
|
||||
middleware_pipeline = stored_middleware_pipeline
|
||||
function_results = await execute_function_calls(
|
||||
custom_args=kwargs,
|
||||
attempt_idx=attempt_idx,
|
||||
|
||||
@@ -13,16 +13,18 @@ from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
||||
from pydantic import SecretStr, ValidationError
|
||||
from pydantic.networks import AnyUrl
|
||||
|
||||
from .._tools import use_function_invocation
|
||||
from .._types import (
|
||||
from agent_framework import (
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
CitationAnnotation,
|
||||
TextContent,
|
||||
use_chat_middleware,
|
||||
use_function_invocation,
|
||||
)
|
||||
from ..exceptions import ServiceInitializationError
|
||||
from ..observability import use_observability
|
||||
from ..openai._chat_client import OpenAIBaseChatClient
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from agent_framework.observability import use_observability
|
||||
from agent_framework.openai._chat_client import OpenAIBaseChatClient
|
||||
|
||||
from ._shared import (
|
||||
AzureOpenAIConfigMixin,
|
||||
AzureOpenAISettings,
|
||||
@@ -41,6 +43,7 @@ TAzureOpenAIChatClient = TypeVar("TAzureOpenAIChatClient", bound="AzureOpenAICha
|
||||
|
||||
@use_function_invocation
|
||||
@use_observability
|
||||
@use_chat_middleware
|
||||
class AzureOpenAIChatClient(AzureOpenAIConfigMixin, OpenAIBaseChatClient):
|
||||
"""Azure OpenAI Chat completion class."""
|
||||
|
||||
|
||||
@@ -9,10 +9,11 @@ from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI
|
||||
from pydantic import SecretStr, ValidationError
|
||||
from pydantic.networks import AnyUrl
|
||||
|
||||
from .._tools import use_function_invocation
|
||||
from ..exceptions import ServiceInitializationError
|
||||
from ..observability import use_observability
|
||||
from ..openai._responses_client import OpenAIBaseResponsesClient
|
||||
from agent_framework import use_chat_middleware, use_function_invocation
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from agent_framework.observability import use_observability
|
||||
from agent_framework.openai._responses_client import OpenAIBaseResponsesClient
|
||||
|
||||
from ._shared import (
|
||||
AzureOpenAIConfigMixin,
|
||||
AzureOpenAISettings,
|
||||
@@ -23,6 +24,7 @@ TAzureOpenAIResponsesClient = TypeVar("TAzureOpenAIResponsesClient", bound="Azur
|
||||
|
||||
@use_observability
|
||||
@use_function_invocation
|
||||
@use_chat_middleware
|
||||
class AzureOpenAIResponsesClient(AzureOpenAIConfigMixin, OpenAIBaseResponsesClient):
|
||||
"""Azure Responses completion class."""
|
||||
|
||||
|
||||
@@ -128,3 +128,9 @@ class AdditionItemMismatch(AgentFrameworkException):
|
||||
"""An error occurred while adding two types."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MiddlewareException(AgentFrameworkException):
|
||||
"""An error occurred during middleware execution."""
|
||||
|
||||
pass
|
||||
|
||||
@@ -21,6 +21,7 @@ from openai.types.beta.threads.runs import RunStep
|
||||
from pydantic import Field, PrivateAttr, SecretStr, ValidationError
|
||||
|
||||
from .._clients import BaseChatClient
|
||||
from .._middleware import use_chat_middleware
|
||||
from .._tools import AIFunction, HostedCodeInterpreterTool, HostedFileSearchTool, use_function_invocation
|
||||
from .._types import (
|
||||
ChatMessage,
|
||||
@@ -52,6 +53,7 @@ __all__ = ["OpenAIAssistantsClient"]
|
||||
|
||||
@use_function_invocation
|
||||
@use_observability
|
||||
@use_chat_middleware
|
||||
class OpenAIAssistantsClient(OpenAIConfigMixin, BaseChatClient):
|
||||
"""OpenAI Assistants client."""
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from pydantic import BaseModel, SecretStr, ValidationError
|
||||
|
||||
from .._clients import BaseChatClient
|
||||
from .._logging import get_logger
|
||||
from .._middleware import use_chat_middleware
|
||||
from .._tools import AIFunction, HostedWebSearchTool, ToolProtocol, use_function_invocation
|
||||
from .._types import (
|
||||
ChatMessage,
|
||||
@@ -452,6 +453,7 @@ TOpenAIChatClient = TypeVar("TOpenAIChatClient", bound="OpenAIChatClient")
|
||||
|
||||
@use_function_invocation
|
||||
@use_observability
|
||||
@use_chat_middleware
|
||||
class OpenAIChatClient(OpenAIConfigMixin, OpenAIBaseChatClient):
|
||||
"""OpenAI Chat completion class."""
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from pydantic import BaseModel, SecretStr, ValidationError
|
||||
|
||||
from .._clients import BaseChatClient
|
||||
from .._logging import get_logger
|
||||
from .._middleware import use_chat_middleware
|
||||
from .._tools import (
|
||||
AIFunction,
|
||||
HostedCodeInterpreterTool,
|
||||
@@ -933,6 +934,7 @@ TOpenAIResponsesClient = TypeVar("TOpenAIResponsesClient", bound="OpenAIResponse
|
||||
|
||||
@use_function_invocation
|
||||
@use_observability
|
||||
@use_chat_middleware
|
||||
class OpenAIResponsesClient(OpenAIConfigMixin, OpenAIBaseResponsesClient):
|
||||
"""OpenAI Responses client class."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user