mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Merge branch 'main' into feature-session-statebag
This commit is contained in:
@@ -25,6 +25,7 @@
|
||||
"src\\Microsoft.Agents.AI.Purview\\Microsoft.Agents.AI.Purview.csproj",
|
||||
"src\\Microsoft.Agents.AI.Workflows.Declarative.AzureAI\\Microsoft.Agents.AI.Workflows.Declarative.AzureAI.csproj",
|
||||
"src\\Microsoft.Agents.AI.Workflows.Declarative\\Microsoft.Agents.AI.Workflows.Declarative.csproj",
|
||||
"src\\Microsoft.Agents.AI.Workflows.Generators\\Microsoft.Agents.AI.Workflows.Generators.csproj",
|
||||
"src\\Microsoft.Agents.AI.Workflows\\Microsoft.Agents.AI.Workflows.csproj",
|
||||
"src\\Microsoft.Agents.AI\\Microsoft.Agents.AI.csproj"
|
||||
]
|
||||
|
||||
@@ -24,7 +24,6 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"agent-framework-core>=1.0.0b260210",
|
||||
"azure-ai-projects >= 2.0.0b3",
|
||||
"azure-ai-agents == 1.2.0b5",
|
||||
"aiohttp",
|
||||
]
|
||||
|
||||
@@ -119,7 +119,7 @@ from agent_framework import Agent, AgentMiddleware, AgentContext
|
||||
class LoggingMiddleware(AgentMiddleware):
|
||||
async def process(self, context: AgentContext, call_next) -> None:
|
||||
print(f"Input: {context.messages}")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
print(f"Output: {context.result}")
|
||||
|
||||
agent = Agent(..., middleware=[LoggingMiddleware()])
|
||||
|
||||
@@ -145,7 +145,7 @@ class AgentContext:
|
||||
context.metadata["start_time"] = time.time()
|
||||
|
||||
# Continue execution
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Access result after execution
|
||||
print(f"Result: {context.result}")
|
||||
@@ -229,7 +229,7 @@ class FunctionInvocationContext:
|
||||
raise MiddlewareTermination("Validation failed")
|
||||
|
||||
# Continue execution
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -293,7 +293,7 @@ class ChatContext:
|
||||
context.metadata["input_tokens"] = self.count_tokens(context.messages)
|
||||
|
||||
# Continue execution
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Access result and count output tokens
|
||||
if context.result:
|
||||
@@ -365,7 +365,7 @@ class AgentMiddleware(ABC):
|
||||
|
||||
async def process(self, context: AgentContext, call_next):
|
||||
for attempt in range(self.max_retries):
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
if context.result and not context.result.is_error:
|
||||
break
|
||||
print(f"Retry {attempt + 1}/{self.max_retries}")
|
||||
@@ -379,7 +379,7 @@ class AgentMiddleware(ABC):
|
||||
async def process(
|
||||
self,
|
||||
context: AgentContext,
|
||||
call_next: Callable[[AgentContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Process an agent invocation.
|
||||
|
||||
@@ -431,7 +431,7 @@ class FunctionMiddleware(ABC):
|
||||
raise MiddlewareTermination()
|
||||
|
||||
# Execute function
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Cache result
|
||||
if context.result:
|
||||
@@ -446,7 +446,7 @@ class FunctionMiddleware(ABC):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Process a function invocation.
|
||||
|
||||
@@ -493,7 +493,7 @@ class ChatMiddleware(ABC):
|
||||
context.messages.insert(0, Message(role="system", text=self.system_prompt))
|
||||
|
||||
# Continue execution
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
|
||||
# Use with an agent
|
||||
@@ -508,7 +508,7 @@ class ChatMiddleware(ABC):
|
||||
async def process(
|
||||
self,
|
||||
context: ChatContext,
|
||||
call_next: Callable[[ChatContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Process a chat client request.
|
||||
|
||||
@@ -531,15 +531,13 @@ class ChatMiddleware(ABC):
|
||||
|
||||
|
||||
# Pure function type definitions for convenience
|
||||
AgentMiddlewareCallable = Callable[[AgentContext, Callable[[AgentContext], Awaitable[None]]], Awaitable[None]]
|
||||
AgentMiddlewareCallable = Callable[[AgentContext, Callable[[], Awaitable[None]]], Awaitable[None]]
|
||||
AgentMiddlewareTypes: TypeAlias = AgentMiddleware | AgentMiddlewareCallable
|
||||
|
||||
FunctionMiddlewareCallable = Callable[
|
||||
[FunctionInvocationContext, Callable[[FunctionInvocationContext], Awaitable[None]]], Awaitable[None]
|
||||
]
|
||||
FunctionMiddlewareCallable = Callable[[FunctionInvocationContext, Callable[[], Awaitable[None]]], Awaitable[None]]
|
||||
FunctionMiddlewareTypes: TypeAlias = FunctionMiddleware | FunctionMiddlewareCallable
|
||||
|
||||
ChatMiddlewareCallable = Callable[[ChatContext, Callable[[ChatContext], Awaitable[None]]], Awaitable[None]]
|
||||
ChatMiddlewareCallable = Callable[[ChatContext, Callable[[], Awaitable[None]]], Awaitable[None]]
|
||||
ChatMiddlewareTypes: TypeAlias = ChatMiddleware | ChatMiddlewareCallable
|
||||
|
||||
ChatAndFunctionMiddlewareTypes: TypeAlias = (
|
||||
@@ -578,7 +576,7 @@ def agent_middleware(func: AgentMiddlewareCallable) -> AgentMiddlewareCallable:
|
||||
@agent_middleware
|
||||
async def logging_middleware(context: AgentContext, call_next):
|
||||
print(f"Before: {context.agent.name}")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
print(f"After: {context.result}")
|
||||
|
||||
|
||||
@@ -611,7 +609,7 @@ def function_middleware(func: FunctionMiddlewareCallable) -> FunctionMiddlewareC
|
||||
@function_middleware
|
||||
async def logging_middleware(context: FunctionInvocationContext, call_next):
|
||||
print(f"Calling: {context.function.name}")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
print(f"Result: {context.result}")
|
||||
|
||||
|
||||
@@ -644,7 +642,7 @@ def chat_middleware(func: ChatMiddlewareCallable) -> ChatMiddlewareCallable:
|
||||
@chat_middleware
|
||||
async def logging_middleware(context: ChatContext, call_next):
|
||||
print(f"Messages: {len(context.messages)}")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
print(f"Response: {context.result}")
|
||||
|
||||
|
||||
@@ -666,10 +664,10 @@ class MiddlewareWrapper(Generic[ContextT]):
|
||||
ContextT: The type of context object this middleware operates on.
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable[[ContextT, Callable[[ContextT], Awaitable[None]]], Awaitable[None]]) -> None:
|
||||
def __init__(self, func: Callable[[ContextT, Callable[[], Awaitable[None]]], Awaitable[None]]) -> None:
|
||||
self.func = func
|
||||
|
||||
async def process(self, context: ContextT, call_next: Callable[[ContextT], Awaitable[None]]) -> None:
|
||||
async def process(self, context: ContextT, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
await self.func(context, call_next)
|
||||
|
||||
|
||||
@@ -772,25 +770,25 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline):
|
||||
context.result = await context.result
|
||||
return context.result
|
||||
|
||||
def create_next_handler(index: int) -> Callable[[AgentContext], Awaitable[None]]:
|
||||
def create_next_handler(index: int) -> Callable[[], Awaitable[None]]:
|
||||
if index >= len(self._middleware):
|
||||
|
||||
async def final_wrapper(c: AgentContext) -> None:
|
||||
c.result = final_handler(c) # type: ignore[assignment]
|
||||
if inspect.isawaitable(c.result):
|
||||
c.result = await c.result
|
||||
async def final_wrapper() -> None:
|
||||
context.result = final_handler(context) # type: ignore[assignment]
|
||||
if inspect.isawaitable(context.result):
|
||||
context.result = await context.result
|
||||
|
||||
return final_wrapper
|
||||
|
||||
async def current_handler(c: AgentContext) -> None:
|
||||
async def current_handler() -> None:
|
||||
# MiddlewareTermination bubbles up to execute() to skip post-processing
|
||||
await self._middleware[index].process(c, create_next_handler(index + 1))
|
||||
await self._middleware[index].process(context, create_next_handler(index + 1))
|
||||
|
||||
return current_handler
|
||||
|
||||
first_handler = create_next_handler(0)
|
||||
with contextlib.suppress(MiddlewareTermination):
|
||||
await first_handler(context)
|
||||
await first_handler()
|
||||
|
||||
if context.result and isinstance(context.result, ResponseStream):
|
||||
for hook in context.stream_transform_hooks:
|
||||
@@ -847,25 +845,25 @@ class FunctionMiddlewarePipeline(BaseMiddlewarePipeline):
|
||||
if not self._middleware:
|
||||
return await final_handler(context)
|
||||
|
||||
def create_next_handler(index: int) -> Callable[[FunctionInvocationContext], Awaitable[None]]:
|
||||
def create_next_handler(index: int) -> Callable[[], Awaitable[None]]:
|
||||
if index >= len(self._middleware):
|
||||
|
||||
async def final_wrapper(c: FunctionInvocationContext) -> None:
|
||||
c.result = final_handler(c)
|
||||
if inspect.isawaitable(c.result):
|
||||
c.result = await c.result
|
||||
async def final_wrapper() -> None:
|
||||
context.result = final_handler(context)
|
||||
if inspect.isawaitable(context.result):
|
||||
context.result = await context.result
|
||||
|
||||
return final_wrapper
|
||||
|
||||
async def current_handler(c: FunctionInvocationContext) -> None:
|
||||
async def current_handler() -> None:
|
||||
# MiddlewareTermination bubbles up to execute() to skip post-processing
|
||||
await self._middleware[index].process(c, create_next_handler(index + 1))
|
||||
await self._middleware[index].process(context, create_next_handler(index + 1))
|
||||
|
||||
return current_handler
|
||||
|
||||
first_handler = create_next_handler(0)
|
||||
# Don't suppress MiddlewareTermination - let it propagate to signal loop termination
|
||||
await first_handler(context)
|
||||
await first_handler()
|
||||
|
||||
return context.result
|
||||
|
||||
@@ -922,25 +920,25 @@ class ChatMiddlewarePipeline(BaseMiddlewarePipeline):
|
||||
raise ValueError("Streaming agent middleware requires a ResponseStream result.")
|
||||
return context.result
|
||||
|
||||
def create_next_handler(index: int) -> Callable[[ChatContext], Awaitable[None]]:
|
||||
def create_next_handler(index: int) -> Callable[[], Awaitable[None]]:
|
||||
if index >= len(self._middleware):
|
||||
|
||||
async def final_wrapper(c: ChatContext) -> None:
|
||||
c.result = final_handler(c) # type: ignore[assignment]
|
||||
if inspect.isawaitable(c.result):
|
||||
c.result = await c.result
|
||||
async def final_wrapper() -> None:
|
||||
context.result = final_handler(context) # type: ignore[assignment]
|
||||
if inspect.isawaitable(context.result):
|
||||
context.result = await context.result
|
||||
|
||||
return final_wrapper
|
||||
|
||||
async def current_handler(c: ChatContext) -> None:
|
||||
async def current_handler() -> None:
|
||||
# MiddlewareTermination bubbles up to execute() to skip post-processing
|
||||
await self._middleware[index].process(c, create_next_handler(index + 1))
|
||||
await self._middleware[index].process(context, create_next_handler(index + 1))
|
||||
|
||||
return current_handler
|
||||
|
||||
first_handler = create_next_handler(0)
|
||||
with contextlib.suppress(MiddlewareTermination):
|
||||
await first_handler(context)
|
||||
await first_handler()
|
||||
|
||||
if context.result and isinstance(context.result, ResponseStream):
|
||||
for hook in context.stream_transform_hooks:
|
||||
|
||||
@@ -7,11 +7,14 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Generic
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from azure.ai.projects.aio import AIProjectClient
|
||||
from azure.core.credentials import TokenCredential
|
||||
from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI
|
||||
from openai import AsyncOpenAI
|
||||
from openai.lib.azure import AsyncAzureADTokenProvider
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .._middleware import ChatMiddlewareLayer
|
||||
from .._telemetry import AGENT_FRAMEWORK_USER_AGENT
|
||||
from .._tools import FunctionInvocationConfiguration, FunctionInvocationLayer
|
||||
from ..exceptions import ServiceInitializationError
|
||||
from ..observability import ChatTelemetryLayer
|
||||
@@ -72,7 +75,9 @@ class AzureOpenAIResponsesClient( # type: ignore[misc]
|
||||
token_endpoint: str | None = None,
|
||||
credential: TokenCredential | None = None,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
async_client: AsyncAzureOpenAI | None = None,
|
||||
async_client: AsyncOpenAI | None = None,
|
||||
project_client: Any | None = None,
|
||||
project_endpoint: str | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
instruction_role: str | None = None,
|
||||
@@ -82,6 +87,14 @@ class AzureOpenAIResponsesClient( # type: ignore[misc]
|
||||
) -> None:
|
||||
"""Initialize an Azure OpenAI Responses client.
|
||||
|
||||
The client can be created in two ways:
|
||||
|
||||
1. **Direct Azure OpenAI** (default): Provide endpoint, api_key, or credential
|
||||
to connect directly to an Azure OpenAI deployment.
|
||||
2. **Foundry project endpoint**: Provide a ``project_client`` or ``project_endpoint``
|
||||
(with ``credential``) to create the client via an Azure AI Foundry project.
|
||||
This requires the ``azure-ai-projects`` package to be installed.
|
||||
|
||||
Keyword Args:
|
||||
api_key: The API key. If provided, will override the value in the env vars or .env file.
|
||||
Can also be set via environment variable AZURE_OPENAI_API_KEY.
|
||||
@@ -105,6 +118,12 @@ class AzureOpenAIResponsesClient( # type: ignore[misc]
|
||||
default_headers: The default headers mapping of string keys to
|
||||
string values for HTTP requests.
|
||||
async_client: An existing client to use.
|
||||
project_client: An existing ``AIProjectClient`` (from ``azure.ai.projects.aio``) to use.
|
||||
The OpenAI client will be obtained via ``project_client.get_openai_client()``.
|
||||
Requires the ``azure-ai-projects`` package.
|
||||
project_endpoint: The Azure AI Foundry project endpoint URL.
|
||||
When provided with ``credential``, an ``AIProjectClient`` will be created
|
||||
and used to obtain the OpenAI client. Requires the ``azure-ai-projects`` package.
|
||||
env_file_path: Use the environment settings file as a fallback to using env vars.
|
||||
env_file_encoding: The encoding of the environment settings file, defaults to 'utf-8'.
|
||||
instruction_role: The role to use for 'instruction' messages, for example, summarization
|
||||
@@ -132,6 +151,27 @@ class AzureOpenAIResponsesClient( # type: ignore[misc]
|
||||
# Or loading from a .env file
|
||||
client = AzureOpenAIResponsesClient(env_file_path="path/to/.env")
|
||||
|
||||
# Using a Foundry project endpoint
|
||||
from azure.identity import DefaultAzureCredential
|
||||
|
||||
client = AzureOpenAIResponsesClient(
|
||||
project_endpoint="https://your-project.services.ai.azure.com",
|
||||
deployment_name="gpt-4o",
|
||||
credential=DefaultAzureCredential(),
|
||||
)
|
||||
|
||||
# Or using an existing AIProjectClient
|
||||
from azure.ai.projects.aio import AIProjectClient
|
||||
|
||||
project_client = AIProjectClient(
|
||||
endpoint="https://your-project.services.ai.azure.com",
|
||||
credential=DefaultAzureCredential(),
|
||||
)
|
||||
client = AzureOpenAIResponsesClient(
|
||||
project_client=project_client,
|
||||
deployment_name="gpt-4o",
|
||||
)
|
||||
|
||||
# Using custom ChatOptions with type safety:
|
||||
from typing import TypedDict
|
||||
from agent_framework.azure import AzureOpenAIResponsesOptions
|
||||
@@ -146,6 +186,15 @@ class AzureOpenAIResponsesClient( # type: ignore[misc]
|
||||
"""
|
||||
if model_id := kwargs.pop("model_id", None) and not deployment_name:
|
||||
deployment_name = str(model_id)
|
||||
|
||||
# Project client path: create OpenAI client from an Azure AI Foundry project
|
||||
if async_client is None and (project_client is not None or project_endpoint is not None):
|
||||
async_client = self._create_client_from_project(
|
||||
project_client=project_client,
|
||||
project_endpoint=project_endpoint,
|
||||
credential=credential,
|
||||
)
|
||||
|
||||
try:
|
||||
azure_openai_settings = AzureOpenAISettings(
|
||||
# pydantic settings will see if there is a value, if not, will try the env var or .env file
|
||||
@@ -195,9 +244,48 @@ class AzureOpenAIResponsesClient( # type: ignore[misc]
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_client_from_project(
|
||||
*,
|
||||
project_client: AIProjectClient | None,
|
||||
project_endpoint: str | None,
|
||||
credential: TokenCredential | None,
|
||||
) -> AsyncOpenAI:
|
||||
"""Create an AsyncOpenAI client from an Azure AI Foundry project.
|
||||
|
||||
Args:
|
||||
project_client: An existing AIProjectClient to use.
|
||||
project_endpoint: The Azure AI Foundry project endpoint URL.
|
||||
credential: Azure credential for authentication.
|
||||
|
||||
Returns:
|
||||
An AsyncAzureOpenAI client obtained from the project client.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If required parameters are missing or
|
||||
the azure-ai-projects package is not installed.
|
||||
"""
|
||||
if project_client is not None:
|
||||
return project_client.get_openai_client()
|
||||
|
||||
if not project_endpoint:
|
||||
raise ServiceInitializationError(
|
||||
"Azure AI project endpoint is required when project_client is not provided."
|
||||
)
|
||||
if not credential:
|
||||
raise ServiceInitializationError(
|
||||
"Azure credential is required when using project_endpoint without a project_client."
|
||||
)
|
||||
project_client = AIProjectClient(
|
||||
endpoint=project_endpoint,
|
||||
credential=credential, # type: ignore[arg-type]
|
||||
user_agent=AGENT_FRAMEWORK_USER_AGENT,
|
||||
)
|
||||
return project_client.get_openai_client()
|
||||
|
||||
@override
|
||||
def _check_model_presence(self, run_options: dict[str, Any]) -> None:
|
||||
if not run_options.get("model"):
|
||||
def _check_model_presence(self, options: dict[str, Any]) -> None:
|
||||
if not options.get("model"):
|
||||
if not self.model_id:
|
||||
raise ValueError("deployment_name must be a non-empty string")
|
||||
run_options["model"] = self.model_id
|
||||
options["model"] = self.model_id
|
||||
|
||||
@@ -9,6 +9,7 @@ from copy import copy
|
||||
from typing import Any, ClassVar, Final
|
||||
|
||||
from azure.core.credentials import TokenCredential
|
||||
from openai import AsyncOpenAI
|
||||
from openai.lib.azure import AsyncAzureOpenAI
|
||||
from pydantic import SecretStr, model_validator
|
||||
|
||||
@@ -162,7 +163,7 @@ class AzureOpenAIConfigMixin(OpenAIBase):
|
||||
token_endpoint: str | None = None,
|
||||
credential: TokenCredential | None = None,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
client: AsyncAzureOpenAI | None = None,
|
||||
client: AsyncOpenAI | None = None,
|
||||
instruction_role: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
||||
@@ -901,6 +901,7 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
|
||||
"""Prepare a chat message for the OpenAI Responses API format."""
|
||||
all_messages: list[dict[str, Any]] = []
|
||||
args: dict[str, Any] = {
|
||||
"type": "message",
|
||||
"role": message.role,
|
||||
}
|
||||
for content in message.contents:
|
||||
@@ -911,16 +912,22 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
|
||||
case "function_result":
|
||||
new_args: dict[str, Any] = {}
|
||||
new_args.update(self._prepare_content_for_openai(message.role, content, call_id_to_id)) # type: ignore[arg-type]
|
||||
all_messages.append(new_args)
|
||||
if new_args:
|
||||
all_messages.append(new_args)
|
||||
case "function_call":
|
||||
function_call = self._prepare_content_for_openai(message.role, content, call_id_to_id) # type: ignore[arg-type]
|
||||
all_messages.append(function_call) # type: ignore
|
||||
if function_call:
|
||||
all_messages.append(function_call) # type: ignore
|
||||
case "function_approval_response" | "function_approval_request":
|
||||
all_messages.append(self._prepare_content_for_openai(message.role, content, call_id_to_id)) # type: ignore
|
||||
prepared = self._prepare_content_for_openai(Role(message.role), content, call_id_to_id)
|
||||
if prepared:
|
||||
all_messages.append(prepared) # type: ignore
|
||||
case _:
|
||||
if "content" not in args:
|
||||
args["content"] = []
|
||||
args["content"].append(self._prepare_content_for_openai(message.role, content, call_id_to_id)) # type: ignore
|
||||
prepared_content = self._prepare_content_for_openai(message.role, content, call_id_to_id) # type: ignore
|
||||
if prepared_content:
|
||||
if "content" not in args:
|
||||
args["content"] = []
|
||||
args["content"].append(prepared_content) # type: ignore
|
||||
if "content" in args or "tool_calls" in args:
|
||||
all_messages.append(args)
|
||||
return all_messages
|
||||
|
||||
@@ -34,6 +34,7 @@ dependencies = [
|
||||
# connectors and functions
|
||||
"openai>=1.99.0",
|
||||
"azure-identity>=1,<2",
|
||||
"azure-ai-projects >= 2.0.0b3",
|
||||
"mcp[ws]>=1.24.0,<2",
|
||||
"packaging>=24.1",
|
||||
]
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Annotated, Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from azure.identity import AzureCliCredential
|
||||
@@ -115,6 +116,119 @@ def test_init_with_empty_model_id(azure_openai_unit_test_env: dict[str, str]) ->
|
||||
)
|
||||
|
||||
|
||||
def test_init_with_project_client(azure_openai_unit_test_env: dict[str, str]) -> None:
|
||||
"""Test initialization with an existing AIProjectClient."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Create a mock AIProjectClient that returns a mock AsyncOpenAI client
|
||||
mock_openai_client = MagicMock(spec=AsyncOpenAI)
|
||||
mock_openai_client.default_headers = {}
|
||||
|
||||
mock_project_client = MagicMock()
|
||||
mock_project_client.get_openai_client.return_value = mock_openai_client
|
||||
|
||||
with patch(
|
||||
"agent_framework.azure._responses_client.AzureOpenAIResponsesClient._create_client_from_project",
|
||||
return_value=mock_openai_client,
|
||||
):
|
||||
azure_responses_client = AzureOpenAIResponsesClient(
|
||||
project_client=mock_project_client,
|
||||
deployment_name="gpt-4o",
|
||||
)
|
||||
|
||||
assert azure_responses_client.model_id == "gpt-4o"
|
||||
assert azure_responses_client.client is mock_openai_client
|
||||
assert isinstance(azure_responses_client, SupportsChatGetResponse)
|
||||
|
||||
|
||||
def test_init_with_project_endpoint(azure_openai_unit_test_env: dict[str, str]) -> None:
|
||||
"""Test initialization with a project endpoint and credential."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
mock_openai_client = MagicMock(spec=AsyncOpenAI)
|
||||
mock_openai_client.default_headers = {}
|
||||
|
||||
with patch(
|
||||
"agent_framework.azure._responses_client.AzureOpenAIResponsesClient._create_client_from_project",
|
||||
return_value=mock_openai_client,
|
||||
):
|
||||
azure_responses_client = AzureOpenAIResponsesClient(
|
||||
project_endpoint="https://test-project.services.ai.azure.com",
|
||||
deployment_name="gpt-4o",
|
||||
credential=AzureCliCredential(),
|
||||
)
|
||||
|
||||
assert azure_responses_client.model_id == "gpt-4o"
|
||||
assert azure_responses_client.client is mock_openai_client
|
||||
assert isinstance(azure_responses_client, SupportsChatGetResponse)
|
||||
|
||||
|
||||
def test_create_client_from_project_with_project_client() -> None:
|
||||
"""Test _create_client_from_project with an existing project client."""
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
mock_openai_client = MagicMock(spec=AsyncOpenAI)
|
||||
mock_project_client = MagicMock()
|
||||
mock_project_client.get_openai_client.return_value = mock_openai_client
|
||||
|
||||
result = AzureOpenAIResponsesClient._create_client_from_project(
|
||||
project_client=mock_project_client,
|
||||
project_endpoint=None,
|
||||
credential=None,
|
||||
)
|
||||
|
||||
assert result is mock_openai_client
|
||||
mock_project_client.get_openai_client.assert_called_once()
|
||||
|
||||
|
||||
def test_create_client_from_project_with_endpoint() -> None:
|
||||
"""Test _create_client_from_project with a project endpoint."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
mock_openai_client = MagicMock(spec=AsyncOpenAI)
|
||||
mock_credential = MagicMock()
|
||||
|
||||
with patch("agent_framework.azure._responses_client.AIProjectClient") as MockAIProjectClient:
|
||||
mock_instance = MockAIProjectClient.return_value
|
||||
mock_instance.get_openai_client.return_value = mock_openai_client
|
||||
|
||||
result = AzureOpenAIResponsesClient._create_client_from_project(
|
||||
project_client=None,
|
||||
project_endpoint="https://test-project.services.ai.azure.com",
|
||||
credential=mock_credential,
|
||||
)
|
||||
|
||||
assert result is mock_openai_client
|
||||
MockAIProjectClient.assert_called_once()
|
||||
mock_instance.get_openai_client.assert_called_once()
|
||||
|
||||
|
||||
def test_create_client_from_project_missing_endpoint() -> None:
|
||||
"""Test _create_client_from_project raises error when endpoint is missing."""
|
||||
with pytest.raises(ServiceInitializationError, match="project endpoint is required"):
|
||||
AzureOpenAIResponsesClient._create_client_from_project(
|
||||
project_client=None,
|
||||
project_endpoint=None,
|
||||
credential=MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
def test_create_client_from_project_missing_credential() -> None:
|
||||
"""Test _create_client_from_project raises error when credential is missing."""
|
||||
with pytest.raises(ServiceInitializationError, match="credential is required"):
|
||||
AzureOpenAIResponsesClient._create_client_from_project(
|
||||
project_client=None,
|
||||
project_endpoint="https://test-project.services.ai.azure.com",
|
||||
credential=None,
|
||||
)
|
||||
|
||||
|
||||
def test_serialize(azure_openai_unit_test_env: dict[str, str]) -> None:
|
||||
default_headers = {"X-Unit-Test": "test-guid"}
|
||||
|
||||
|
||||
@@ -19,12 +19,10 @@ class TestAsToolKwargsPropagation:
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
@agent_middleware
|
||||
async def capture_middleware(
|
||||
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Capture kwargs passed to the sub-agent
|
||||
captured_kwargs.update(context.kwargs)
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Setup mock response
|
||||
client.responses = [
|
||||
@@ -62,11 +60,9 @@ class TestAsToolKwargsPropagation:
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
@agent_middleware
|
||||
async def capture_middleware(
|
||||
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
captured_kwargs.update(context.kwargs)
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Setup mock response
|
||||
client.responses = [
|
||||
@@ -99,12 +95,10 @@ class TestAsToolKwargsPropagation:
|
||||
captured_kwargs_list: list[dict[str, Any]] = []
|
||||
|
||||
@agent_middleware
|
||||
async def capture_middleware(
|
||||
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Capture kwargs at each level
|
||||
captured_kwargs_list.append(dict(context.kwargs))
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Setup mock responses to trigger nested tool invocation: B calls tool C, then completes.
|
||||
client.responses = [
|
||||
@@ -162,11 +156,9 @@ class TestAsToolKwargsPropagation:
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
@agent_middleware
|
||||
async def capture_middleware(
|
||||
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
captured_kwargs.update(context.kwargs)
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Setup mock streaming responses
|
||||
from agent_framework import ChatResponseUpdate
|
||||
@@ -224,11 +216,9 @@ class TestAsToolKwargsPropagation:
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
@agent_middleware
|
||||
async def capture_middleware(
|
||||
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
captured_kwargs.update(context.kwargs)
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Setup mock response
|
||||
client.responses = [
|
||||
@@ -266,16 +256,14 @@ class TestAsToolKwargsPropagation:
|
||||
call_count = 0
|
||||
|
||||
@agent_middleware
|
||||
async def capture_middleware(
|
||||
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
first_call_kwargs.update(context.kwargs)
|
||||
elif call_count == 2:
|
||||
second_call_kwargs.update(context.kwargs)
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Setup mock responses for both calls
|
||||
client.responses = [
|
||||
@@ -318,11 +306,9 @@ class TestAsToolKwargsPropagation:
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
@agent_middleware
|
||||
async def capture_middleware(
|
||||
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
captured_kwargs.update(context.kwargs)
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Setup mock response
|
||||
client.responses = [
|
||||
|
||||
@@ -2298,9 +2298,7 @@ async def test_streaming_error_recovery_resets_counter(chat_client_base: Support
|
||||
class TerminateLoopMiddleware(FunctionMiddleware):
|
||||
"""Middleware that raises MiddlewareTermination to exit the function calling loop."""
|
||||
|
||||
async def process(
|
||||
self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: FunctionInvocationContext, next_handler: Callable[[], Awaitable[None]]) -> None:
|
||||
# Set result to a simple value - the framework will wrap it in FunctionResultContent
|
||||
context.result = "terminated by middleware"
|
||||
raise MiddlewareTermination
|
||||
@@ -2355,14 +2353,12 @@ async def test_terminate_loop_single_function_call(chat_client_base: SupportsCha
|
||||
class SelectiveTerminateMiddleware(FunctionMiddleware):
|
||||
"""Only terminates for terminating_function."""
|
||||
|
||||
async def process(
|
||||
self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: FunctionInvocationContext, next_handler: Callable[[], Awaitable[None]]) -> None:
|
||||
if context.function.name == "terminating_function":
|
||||
# Set result to a simple value - the framework will wrap it in FunctionResultContent
|
||||
context.result = "terminated by middleware"
|
||||
raise MiddlewareTermination
|
||||
await next_handler(context)
|
||||
await next_handler()
|
||||
|
||||
|
||||
async def test_terminate_loop_multiple_function_calls_one_terminates(chat_client_base: SupportsChatGetResponse):
|
||||
|
||||
@@ -135,12 +135,12 @@ class TestAgentMiddlewarePipeline:
|
||||
"""Test cases for AgentMiddlewarePipeline."""
|
||||
|
||||
class PreNextTerminateMiddleware(AgentMiddleware):
|
||||
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
raise MiddlewareTermination
|
||||
|
||||
class PostNextTerminateMiddleware(AgentMiddleware):
|
||||
async def process(self, context: AgentContext, call_next: Any) -> None:
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
raise MiddlewareTermination
|
||||
|
||||
def test_init_empty(self) -> None:
|
||||
@@ -157,8 +157,8 @@ class TestAgentMiddlewarePipeline:
|
||||
def test_init_with_function_middleware(self) -> None:
|
||||
"""Test AgentMiddlewarePipeline initialization with function-based middleware."""
|
||||
|
||||
async def test_middleware(context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
|
||||
await call_next(context)
|
||||
async def test_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
await call_next()
|
||||
|
||||
pipeline = AgentMiddlewarePipeline(test_middleware)
|
||||
assert pipeline.has_middlewares
|
||||
@@ -185,11 +185,9 @@ class TestAgentMiddlewarePipeline:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
middleware = OrderTrackingMiddleware("test")
|
||||
@@ -238,11 +236,9 @@ class TestAgentMiddlewarePipeline:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
middleware = StreamOrderTrackingMiddleware("test")
|
||||
@@ -367,12 +363,10 @@ class TestAgentMiddlewarePipeline:
|
||||
captured_thread = None
|
||||
|
||||
class ThreadCapturingMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
nonlocal captured_thread
|
||||
captured_thread = context.thread
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
middleware = ThreadCapturingMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline(middleware)
|
||||
@@ -394,12 +388,10 @@ class TestAgentMiddlewarePipeline:
|
||||
captured_thread = "not_none" # Use string to distinguish from None
|
||||
|
||||
class ThreadCapturingMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
nonlocal captured_thread
|
||||
captured_thread = context.thread
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
middleware = ThreadCapturingMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline(middleware)
|
||||
@@ -425,7 +417,7 @@ class TestFunctionMiddlewarePipeline:
|
||||
|
||||
class PostNextTerminateFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(self, context: FunctionInvocationContext, call_next: Any) -> None:
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
raise MiddlewareTermination
|
||||
|
||||
async def test_execute_with_pre_next_termination(self, mock_function: FunctionTool[Any, Any]) -> None:
|
||||
@@ -482,10 +474,8 @@ class TestFunctionMiddlewarePipeline:
|
||||
def test_init_with_function_middleware(self) -> None:
|
||||
"""Test FunctionMiddlewarePipeline initialization with function-based middleware."""
|
||||
|
||||
async def test_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
await call_next(context)
|
||||
async def test_middleware(context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
await call_next()
|
||||
|
||||
pipeline = FunctionMiddlewarePipeline(test_middleware)
|
||||
assert pipeline.has_middlewares
|
||||
@@ -515,10 +505,10 @@ class TestFunctionMiddlewarePipeline:
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
middleware = OrderTrackingFunctionMiddleware("test")
|
||||
@@ -541,12 +531,12 @@ class TestChatMiddlewarePipeline:
|
||||
"""Test cases for ChatMiddlewarePipeline."""
|
||||
|
||||
class PreNextTerminateChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
raise MiddlewareTermination
|
||||
|
||||
class PostNextTerminateChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
await call_next(context)
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
await call_next()
|
||||
raise MiddlewareTermination
|
||||
|
||||
def test_init_empty(self) -> None:
|
||||
@@ -563,8 +553,8 @@ class TestChatMiddlewarePipeline:
|
||||
def test_init_with_function_middleware(self) -> None:
|
||||
"""Test ChatMiddlewarePipeline initialization with function-based middleware."""
|
||||
|
||||
async def test_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
await call_next(context)
|
||||
async def test_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
await call_next()
|
||||
|
||||
pipeline = ChatMiddlewarePipeline(test_middleware)
|
||||
assert pipeline.has_middlewares
|
||||
@@ -592,9 +582,9 @@ class TestChatMiddlewarePipeline:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
middleware = OrderTrackingChatMiddleware("test")
|
||||
@@ -644,9 +634,9 @@ class TestChatMiddlewarePipeline:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
middleware = StreamOrderTrackingChatMiddleware("test")
|
||||
@@ -774,12 +764,10 @@ class TestClassBasedMiddleware:
|
||||
metadata_updates: list[str] = []
|
||||
|
||||
class MetadataAgentMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
context.metadata["before"] = True
|
||||
metadata_updates.append("before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
context.metadata["after"] = True
|
||||
metadata_updates.append("after")
|
||||
|
||||
@@ -807,11 +795,11 @@ class TestClassBasedMiddleware:
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
context.metadata["before"] = True
|
||||
metadata_updates.append("before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
context.metadata["after"] = True
|
||||
metadata_updates.append("after")
|
||||
|
||||
@@ -839,12 +827,10 @@ class TestFunctionBasedMiddleware:
|
||||
"""Test function-based agent middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def test_agent_middleware(
|
||||
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def test_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("function_before")
|
||||
context.metadata["function_middleware"] = True
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("function_after")
|
||||
|
||||
pipeline = AgentMiddlewarePipeline(test_agent_middleware)
|
||||
@@ -866,11 +852,11 @@ class TestFunctionBasedMiddleware:
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def test_function_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("function_before")
|
||||
context.metadata["function_middleware"] = True
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("function_after")
|
||||
|
||||
pipeline = FunctionMiddlewarePipeline(test_function_middleware)
|
||||
@@ -896,18 +882,14 @@ class TestMixedMiddleware:
|
||||
execution_order: list[str] = []
|
||||
|
||||
class ClassMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("class_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("class_after")
|
||||
|
||||
async def function_middleware(
|
||||
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def function_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("function_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("function_after")
|
||||
|
||||
pipeline = AgentMiddlewarePipeline(ClassMiddleware(), function_middleware)
|
||||
@@ -931,17 +913,17 @@ class TestMixedMiddleware:
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append("class_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("class_after")
|
||||
|
||||
async def function_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("function_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("function_after")
|
||||
|
||||
pipeline = FunctionMiddlewarePipeline(ClassMiddleware(), function_middleware)
|
||||
@@ -962,16 +944,14 @@ class TestMixedMiddleware:
|
||||
execution_order: list[str] = []
|
||||
|
||||
class ClassChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("class_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("class_after")
|
||||
|
||||
async def function_chat_middleware(
|
||||
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def function_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("function_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("function_after")
|
||||
|
||||
pipeline = ChatMiddlewarePipeline(ClassChatMiddleware(), function_chat_middleware)
|
||||
@@ -997,27 +977,21 @@ class TestMultipleMiddlewareOrdering:
|
||||
execution_order: list[str] = []
|
||||
|
||||
class FirstMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("first_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("first_after")
|
||||
|
||||
class SecondMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("second_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("second_after")
|
||||
|
||||
class ThirdMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("third_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("third_after")
|
||||
|
||||
middleware = [FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()]
|
||||
@@ -1051,20 +1025,20 @@ class TestMultipleMiddlewareOrdering:
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append("first_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("first_after")
|
||||
|
||||
class SecondMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append("second_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("second_after")
|
||||
|
||||
middleware = [FirstMiddleware(), SecondMiddleware()]
|
||||
@@ -1087,21 +1061,21 @@ class TestMultipleMiddlewareOrdering:
|
||||
execution_order: list[str] = []
|
||||
|
||||
class FirstChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("first_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("first_after")
|
||||
|
||||
class SecondChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("second_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("second_after")
|
||||
|
||||
class ThirdChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("third_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("third_after")
|
||||
|
||||
middleware = [FirstChatMiddleware(), SecondChatMiddleware(), ThirdChatMiddleware()]
|
||||
@@ -1136,9 +1110,7 @@ class TestContextContentValidation:
|
||||
"""Test that agent context contains expected data."""
|
||||
|
||||
class ContextValidationMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Verify context has all expected attributes
|
||||
assert hasattr(context, "agent")
|
||||
assert hasattr(context, "messages")
|
||||
@@ -1156,7 +1128,7 @@ class TestContextContentValidation:
|
||||
# Add custom metadata
|
||||
context.metadata["validated"] = True
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
middleware = ContextValidationMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline(middleware)
|
||||
@@ -1178,7 +1150,7 @@ class TestContextContentValidation:
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
# Verify context has all expected attributes
|
||||
assert hasattr(context, "function")
|
||||
@@ -1194,7 +1166,7 @@ class TestContextContentValidation:
|
||||
# Add custom metadata
|
||||
context.metadata["validated"] = True
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
middleware = ContextValidationMiddleware()
|
||||
pipeline = FunctionMiddlewarePipeline(middleware)
|
||||
@@ -1213,7 +1185,7 @@ class TestContextContentValidation:
|
||||
"""Test that chat context contains expected data."""
|
||||
|
||||
class ChatContextValidationMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Verify context has all expected attributes
|
||||
assert hasattr(context, "client")
|
||||
assert hasattr(context, "messages")
|
||||
@@ -1235,7 +1207,7 @@ class TestContextContentValidation:
|
||||
# Add custom metadata
|
||||
context.metadata["validated"] = True
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
middleware = ChatContextValidationMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline(middleware)
|
||||
@@ -1260,11 +1232,9 @@ class TestStreamingScenarios:
|
||||
streaming_flags: list[bool] = []
|
||||
|
||||
class StreamingFlagMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
streaming_flags.append(context.stream)
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
middleware = StreamingFlagMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline(middleware)
|
||||
@@ -1302,11 +1272,9 @@ class TestStreamingScenarios:
|
||||
chunks_processed: list[str] = []
|
||||
|
||||
class StreamProcessingMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
chunks_processed.append("before_stream")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
chunks_processed.append("after_stream")
|
||||
|
||||
middleware = StreamProcessingMiddleware()
|
||||
@@ -1345,9 +1313,9 @@ class TestStreamingScenarios:
|
||||
streaming_flags: list[bool] = []
|
||||
|
||||
class ChatStreamingFlagMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
streaming_flags.append(context.stream)
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
middleware = ChatStreamingFlagMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline(middleware)
|
||||
@@ -1386,9 +1354,9 @@ class TestStreamingScenarios:
|
||||
chunks_processed: list[str] = []
|
||||
|
||||
class ChatStreamProcessingMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
chunks_processed.append("before_stream")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
chunks_processed.append("after_stream")
|
||||
|
||||
middleware = ChatStreamProcessingMiddleware()
|
||||
@@ -1436,24 +1404,22 @@ class FunctionTestArgs(BaseModel):
|
||||
class TestAgentMiddleware(AgentMiddleware):
|
||||
"""Test implementation of AgentMiddleware."""
|
||||
|
||||
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
|
||||
await call_next(context)
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
await call_next()
|
||||
|
||||
|
||||
class TestFunctionMiddleware(FunctionMiddleware):
|
||||
"""Test implementation of FunctionMiddleware."""
|
||||
|
||||
async def process(
|
||||
self, context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
await call_next(context)
|
||||
async def process(self, context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
await call_next()
|
||||
|
||||
|
||||
class TestChatMiddleware(ChatMiddleware):
|
||||
"""Test implementation of ChatMiddleware."""
|
||||
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
await call_next(context)
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
await call_next()
|
||||
|
||||
|
||||
class MockFunctionArgs(BaseModel):
|
||||
@@ -1469,9 +1435,7 @@ class TestMiddlewareExecutionControl:
|
||||
"""Test that when agent middleware doesn't call next(), no execution happens."""
|
||||
|
||||
class NoNextMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Don't call next() - this should prevent any execution
|
||||
pass
|
||||
|
||||
@@ -1498,9 +1462,7 @@ class TestMiddlewareExecutionControl:
|
||||
"""Test that when agent middleware doesn't call next(), no streaming execution happens."""
|
||||
|
||||
class NoNextStreamingMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Don't call next() - this should prevent any execution
|
||||
pass
|
||||
|
||||
@@ -1537,7 +1499,7 @@ class TestMiddlewareExecutionControl:
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
# Don't call next() - this should prevent any execution
|
||||
pass
|
||||
@@ -1566,18 +1528,14 @@ class TestMiddlewareExecutionControl:
|
||||
execution_order: list[str] = []
|
||||
|
||||
class FirstMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("first")
|
||||
# Don't call next() - this should stop the pipeline
|
||||
|
||||
class SecondMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("second")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
pipeline = AgentMiddlewarePipeline(FirstMiddleware(), SecondMiddleware())
|
||||
messages = [Message(role="user", text="test")]
|
||||
@@ -1601,7 +1559,7 @@ class TestMiddlewareExecutionControl:
|
||||
"""Test that when chat middleware doesn't call next(), no execution happens."""
|
||||
|
||||
class NoNextChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Don't call next() - this should prevent any execution
|
||||
pass
|
||||
|
||||
@@ -1629,7 +1587,7 @@ class TestMiddlewareExecutionControl:
|
||||
"""Test that when chat middleware doesn't call next(), no streaming execution happens."""
|
||||
|
||||
class NoNextStreamingChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Don't call next() - this should prevent any execution
|
||||
pass
|
||||
|
||||
@@ -1670,14 +1628,14 @@ class TestMiddlewareExecutionControl:
|
||||
execution_order: list[str] = []
|
||||
|
||||
class FirstChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("first")
|
||||
# Don't call next() - this should stop the pipeline
|
||||
|
||||
class SecondChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("second")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
pipeline = ChatMiddlewarePipeline(FirstChatMiddleware(), SecondChatMiddleware())
|
||||
messages = [Message(role="user", text="test")]
|
||||
|
||||
@@ -43,11 +43,9 @@ class TestResultOverrideMiddleware:
|
||||
override_response = AgentResponse(messages=[Message(role="assistant", text="overridden response")])
|
||||
|
||||
class ResponseOverrideMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Execute the pipeline first, then override the response
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
context.result = override_response
|
||||
|
||||
middleware = ResponseOverrideMiddleware()
|
||||
@@ -79,11 +77,9 @@ class TestResultOverrideMiddleware:
|
||||
yield AgentResponseUpdate(contents=[Content.from_text(text=" stream")])
|
||||
|
||||
class StreamResponseOverrideMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Execute the pipeline first, then override the response stream
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
context.result = ResponseStream(override_stream())
|
||||
|
||||
middleware = StreamResponseOverrideMiddleware()
|
||||
@@ -115,10 +111,10 @@ class TestResultOverrideMiddleware:
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
# Execute the pipeline first, then override the result
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
context.result = override_result
|
||||
|
||||
middleware = ResultOverrideMiddleware()
|
||||
@@ -145,11 +141,9 @@ class TestResultOverrideMiddleware:
|
||||
mock_chat_client = MockChatClient()
|
||||
|
||||
class ChatAgentResponseOverrideMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Always call next() first to allow execution
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
# Then conditionally override based on content
|
||||
if any("special" in msg.text for msg in context.messages if msg.text):
|
||||
context.result = AgentResponse(
|
||||
@@ -184,15 +178,13 @@ class TestResultOverrideMiddleware:
|
||||
yield AgentResponseUpdate(contents=[Content.from_text(text=" response!")])
|
||||
|
||||
class ChatAgentStreamOverrideMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Check if we want to override BEFORE calling next to avoid creating unused streams
|
||||
if any("custom stream" in msg.text for msg in context.messages if msg.text):
|
||||
context.result = ResponseStream(custom_stream())
|
||||
return # Don't call next() - we're overriding the entire result
|
||||
# Normal case - let the agent handle it
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Create Agent with override middleware
|
||||
middleware = ChatAgentStreamOverrideMiddleware()
|
||||
@@ -223,12 +215,10 @@ class TestResultOverrideMiddleware:
|
||||
"""Test that when agent middleware conditionally doesn't call next(), no execution happens."""
|
||||
|
||||
class ConditionalNoNextMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Only call next() if message contains "execute"
|
||||
if any("execute" in msg.text for msg in context.messages if msg.text):
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
# Otherwise, don't call next() - no execution should happen
|
||||
|
||||
middleware = ConditionalNoNextMiddleware()
|
||||
@@ -269,13 +259,13 @@ class TestResultOverrideMiddleware:
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
# Only call next() if argument name contains "execute"
|
||||
args = context.arguments
|
||||
assert isinstance(args, FunctionTestArgs)
|
||||
if "execute" in args.name:
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
# Otherwise, don't call next() - no execution should happen
|
||||
|
||||
middleware = ConditionalNoNextFunctionMiddleware()
|
||||
@@ -318,14 +308,12 @@ class TestResultObservability:
|
||||
observed_responses: list[AgentResponse] = []
|
||||
|
||||
class ObservabilityMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Context should be empty before next()
|
||||
assert context.result is None
|
||||
|
||||
# Call next to execute
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Context should now contain the response for observability
|
||||
assert context.result is not None
|
||||
@@ -355,13 +343,13 @@ class TestResultObservability:
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
# Context should be empty before next()
|
||||
assert context.result is None
|
||||
|
||||
# Call next to execute
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Context should now contain the result for observability
|
||||
assert context.result is not None
|
||||
@@ -386,11 +374,9 @@ class TestResultObservability:
|
||||
"""Test that middleware can override response after observing execution."""
|
||||
|
||||
class PostExecutionOverrideMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Call next to execute first
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Now observe and conditionally override
|
||||
assert context.result is not None
|
||||
@@ -423,10 +409,10 @@ class TestResultObservability:
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
# Call next to execute first
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Now observe and conditionally override
|
||||
assert context.result is not None
|
||||
|
||||
@@ -44,11 +44,9 @@ class TestChatAgentClassBasedMiddleware:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
# Create Agent with middleware
|
||||
@@ -76,9 +74,9 @@ class TestChatAgentClassBasedMiddleware:
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
middleware = TrackingFunctionMiddleware()
|
||||
Agent(client=client, middleware=[middleware])
|
||||
@@ -96,10 +94,10 @@ class TestChatAgentClassBasedMiddleware:
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
middleware = TrackingFunctionMiddleware("function_middleware")
|
||||
@@ -122,13 +120,11 @@ class TestChatAgentFunctionBasedMiddleware:
|
||||
execution_order: list[str] = []
|
||||
|
||||
class PreTerminationMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("middleware_before")
|
||||
raise MiddlewareTermination
|
||||
# Code after raise is unreachable
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("middleware_after")
|
||||
|
||||
# Create Agent with terminating middleware
|
||||
@@ -153,11 +149,9 @@ class TestChatAgentFunctionBasedMiddleware:
|
||||
execution_order: list[str] = []
|
||||
|
||||
class PostTerminationMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("middleware_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("middleware_after")
|
||||
context.terminate = True
|
||||
|
||||
@@ -193,12 +187,12 @@ class TestChatAgentFunctionBasedMiddleware:
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append("middleware_before")
|
||||
context.terminate = True
|
||||
# We call next() but since terminate=True, subsequent middleware and handler should not execute
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("middleware_after")
|
||||
|
||||
Agent(client=client, middleware=[PreTerminationFunctionMiddleware()], tools=[])
|
||||
@@ -211,10 +205,10 @@ class TestChatAgentFunctionBasedMiddleware:
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append("middleware_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("middleware_after")
|
||||
context.terminate = True
|
||||
|
||||
@@ -224,11 +218,9 @@ class TestChatAgentFunctionBasedMiddleware:
|
||||
"""Test function-based agent middleware with Agent."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def tracking_agent_middleware(
|
||||
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def tracking_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("agent_function_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("agent_function_after")
|
||||
|
||||
# Create Agent with function middleware
|
||||
@@ -252,9 +244,9 @@ class TestChatAgentFunctionBasedMiddleware:
|
||||
"""Test function-based function middleware with Agent."""
|
||||
|
||||
async def tracking_function_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
Agent(client=client, middleware=[tracking_function_middleware])
|
||||
|
||||
@@ -265,10 +257,10 @@ class TestChatAgentFunctionBasedMiddleware:
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def tracking_function_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("function_function_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("function_function_after")
|
||||
|
||||
agent = Agent(client=chat_client_base, middleware=[tracking_function_middleware])
|
||||
@@ -290,12 +282,10 @@ class TestChatAgentStreamingMiddleware:
|
||||
streaming_flags: list[bool] = []
|
||||
|
||||
class StreamingTrackingMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("middleware_before")
|
||||
streaming_flags.append(context.stream)
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("middleware_after")
|
||||
|
||||
# Create Agent with middleware
|
||||
@@ -334,11 +324,9 @@ class TestChatAgentStreamingMiddleware:
|
||||
streaming_flags: list[bool] = []
|
||||
|
||||
class FlagTrackingMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
streaming_flags.append(context.stream)
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Create Agent with middleware
|
||||
middleware = FlagTrackingMiddleware()
|
||||
@@ -368,11 +356,9 @@ class TestChatAgentMultipleMiddlewareOrdering:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
# Create multiple middleware
|
||||
@@ -400,35 +386,31 @@ class TestChatAgentMultipleMiddlewareOrdering:
|
||||
execution_order: list[str] = []
|
||||
|
||||
class ClassAgentMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("class_agent_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("class_agent_after")
|
||||
|
||||
async def function_agent_middleware(
|
||||
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def function_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("function_agent_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("function_agent_after")
|
||||
|
||||
class ClassFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append("class_function_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("class_function_after")
|
||||
|
||||
async def function_function_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("function_function_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("function_function_after")
|
||||
|
||||
agent = Agent(
|
||||
@@ -447,25 +429,21 @@ class TestChatAgentMultipleMiddlewareOrdering:
|
||||
execution_order: list[str] = []
|
||||
|
||||
class ClassAgentMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("class_agent_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("class_agent_after")
|
||||
|
||||
async def function_agent_middleware(
|
||||
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def function_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("function_agent_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("function_agent_after")
|
||||
|
||||
async def function_function_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("function_function_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("function_function_after")
|
||||
|
||||
agent = Agent(
|
||||
@@ -521,10 +499,10 @@ class TestChatAgentFunctionMiddlewareWithTools:
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
# Set up mock to return a function call first, then a regular response
|
||||
@@ -583,10 +561,10 @@ class TestChatAgentFunctionMiddlewareWithTools:
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def tracking_function_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("function_middleware_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("function_middleware_after")
|
||||
|
||||
# Set up mock to return a function call first, then a regular response
|
||||
@@ -647,20 +625,20 @@ class TestChatAgentFunctionMiddlewareWithTools:
|
||||
async def process(
|
||||
self,
|
||||
context: AgentContext,
|
||||
call_next: Callable[[AgentContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append("agent_middleware_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("agent_middleware_after")
|
||||
|
||||
class TrackingFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append("function_middleware_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("function_middleware_after")
|
||||
|
||||
# Set up mock to return a function call first, then a regular response
|
||||
@@ -728,7 +706,7 @@ class TestChatAgentFunctionMiddlewareWithTools:
|
||||
|
||||
@function_middleware
|
||||
async def kwargs_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
nonlocal middleware_called
|
||||
middleware_called = True
|
||||
@@ -748,7 +726,7 @@ class TestChatAgentFunctionMiddlewareWithTools:
|
||||
modified_kwargs["new_param"] = context.kwargs.get("new_param")
|
||||
modified_kwargs["custom_param"] = context.kwargs.get("custom_param")
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
@@ -801,9 +779,9 @@ class TestMiddlewareDynamicRebuild:
|
||||
self.name = name
|
||||
self.execution_log = execution_log
|
||||
|
||||
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
self.execution_log.append(f"{self.name}_start")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
self.execution_log.append(f"{self.name}_end")
|
||||
|
||||
async def test_middleware_dynamic_rebuild_non_streaming(self, client: "MockChatClient") -> None:
|
||||
@@ -924,9 +902,9 @@ class TestRunLevelMiddleware:
|
||||
self.name = name
|
||||
self.execution_log = execution_log
|
||||
|
||||
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
self.execution_log.append(f"{self.name}_start")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
self.execution_log.append(f"{self.name}_end")
|
||||
|
||||
async def test_run_level_middleware_isolation(self, client: "MockChatClient") -> None:
|
||||
@@ -976,29 +954,25 @@ class TestRunLevelMiddleware:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_log.append(f"{self.name}_start")
|
||||
# Set metadata to pass information to run middleware
|
||||
context.metadata[f"{self.name}_key"] = f"{self.name}_value"
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_log.append(f"{self.name}_end")
|
||||
|
||||
class MetadataRunMiddleware(AgentMiddleware):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_log.append(f"{self.name}_start")
|
||||
# Read metadata set by agent middleware
|
||||
for key, value in context.metadata.items():
|
||||
metadata_log.append(f"{self.name}_reads_{key}:{value}")
|
||||
# Set run-level metadata
|
||||
context.metadata[f"{self.name}_key"] = f"{self.name}_value"
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_log.append(f"{self.name}_end")
|
||||
|
||||
# Create agent with agent-level middleware
|
||||
@@ -1049,12 +1023,10 @@ class TestRunLevelMiddleware:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_log.append(f"{self.name}_start")
|
||||
streaming_flags.append(context.stream)
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_log.append(f"{self.name}_end")
|
||||
|
||||
# Create agent without agent-level middleware
|
||||
@@ -1093,48 +1065,44 @@ class TestRunLevelMiddleware:
|
||||
|
||||
# Agent-level middleware
|
||||
class AgentLevelAgentMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_log.append("agent_level_agent_start")
|
||||
context.metadata["agent_level_agent"] = "processed"
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_log.append("agent_level_agent_end")
|
||||
|
||||
class AgentLevelFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_log.append("agent_level_function_start")
|
||||
context.metadata["agent_level_function"] = "processed"
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_log.append("agent_level_function_end")
|
||||
|
||||
# Run-level middleware
|
||||
class RunLevelAgentMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_log.append("run_level_agent_start")
|
||||
# Verify agent-level middleware metadata is available
|
||||
assert "agent_level_agent" in context.metadata
|
||||
context.metadata["run_level_agent"] = "processed"
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_log.append("run_level_agent_end")
|
||||
|
||||
class RunLevelFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_log.append("run_level_function_start")
|
||||
# Verify agent-level function middleware metadata is available
|
||||
assert "agent_level_function" in context.metadata
|
||||
context.metadata["run_level_function"] = "processed"
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_log.append("run_level_function_end")
|
||||
|
||||
# Create tool function for testing function middleware
|
||||
@@ -1217,18 +1185,16 @@ class TestMiddlewareDecoratorLogic:
|
||||
execution_order: list[str] = []
|
||||
|
||||
@agent_middleware
|
||||
async def matching_agent_middleware(
|
||||
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def matching_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("decorator_type_match_agent")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
@function_middleware
|
||||
async def matching_function_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("decorator_type_match_function")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Create tool function for testing function middleware
|
||||
def custom_tool(message: str) -> str:
|
||||
@@ -1282,7 +1248,7 @@ class TestMiddlewareDecoratorLogic:
|
||||
context: FunctionInvocationContext, # Wrong type for @agent_middleware
|
||||
call_next: Any,
|
||||
) -> None:
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
agent = Agent(client=client, middleware=[mismatched_middleware])
|
||||
await agent.run([Message(role="user", text="test")])
|
||||
@@ -1294,12 +1260,12 @@ class TestMiddlewareDecoratorLogic:
|
||||
@agent_middleware
|
||||
async def decorator_only_agent(context: Any, call_next: Any) -> None: # No type annotation
|
||||
execution_order.append("decorator_only_agent")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
@function_middleware
|
||||
async def decorator_only_function(context: Any, call_next: Any) -> None: # No type annotation
|
||||
execution_order.append("decorator_only_function")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Create tool function for testing function middleware
|
||||
def custom_tool(message: str) -> str:
|
||||
@@ -1346,16 +1312,16 @@ class TestMiddlewareDecoratorLogic:
|
||||
execution_order: list[str] = []
|
||||
|
||||
# No decorator
|
||||
async def type_only_agent(context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
|
||||
async def type_only_agent(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("type_only_agent")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# No decorator
|
||||
async def type_only_function(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("type_only_function")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Create tool function for testing function middleware
|
||||
def custom_tool(message: str) -> str:
|
||||
@@ -1399,7 +1365,7 @@ class TestMiddlewareDecoratorLogic:
|
||||
"""Neither decorator nor parameter type specified - should throw exception."""
|
||||
|
||||
async def no_info_middleware(context: Any, call_next: Any) -> None: # No decorator, no type
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Should raise MiddlewareException
|
||||
with pytest.raises(MiddlewareException, match="Cannot determine middleware type"):
|
||||
@@ -1447,9 +1413,7 @@ class TestChatAgentThreadBehavior:
|
||||
thread_states: list[dict[str, Any]] = []
|
||||
|
||||
class ThreadTrackingMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Capture state before next() call
|
||||
thread_messages = []
|
||||
if context.thread and context.thread.message_store:
|
||||
@@ -1464,7 +1428,7 @@ class TestChatAgentThreadBehavior:
|
||||
}
|
||||
thread_states.append(before_state)
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Capture state after next() call
|
||||
thread_messages_after = []
|
||||
@@ -1560,9 +1524,9 @@ class TestChatAgentChatMiddleware:
|
||||
execution_order: list[str] = []
|
||||
|
||||
class TrackingChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("chat_middleware_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("chat_middleware_after")
|
||||
|
||||
# Create Agent with chat middleware
|
||||
@@ -1588,11 +1552,9 @@ class TestChatAgentChatMiddleware:
|
||||
"""Test function-based chat middleware with Agent."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def tracking_chat_middleware(
|
||||
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def tracking_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("chat_middleware_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("chat_middleware_after")
|
||||
|
||||
# Create Agent with function-based chat middleware
|
||||
@@ -1617,9 +1579,7 @@ class TestChatAgentChatMiddleware:
|
||||
"""Test that chat middleware can modify messages before sending to model."""
|
||||
|
||||
@chat_middleware
|
||||
async def message_modifier_middleware(
|
||||
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def message_modifier_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Modify the first message by adding a prefix
|
||||
if context.messages:
|
||||
for idx, msg in enumerate(context.messages):
|
||||
@@ -1628,7 +1588,7 @@ class TestChatAgentChatMiddleware:
|
||||
original_text = msg.text or ""
|
||||
context.messages[idx] = Message(role=msg.role, text=f"MODIFIED: {original_text}")
|
||||
break
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Create Agent with message-modifying middleware
|
||||
client = MockBaseChatClient()
|
||||
@@ -1646,9 +1606,7 @@ class TestChatAgentChatMiddleware:
|
||||
"""Test that chat middleware can override the response."""
|
||||
|
||||
@chat_middleware
|
||||
async def response_override_middleware(
|
||||
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def response_override_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Override the response without calling next()
|
||||
context.result = ChatResponse(
|
||||
messages=[Message(role="assistant", text="MiddlewareTypes overridden response")],
|
||||
@@ -1675,15 +1633,15 @@ class TestChatAgentChatMiddleware:
|
||||
execution_order: list[str] = []
|
||||
|
||||
@chat_middleware
|
||||
async def first_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("first_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("first_after")
|
||||
|
||||
@chat_middleware
|
||||
async def second_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("second_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("second_after")
|
||||
|
||||
# Create Agent with multiple chat middleware
|
||||
@@ -1709,10 +1667,10 @@ class TestChatAgentChatMiddleware:
|
||||
streaming_flags: list[bool] = []
|
||||
|
||||
class StreamingTrackingChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("streaming_chat_before")
|
||||
streaming_flags.append(context.stream)
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("streaming_chat_after")
|
||||
|
||||
# Create Agent with chat middleware
|
||||
@@ -1749,13 +1707,13 @@ class TestChatAgentChatMiddleware:
|
||||
execution_order: list[str] = []
|
||||
|
||||
class PreTerminationChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("middleware_before")
|
||||
# Set a custom response since we're terminating
|
||||
context.result = ChatResponse(messages=[Message(role="assistant", text="Terminated by middleware")])
|
||||
raise MiddlewareTermination
|
||||
# We call next() but since terminate=True, execution should stop
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("middleware_after")
|
||||
|
||||
# Create Agent with terminating middleware
|
||||
@@ -1777,9 +1735,9 @@ class TestChatAgentChatMiddleware:
|
||||
execution_order: list[str] = []
|
||||
|
||||
class PostTerminationChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("middleware_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("middleware_after")
|
||||
context.terminate = True
|
||||
|
||||
@@ -1804,21 +1762,21 @@ class TestChatAgentChatMiddleware:
|
||||
"""Test Agent with combined middleware types."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def agent_middleware(context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
|
||||
async def agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("agent_middleware_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("agent_middleware_after")
|
||||
|
||||
async def chat_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("chat_middleware_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("chat_middleware_after")
|
||||
|
||||
async def function_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("function_middleware_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("function_middleware_after")
|
||||
|
||||
# Create Agent with function middleware and tools
|
||||
@@ -1842,9 +1800,7 @@ class TestChatAgentChatMiddleware:
|
||||
modified_kwargs: dict[str, Any] = {}
|
||||
|
||||
@agent_middleware
|
||||
async def kwargs_middleware(
|
||||
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def kwargs_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Capture the original kwargs
|
||||
captured_kwargs.update(context.kwargs)
|
||||
|
||||
@@ -1856,7 +1812,7 @@ class TestChatAgentChatMiddleware:
|
||||
# Store modified kwargs for verification
|
||||
modified_kwargs.update(context.kwargs)
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Create Agent with agent middleware
|
||||
client = MockBaseChatClient()
|
||||
@@ -1895,10 +1851,10 @@ class TestChatAgentChatMiddleware:
|
||||
|
||||
# class TrackingMiddleware(AgentMiddleware):
|
||||
# async def process(
|
||||
# self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
|
||||
# self, context: AgentContext, call_next: Callable[[], Awaitable[None]]
|
||||
# ) -> None:
|
||||
# execution_order.append("before")
|
||||
# await call_next(context)
|
||||
# await call_next()
|
||||
# execution_order.append("after")
|
||||
|
||||
# @use_agent_middleware
|
||||
|
||||
@@ -32,10 +32,10 @@ class TestChatMiddleware:
|
||||
async def process(
|
||||
self,
|
||||
context: ChatContext,
|
||||
call_next: Callable[[ChatContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append("chat_middleware_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("chat_middleware_after")
|
||||
|
||||
# Add middleware to chat client
|
||||
@@ -58,11 +58,9 @@ class TestChatMiddleware:
|
||||
execution_order: list[str] = []
|
||||
|
||||
@chat_middleware
|
||||
async def logging_chat_middleware(
|
||||
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def logging_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("function_middleware_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("function_middleware_after")
|
||||
|
||||
# Add middleware to chat client
|
||||
@@ -84,14 +82,12 @@ class TestChatMiddleware:
|
||||
"""Test that chat middleware can modify messages before sending to model."""
|
||||
|
||||
@chat_middleware
|
||||
async def message_modifier_middleware(
|
||||
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def message_modifier_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Modify the first message by adding a prefix
|
||||
if context.messages and len(context.messages) > 0:
|
||||
original_text = context.messages[0].text or ""
|
||||
context.messages[0] = Message(role=context.messages[0].role, text=f"MODIFIED: {original_text}")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Add middleware to chat client
|
||||
chat_client_base.chat_middleware = [message_modifier_middleware]
|
||||
@@ -110,9 +106,7 @@ class TestChatMiddleware:
|
||||
"""Test that chat middleware can override the response."""
|
||||
|
||||
@chat_middleware
|
||||
async def response_override_middleware(
|
||||
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def response_override_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Override the response without calling next()
|
||||
context.result = ChatResponse(
|
||||
messages=[Message(role="assistant", text="MiddlewareTypes overridden response")],
|
||||
@@ -138,15 +132,15 @@ class TestChatMiddleware:
|
||||
execution_order: list[str] = []
|
||||
|
||||
@chat_middleware
|
||||
async def first_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("first_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("first_after")
|
||||
|
||||
@chat_middleware
|
||||
async def second_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("second_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("second_after")
|
||||
|
||||
# Add middleware to chat client (order should be preserved)
|
||||
@@ -173,11 +167,9 @@ class TestChatMiddleware:
|
||||
execution_order: list[str] = []
|
||||
|
||||
@chat_middleware
|
||||
async def agent_level_chat_middleware(
|
||||
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def agent_level_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("agent_chat_middleware_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("agent_chat_middleware_after")
|
||||
|
||||
client = MockBaseChatClient()
|
||||
@@ -205,15 +197,15 @@ class TestChatMiddleware:
|
||||
execution_order: list[str] = []
|
||||
|
||||
@chat_middleware
|
||||
async def first_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("first_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("first_after")
|
||||
|
||||
@chat_middleware
|
||||
async def second_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("second_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("second_after")
|
||||
|
||||
# Create Agent with multiple chat middleware
|
||||
@@ -240,9 +232,7 @@ class TestChatMiddleware:
|
||||
execution_order: list[str] = []
|
||||
|
||||
@chat_middleware
|
||||
async def streaming_middleware(
|
||||
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def streaming_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_order.append("streaming_before")
|
||||
# Verify it's a streaming context
|
||||
assert context.stream is True
|
||||
@@ -254,7 +244,7 @@ class TestChatMiddleware:
|
||||
return update
|
||||
|
||||
context.stream_transform_hooks.append(upper_case_update)
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("streaming_after")
|
||||
|
||||
# Add middleware to chat client
|
||||
@@ -278,11 +268,9 @@ class TestChatMiddleware:
|
||||
execution_count = {"count": 0}
|
||||
|
||||
@chat_middleware
|
||||
async def counting_middleware(
|
||||
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
async def counting_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
execution_count["count"] += 1
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# First call with run-level middleware
|
||||
messages = [Message(role="user", text="first message")]
|
||||
@@ -310,7 +298,7 @@ class TestChatMiddleware:
|
||||
modified_kwargs: dict[str, Any] = {}
|
||||
|
||||
@chat_middleware
|
||||
async def kwargs_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def kwargs_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Capture the original kwargs
|
||||
captured_kwargs.update(context.kwargs)
|
||||
|
||||
@@ -322,7 +310,7 @@ class TestChatMiddleware:
|
||||
# Store modified kwargs for verification
|
||||
modified_kwargs.update(context.kwargs)
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Add middleware to chat client
|
||||
chat_client_base.chat_middleware = [kwargs_middleware]
|
||||
@@ -355,11 +343,11 @@ class TestChatMiddleware:
|
||||
|
||||
@function_middleware
|
||||
async def test_function_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
nonlocal execution_order
|
||||
execution_order.append(f"function_middleware_before_{context.function.name}")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append(f"function_middleware_after_{context.function.name}")
|
||||
|
||||
# Define a simple tool function
|
||||
@@ -421,10 +409,10 @@ class TestChatMiddleware:
|
||||
|
||||
@function_middleware
|
||||
async def run_level_function_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("run_level_function_middleware_before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
execution_order.append("run_level_function_middleware_after")
|
||||
|
||||
# Define a simple tool function
|
||||
|
||||
@@ -798,12 +798,8 @@ def test_chat_message_with_error_content() -> None:
|
||||
|
||||
result = client._prepare_message_for_openai(message, call_id_to_id)
|
||||
|
||||
# Message should be prepared with empty content list since ErrorContent returns {}
|
||||
assert len(result) == 1
|
||||
prepared_message = result[0]
|
||||
assert prepared_message["role"] == "assistant"
|
||||
# Content should be a list with empty dict since ErrorContent returns {}
|
||||
assert prepared_message.get("content") == [{}]
|
||||
# Message should be empty since ErrorContent is filtered out
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_chat_message_with_usage_content() -> None:
|
||||
@@ -823,12 +819,8 @@ def test_chat_message_with_usage_content() -> None:
|
||||
|
||||
result = client._prepare_message_for_openai(message, call_id_to_id)
|
||||
|
||||
# Message should be prepared with empty content list since UsageContent returns {}
|
||||
assert len(result) == 1
|
||||
prepared_message = result[0]
|
||||
assert prepared_message["role"] == "assistant"
|
||||
# Content should be a list with empty dict since UsageContent returns {}
|
||||
assert prepared_message.get("content") == [{}]
|
||||
# Message should be empty since UsageContent is filtered out
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_hosted_file_content_preparation() -> None:
|
||||
|
||||
@@ -207,7 +207,7 @@ def test_serialize(ollama_unit_test_env: dict[str, str]) -> None:
|
||||
def test_chat_middleware(ollama_unit_test_env: dict[str, str]) -> None:
|
||||
@chat_middleware
|
||||
async def sample_middleware(context, call_next):
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
ollama_chat_client = OllamaChatClient(middleware=[sample_middleware])
|
||||
assert len(ollama_chat_client.middleware) == 1
|
||||
|
||||
@@ -129,11 +129,11 @@ class _AutoHandoffMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Intercept matching handoff tool calls and inject synthetic results."""
|
||||
if context.function.name not in self._handoff_functions:
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
return
|
||||
|
||||
from agent_framework._middleware import MiddlewareTermination
|
||||
|
||||
@@ -65,7 +65,7 @@ class PurviewPolicyMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: AgentContext,
|
||||
call_next: Callable[[AgentContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None: # type: ignore[override]
|
||||
resolved_user_id: str | None = None
|
||||
try:
|
||||
@@ -92,7 +92,7 @@ class PurviewPolicyMiddleware(AgentMiddleware):
|
||||
if not self._settings.ignore_exceptions:
|
||||
raise
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
try:
|
||||
# Post (response) check only if we have a normal AgentResponse
|
||||
@@ -162,7 +162,7 @@ class PurviewChatPolicyMiddleware(ChatMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: ChatContext,
|
||||
call_next: Callable[[ChatContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None: # type: ignore[override]
|
||||
resolved_user_id: str | None = None
|
||||
try:
|
||||
@@ -187,7 +187,7 @@ class PurviewChatPolicyMiddleware(ChatMiddleware):
|
||||
if not self._settings.ignore_exceptions:
|
||||
raise
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
try:
|
||||
# Post (response) evaluation only if non-streaming and we have messages result shape
|
||||
|
||||
@@ -49,7 +49,7 @@ class TestPurviewChatPolicyMiddleware:
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
|
||||
next_called = False
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None:
|
||||
async def mock_next() -> None:
|
||||
nonlocal next_called
|
||||
next_called = True
|
||||
|
||||
@@ -57,7 +57,7 @@ class TestPurviewChatPolicyMiddleware:
|
||||
def __init__(self):
|
||||
self.messages = [Message(role="assistant", text="Hi there")]
|
||||
|
||||
ctx.result = Result()
|
||||
chat_context.result = Result()
|
||||
|
||||
await middleware.process(chat_context, mock_next)
|
||||
assert next_called
|
||||
@@ -67,7 +67,7 @@ class TestPurviewChatPolicyMiddleware:
|
||||
async def test_blocks_prompt(self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext) -> None:
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")):
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None: # should not run
|
||||
async def mock_next() -> None: # should not run
|
||||
raise AssertionError("next should not be called when prompt blocked")
|
||||
|
||||
with pytest.raises(MiddlewareTermination):
|
||||
@@ -88,12 +88,12 @@ class TestPurviewChatPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", side_effect=side_effect):
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None:
|
||||
async def mock_next() -> None:
|
||||
class Result:
|
||||
def __init__(self):
|
||||
self.messages = [Message(role="assistant", text="Sensitive output")] # pragma: no cover
|
||||
|
||||
ctx.result = Result()
|
||||
chat_context.result = Result()
|
||||
|
||||
await middleware.process(chat_context, mock_next)
|
||||
assert call_state["count"] == 2
|
||||
@@ -114,8 +114,8 @@ class TestPurviewChatPolicyMiddleware:
|
||||
)
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None:
|
||||
ctx.result = MagicMock()
|
||||
async def mock_next() -> None:
|
||||
streaming_context.result = MagicMock()
|
||||
|
||||
await middleware.process(streaming_context, mock_next)
|
||||
assert mock_proc.call_count == 1
|
||||
@@ -138,10 +138,10 @@ class TestPurviewChatPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None:
|
||||
async def mock_next() -> None:
|
||||
result = MagicMock()
|
||||
result.messages = [Message(role="assistant", text="Response")]
|
||||
ctx.result = result
|
||||
chat_context.result = result
|
||||
|
||||
await middleware.process(chat_context, mock_next)
|
||||
|
||||
@@ -162,10 +162,10 @@ class TestPurviewChatPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None:
|
||||
async def mock_next() -> None:
|
||||
result = MagicMock()
|
||||
result.messages = [Message(role="assistant", text="Response")]
|
||||
ctx.result = result
|
||||
chat_context.result = result
|
||||
|
||||
await middleware.process(chat_context, mock_next)
|
||||
|
||||
@@ -194,7 +194,7 @@ class TestPurviewChatPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None:
|
||||
async def mock_next() -> None:
|
||||
raise AssertionError("next should not be called")
|
||||
|
||||
# Should raise the exception
|
||||
@@ -224,10 +224,10 @@ class TestPurviewChatPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", side_effect=side_effect):
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None:
|
||||
async def mock_next() -> None:
|
||||
result = MagicMock()
|
||||
result.messages = [Message(role="assistant", text="OK")]
|
||||
ctx.result = result
|
||||
context.result = result
|
||||
|
||||
with pytest.raises(PurviewPaymentRequiredError):
|
||||
await middleware.process(context, mock_next)
|
||||
@@ -249,7 +249,7 @@ class TestPurviewChatPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None:
|
||||
async def mock_next() -> None:
|
||||
result = MagicMock()
|
||||
result.messages = [Message(role="assistant", text="Response")]
|
||||
context.result = result
|
||||
@@ -265,9 +265,9 @@ class TestPurviewChatPolicyMiddleware:
|
||||
"""Test middleware handles result that doesn't have messages attribute."""
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")):
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None:
|
||||
async def mock_next() -> None:
|
||||
# Set result to something without messages attribute
|
||||
ctx.result = "Some string result"
|
||||
chat_context.result = "Some string result"
|
||||
|
||||
await middleware.process(chat_context, mock_next)
|
||||
|
||||
@@ -289,7 +289,7 @@ class TestPurviewChatPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None:
|
||||
async def mock_next() -> None:
|
||||
result = MagicMock()
|
||||
result.messages = [Message(role="assistant", text="Response")]
|
||||
context.result = result
|
||||
@@ -313,7 +313,7 @@ class TestPurviewChatPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", side_effect=ValueError("boom")):
|
||||
|
||||
async def mock_next(_: ChatContext) -> None:
|
||||
async def mock_next() -> None:
|
||||
raise AssertionError("next should not be called")
|
||||
|
||||
with pytest.raises(ValueError, match="boom"):
|
||||
@@ -342,10 +342,10 @@ class TestPurviewChatPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", side_effect=side_effect):
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None:
|
||||
async def mock_next() -> None:
|
||||
result = MagicMock()
|
||||
result.messages = [Message(role="assistant", text="OK")]
|
||||
ctx.result = result
|
||||
context.result = result
|
||||
|
||||
with pytest.raises(ValueError, match="post"):
|
||||
await middleware.process(context, mock_next)
|
||||
@@ -361,10 +361,10 @@ class TestPurviewChatPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None:
|
||||
async def mock_next() -> None:
|
||||
result = MagicMock()
|
||||
result.messages = [Message(role="assistant", text="Hi")]
|
||||
ctx.result = result
|
||||
context.result = result
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
@@ -382,10 +382,10 @@ class TestPurviewChatPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None:
|
||||
async def mock_next() -> None:
|
||||
result = MagicMock()
|
||||
result.messages = [Message(role="assistant", text="Hi")]
|
||||
ctx.result = result
|
||||
context.result = result
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
@@ -401,10 +401,10 @@ class TestPurviewChatPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None:
|
||||
async def mock_next() -> None:
|
||||
result = MagicMock()
|
||||
result.messages = [Message(role="assistant", text="Response")]
|
||||
ctx.result = result
|
||||
context.result = result
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
|
||||
@@ -55,10 +55,10 @@ class TestPurviewPolicyMiddleware:
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")):
|
||||
next_called = False
|
||||
|
||||
async def mock_next(ctx: AgentContext) -> None:
|
||||
async def mock_next() -> None:
|
||||
nonlocal next_called
|
||||
next_called = True
|
||||
ctx.result = AgentResponse(messages=[Message(role="assistant", text="I'm good, thanks!")])
|
||||
context.result = AgentResponse(messages=[Message(role="assistant", text="I'm good, thanks!")])
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
@@ -74,7 +74,7 @@ class TestPurviewPolicyMiddleware:
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")):
|
||||
next_called = False
|
||||
|
||||
async def mock_next(ctx: AgentContext) -> None:
|
||||
async def mock_next() -> None:
|
||||
nonlocal next_called
|
||||
next_called = True
|
||||
|
||||
@@ -101,8 +101,8 @@ class TestPurviewPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
|
||||
|
||||
async def mock_next(ctx: AgentContext) -> None:
|
||||
ctx.result = AgentResponse(
|
||||
async def mock_next() -> None:
|
||||
context.result = AgentResponse(
|
||||
messages=[Message(role="assistant", text="Here's some sensitive information")]
|
||||
)
|
||||
|
||||
@@ -125,8 +125,8 @@ class TestPurviewPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")):
|
||||
|
||||
async def mock_next(ctx: AgentContext) -> None:
|
||||
ctx.result = "Some non-standard result"
|
||||
async def mock_next() -> None:
|
||||
context.result = "Some non-standard result"
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
@@ -142,8 +142,8 @@ class TestPurviewPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_process:
|
||||
|
||||
async def mock_next(ctx: AgentContext) -> None:
|
||||
ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
|
||||
async def mock_next() -> None:
|
||||
context.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
@@ -160,8 +160,8 @@ class TestPurviewPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
|
||||
|
||||
async def mock_next(ctx: AgentContext) -> None:
|
||||
ctx.result = AgentResponse(messages=[Message(role="assistant", text="streaming")])
|
||||
async def mock_next() -> None:
|
||||
context.result = AgentResponse(messages=[Message(role="assistant", text="streaming")])
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
@@ -181,7 +181,7 @@ class TestPurviewPolicyMiddleware:
|
||||
side_effect=PurviewPaymentRequiredError("Payment required"),
|
||||
):
|
||||
|
||||
async def mock_next(_: AgentContext) -> None:
|
||||
async def mock_next() -> None:
|
||||
raise AssertionError("next should not be called")
|
||||
|
||||
with pytest.raises(PurviewPaymentRequiredError):
|
||||
@@ -206,8 +206,8 @@ class TestPurviewPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", side_effect=side_effect):
|
||||
|
||||
async def mock_next(ctx: AgentContext) -> None:
|
||||
ctx.result = AgentResponse(messages=[Message(role="assistant", text="OK")])
|
||||
async def mock_next() -> None:
|
||||
context.result = AgentResponse(messages=[Message(role="assistant", text="OK")])
|
||||
|
||||
with pytest.raises(PurviewPaymentRequiredError):
|
||||
await middleware.process(context, mock_next)
|
||||
@@ -231,8 +231,8 @@ class TestPurviewPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", side_effect=side_effect):
|
||||
|
||||
async def mock_next(ctx: AgentContext) -> None:
|
||||
ctx.result = AgentResponse(messages=[Message(role="assistant", text="OK")])
|
||||
async def mock_next() -> None:
|
||||
context.result = AgentResponse(messages=[Message(role="assistant", text="OK")])
|
||||
|
||||
with pytest.raises(ValueError, match="Post-check blew up"):
|
||||
await middleware.process(context, mock_next)
|
||||
@@ -250,8 +250,8 @@ class TestPurviewPolicyMiddleware:
|
||||
middleware._processor, "process_messages", side_effect=Exception("Pre-check error")
|
||||
) as mock_process:
|
||||
|
||||
async def mock_next(ctx: AgentContext) -> None:
|
||||
ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
|
||||
async def mock_next() -> None:
|
||||
context.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
@@ -280,8 +280,8 @@ class TestPurviewPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
|
||||
|
||||
async def mock_next(ctx: AgentContext) -> None:
|
||||
ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
|
||||
async def mock_next() -> None:
|
||||
context.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
@@ -306,8 +306,8 @@ class TestPurviewPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
|
||||
|
||||
async def mock_next(ctx):
|
||||
ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
|
||||
async def mock_next():
|
||||
context.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
|
||||
|
||||
# Should not raise, just log
|
||||
await middleware.process(context, mock_next)
|
||||
@@ -330,7 +330,7 @@ class TestPurviewPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
|
||||
|
||||
async def mock_next(ctx):
|
||||
async def mock_next():
|
||||
pass
|
||||
|
||||
# Should raise the exception
|
||||
@@ -346,8 +346,8 @@ class TestPurviewPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
|
||||
|
||||
async def mock_next(ctx: AgentContext) -> None:
|
||||
ctx.result = AgentResponse(messages=[Message(role="assistant", text="Hi")])
|
||||
async def mock_next() -> None:
|
||||
context.result = AgentResponse(messages=[Message(role="assistant", text="Hi")])
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
@@ -364,8 +364,8 @@ class TestPurviewPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
|
||||
|
||||
async def mock_next(ctx: AgentContext) -> None:
|
||||
ctx.result = AgentResponse(messages=[Message(role="assistant", text="Hi")])
|
||||
async def mock_next() -> None:
|
||||
context.result = AgentResponse(messages=[Message(role="assistant", text="Hi")])
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
@@ -383,8 +383,8 @@ class TestPurviewPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
|
||||
|
||||
async def mock_next(ctx: AgentContext) -> None:
|
||||
ctx.result = AgentResponse(messages=[Message(role="assistant", text="Hi")])
|
||||
async def mock_next() -> None:
|
||||
context.result = AgentResponse(messages=[Message(role="assistant", text="Hi")])
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
@@ -399,8 +399,8 @@ class TestPurviewPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
|
||||
|
||||
async def mock_next(ctx: AgentContext) -> None:
|
||||
ctx.result = AgentResponse(messages=[Message(role="assistant", text="Hi")])
|
||||
async def mock_next() -> None:
|
||||
context.result = AgentResponse(messages=[Message(role="assistant", text="Hi")])
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
@@ -416,8 +416,8 @@ class TestPurviewPolicyMiddleware:
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
|
||||
|
||||
async def mock_next(ctx: AgentContext) -> None:
|
||||
ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
|
||||
async def mock_next() -> None:
|
||||
context.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@
|
||||
"**/autogen-migration/**",
|
||||
"**/semantic-kernel-migration/**",
|
||||
"**/demos/**",
|
||||
"**/agent_with_foundry_tracing.py"
|
||||
"**/agent_with_foundry_tracing.py",
|
||||
"**/azure_responses_client_with_foundry.py"
|
||||
],
|
||||
"typeCheckingMode": "off",
|
||||
"reportMissingImports": "error",
|
||||
|
||||
@@ -78,6 +78,7 @@ This directory contains samples demonstrating the capabilities of Microsoft Agen
|
||||
| [`getting_started/agents/azure_openai/azure_responses_client_image_analysis.py`](./getting_started/agents/azure_openai/azure_responses_client_image_analysis.py) | Azure OpenAI Responses Client with Image Analysis Example |
|
||||
| [`getting_started/agents/azure_openai/azure_responses_client_with_code_interpreter.py`](./getting_started/agents/azure_openai/azure_responses_client_with_code_interpreter.py) | Azure OpenAI Responses Client with Code Interpreter Example |
|
||||
| [`getting_started/agents/azure_openai/azure_responses_client_with_explicit_settings.py`](./getting_started/agents/azure_openai/azure_responses_client_with_explicit_settings.py) | Azure OpenAI Responses Client with Explicit Settings Example |
|
||||
| [`getting_started/agents/azure_openai/azure_responses_client_with_foundry.py`](./getting_started/agents/azure_openai/azure_responses_client_with_foundry.py) | Azure OpenAI Responses Client with Foundry Project Example |
|
||||
| [`getting_started/agents/azure_openai/azure_responses_client_with_function_tools.py`](./getting_started/agents/azure_openai/azure_responses_client_with_function_tools.py) | Azure OpenAI Responses Client with Function Tools Example |
|
||||
| [`getting_started/agents/azure_openai/azure_responses_client_with_hosted_mcp.py`](./getting_started/agents/azure_openai/azure_responses_client_with_hosted_mcp.py) | Azure OpenAI Responses Client with Hosted Model Context Protocol (MCP) Example |
|
||||
| [`getting_started/agents/azure_openai/azure_responses_client_with_local_mcp.py`](./getting_started/agents/azure_openai/azure_responses_client_with_local_mcp.py) | Azure OpenAI Responses Client with local Model Context Protocol (MCP) Example |
|
||||
|
||||
@@ -267,7 +267,7 @@ class TerminatingMiddleware(FunctionMiddleware):
|
||||
if self.should_terminate(context):
|
||||
context.result = "terminated by middleware"
|
||||
raise MiddlewareTermination # Exit function invocation loop
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
```
|
||||
|
||||
## Arguments Added/Altered at Each Layer
|
||||
@@ -347,7 +347,7 @@ class CachingMiddleware(FunctionMiddleware):
|
||||
return # Upstream post-processing still runs
|
||||
|
||||
# Option B: Call call_next, then return normally
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
self.cache[context.function.name] = context.result
|
||||
return # Normal completion
|
||||
```
|
||||
@@ -362,7 +362,7 @@ class BlockedFunctionMiddleware(FunctionMiddleware):
|
||||
if context.function.name in self.blocked_functions:
|
||||
context.result = "Function blocked by policy"
|
||||
raise MiddlewareTermination("Blocked") # Skips ALL post-processing
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
```
|
||||
|
||||
### 3. Raise Any Other Exception
|
||||
@@ -374,7 +374,7 @@ class ValidationMiddleware(FunctionMiddleware):
|
||||
async def process(self, context: FunctionInvocationContext, call_next):
|
||||
if not self.is_valid(context.arguments):
|
||||
raise ValueError("Invalid arguments") # Bubbles up to user
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
```
|
||||
|
||||
## `return` vs `raise MiddlewareTermination`
|
||||
@@ -385,7 +385,7 @@ The key difference is what happens to **upstream middleware's post-processing**:
|
||||
class MiddlewareA(AgentMiddleware):
|
||||
async def process(self, context, call_next):
|
||||
print("A: before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
print("A: after") # Does this run?
|
||||
|
||||
class MiddlewareB(AgentMiddleware):
|
||||
@@ -410,7 +410,7 @@ With middleware registered as `[MiddlewareA, MiddlewareB]`:
|
||||
|
||||
## Calling `call_next()` or Not
|
||||
|
||||
The decision to call `call_next(context)` determines whether downstream middleware and the actual operation execute:
|
||||
The decision to call `call_next()` determines whether downstream middleware and the actual operation execute:
|
||||
|
||||
### Without calling `call_next()` - Skip downstream
|
||||
|
||||
@@ -430,7 +430,7 @@ async def process(self, context, call_next):
|
||||
```python
|
||||
async def process(self, context, call_next):
|
||||
# Pre-processing
|
||||
await call_next(context) # Execute downstream + actual operation
|
||||
await call_next() # Execute downstream + actual operation
|
||||
# Post-processing (context.result now contains real result)
|
||||
return
|
||||
```
|
||||
@@ -450,7 +450,7 @@ async def process(self, context, call_next):
|
||||
| `raise MiddlewareTermination` | Yes | ✅ | ✅ | ❌ No |
|
||||
| `raise OtherException` | Either | Depends | Depends | ❌ No (exception propagates) |
|
||||
|
||||
> **Note:** The first row (`return` after calling `call_next()`) is the default behavior. Python functions implicitly return `None` at the end, so simply calling `await call_next(context)` without an explicit `return` statement achieves this pattern.
|
||||
> **Note:** The first row (`return` after calling `call_next()`) is the default behavior. Python functions implicitly return `None` at the end, so simply calling `await call_next()` without an explicit `return` statement achieves this pattern.
|
||||
|
||||
## Streaming vs Non-Streaming
|
||||
|
||||
|
||||
@@ -20,13 +20,13 @@ multiple specialized agents, each focusing on specific tasks.
|
||||
|
||||
async def logging_middleware(
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""MiddlewareTypes that logs tool invocations to show the delegation flow."""
|
||||
print(f"[Calling tool: {context.function.name}]")
|
||||
print(f"[Request: {context.arguments}]")
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
print(f"[Response: {context.result}]")
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ This folder contains examples demonstrating different ways to create and use age
|
||||
| [`azure_responses_client_with_code_interpreter.py`](azure_responses_client_with_code_interpreter.py) | Shows how to use `AzureOpenAIResponsesClient.get_code_interpreter_tool()` with Azure agents to write and execute Python code. Includes helper methods for accessing code interpreter data from response chunks. |
|
||||
| [`azure_responses_client_with_explicit_settings.py`](azure_responses_client_with_explicit_settings.py) | Shows how to initialize an agent with a specific responses client, configuring settings explicitly including endpoint and deployment name. |
|
||||
| [`azure_responses_client_with_file_search.py`](azure_responses_client_with_file_search.py) | Demonstrates using `AzureOpenAIResponsesClient.get_file_search_tool()` with Azure OpenAI Responses Client for direct document-based question answering and information retrieval from vector stores. |
|
||||
| [`azure_responses_client_with_foundry.py`](azure_responses_client_with_foundry.py) | Shows how to create an agent using an Azure AI Foundry project endpoint instead of a direct Azure OpenAI endpoint. Requires the `azure-ai-projects` package. |
|
||||
| [`azure_responses_client_with_function_tools.py`](azure_responses_client_with_function_tools.py) | Demonstrates how to use function tools with agents. Shows both agent-level tools (defined when creating the agent) and query-level tools (provided with specific queries). |
|
||||
| [`azure_responses_client_with_hosted_mcp.py`](azure_responses_client_with_hosted_mcp.py) | Shows how to integrate Azure OpenAI Responses Client with hosted Model Context Protocol (MCP) servers using `AzureOpenAIResponsesClient.get_mcp_tool()` for extended functionality. |
|
||||
| [`azure_responses_client_with_local_mcp.py`](azure_responses_client_with_local_mcp.py) | Shows how to integrate Azure OpenAI Responses Client with local Model Context Protocol (MCP) servers using MCPStreamableHTTPTool for extended functionality. |
|
||||
@@ -35,6 +36,9 @@ Make sure to set the following environment variables before running the examples
|
||||
- `AZURE_OPENAI_CHAT_DEPLOYMENT_NAME`: The name of your Azure OpenAI chat model deployment
|
||||
- `AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME`: The name of your Azure OpenAI Responses deployment
|
||||
|
||||
For the Foundry project sample (`azure_responses_client_with_foundry.py`), also set:
|
||||
- `AZURE_AI_PROJECT_ENDPOINT`: Your Azure AI Foundry project endpoint
|
||||
|
||||
Optionally, you can set:
|
||||
- `AZURE_OPENAI_API_VERSION`: The API version to use (default is `2024-02-15-preview`)
|
||||
- `AZURE_OPENAI_API_KEY`: Your Azure OpenAI API key (if not using `AzureCliCredential`)
|
||||
|
||||
+113
@@ -0,0 +1,113 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from random import randint
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import tool
|
||||
from agent_framework.azure import AzureOpenAIResponsesClient
|
||||
from azure.identity import AzureCliCredential
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
"""
|
||||
Azure OpenAI Responses Client with Foundry Project Example
|
||||
|
||||
This sample demonstrates how to create an AzureOpenAIResponsesClient using an
|
||||
Azure AI Foundry project endpoint. Instead of providing an Azure OpenAI endpoint
|
||||
directly, you provide a Foundry project endpoint and the client is created via
|
||||
the Azure AI Foundry project SDK.
|
||||
|
||||
This requires:
|
||||
- The `azure-ai-projects` package to be installed.
|
||||
- The `AZURE_AI_PROJECT_ENDPOINT` environment variable set to your Foundry project endpoint.
|
||||
- The `AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME` environment variable set to the model deployment name.
|
||||
"""
|
||||
|
||||
load_dotenv() # Load environment variables from .env file if present
|
||||
|
||||
|
||||
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/getting_started/tools/function_tool_with_approval.py and samples/getting_started/tools/function_tool_with_approval_and_threads.py.
|
||||
@tool(approval_mode="never_require")
|
||||
def get_weather(
|
||||
location: Annotated[str, Field(description="The location to get the weather for.")],
|
||||
) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
conditions = ["sunny", "cloudy", "rainy", "stormy"]
|
||||
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
|
||||
|
||||
|
||||
async def non_streaming_example() -> None:
|
||||
"""Example of non-streaming response (get the complete result at once)."""
|
||||
print("=== Non-streaming Response Example ===")
|
||||
|
||||
# 1. Create the AzureOpenAIResponsesClient using a Foundry project endpoint.
|
||||
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
|
||||
# authentication option.
|
||||
credential = AzureCliCredential()
|
||||
agent = AzureOpenAIResponsesClient(
|
||||
project_endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"],
|
||||
deployment_name=os.environ["AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME"],
|
||||
credential=credential,
|
||||
).as_agent(
|
||||
instructions="You are a helpful weather agent.",
|
||||
tools=get_weather,
|
||||
)
|
||||
|
||||
# 2. Run a query and print the result.
|
||||
query = "What's the weather like in Seattle?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Result: {result}\n")
|
||||
|
||||
|
||||
async def streaming_example() -> None:
|
||||
"""Example of streaming response (get results as they are generated)."""
|
||||
print("=== Streaming Response Example ===")
|
||||
|
||||
# 1. Create the AzureOpenAIResponsesClient using a Foundry project endpoint.
|
||||
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
|
||||
# authentication option.
|
||||
credential = AzureCliCredential()
|
||||
agent = AzureOpenAIResponsesClient(
|
||||
project_endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"],
|
||||
deployment_name=os.environ["AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME"],
|
||||
credential=credential,
|
||||
).as_agent(
|
||||
instructions="You are a helpful weather agent.",
|
||||
tools=get_weather,
|
||||
)
|
||||
|
||||
# 2. Stream the response and print each chunk as it arrives.
|
||||
query = "What's the weather like in Portland?"
|
||||
print(f"User: {query}")
|
||||
print("Agent: ", end="", flush=True)
|
||||
async for chunk in agent.run(query, stream=True):
|
||||
if chunk.text:
|
||||
print(chunk.text, end="", flush=True)
|
||||
print("\n")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
print("=== Azure OpenAI Responses Client with Foundry Project Example ===")
|
||||
|
||||
await non_streaming_example()
|
||||
await streaming_example()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
"""
|
||||
Sample output:
|
||||
=== Azure OpenAI Responses Client with Foundry Project Example ===
|
||||
=== Non-streaming Response Example ===
|
||||
User: What's the weather like in Seattle?
|
||||
Result: The weather in Seattle is cloudy with a high of 18°C.
|
||||
|
||||
=== Streaming Response Example ===
|
||||
User: What's the weather like in Portland?
|
||||
Agent: The weather in Portland is sunny with a high of 25°C.
|
||||
"""
|
||||
@@ -29,7 +29,7 @@ response generation, showing both streaming and non-streaming responses.
|
||||
@chat_middleware
|
||||
async def security_and_override_middleware(
|
||||
context: ChatContext,
|
||||
call_next: Callable[[ChatContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Function-based middleware that implements security filtering and response override."""
|
||||
print("[SecurityMiddleware] Processing input...")
|
||||
@@ -60,7 +60,7 @@ async def security_and_override_middleware(
|
||||
raise MiddlewareTermination(result=context.result)
|
||||
|
||||
# Continue to next middleware or AI execution
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
print("[SecurityMiddleware] Response generated.")
|
||||
print(type(context.result))
|
||||
|
||||
+2
-2
@@ -19,13 +19,13 @@ multiple specialized agents, each focusing on specific tasks.
|
||||
|
||||
async def logging_middleware(
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""MiddlewareTypes that logs tool invocations to show the delegation flow."""
|
||||
print(f"[Calling tool: {context.function.name}]")
|
||||
print(f"[Request: {context.arguments}]")
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
print(f"[Response: {context.result}]")
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ def cleanup_resources():
|
||||
@chat_middleware
|
||||
async def security_filter_middleware(
|
||||
context: ChatContext,
|
||||
call_next: Callable[[ChatContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Chat middleware that blocks requests containing sensitive information."""
|
||||
blocked_terms = ["password", "secret", "api_key", "token"]
|
||||
@@ -80,13 +80,13 @@ async def security_filter_middleware(
|
||||
|
||||
raise MiddlewareTermination(result=context.result)
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
|
||||
@function_middleware
|
||||
async def atlantis_location_filter_middleware(
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Function middleware that blocks weather requests for Atlantis."""
|
||||
# Check if location parameter is "atlantis"
|
||||
@@ -98,7 +98,7 @@ async def atlantis_location_filter_middleware(
|
||||
)
|
||||
raise MiddlewareTermination(result=context.result)
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
|
||||
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/getting_started/tools/function_tool_with_approval.py and samples/getting_started/tools/function_tool_with_approval_and_threads.py.
|
||||
|
||||
@@ -68,7 +68,7 @@ def get_weather(
|
||||
class SecurityAgentMiddleware(AgentMiddleware):
|
||||
"""Agent-level security middleware that validates all requests."""
|
||||
|
||||
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
print("[SecurityMiddleware] Checking security for all requests...")
|
||||
|
||||
# Check for security violations in the last user message
|
||||
@@ -81,18 +81,18 @@ class SecurityAgentMiddleware(AgentMiddleware):
|
||||
|
||||
print("[SecurityMiddleware] Security check passed.")
|
||||
context.metadata["security_validated"] = True
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
|
||||
async def performance_monitor_middleware(
|
||||
context: AgentContext,
|
||||
call_next: Callable[[AgentContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Agent-level performance monitoring for all runs."""
|
||||
print("[PerformanceMonitor] Starting performance monitoring...")
|
||||
start_time = time.time()
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
@@ -104,7 +104,7 @@ async def performance_monitor_middleware(
|
||||
class HighPriorityMiddleware(AgentMiddleware):
|
||||
"""Run-level middleware for high priority requests."""
|
||||
|
||||
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
print("[HighPriority] Processing high priority request with expedited handling...")
|
||||
|
||||
# Read metadata set by agent-level middleware
|
||||
@@ -115,13 +115,13 @@ class HighPriorityMiddleware(AgentMiddleware):
|
||||
context.metadata["priority"] = "high"
|
||||
context.metadata["expedited"] = True
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
print("[HighPriority] High priority processing completed")
|
||||
|
||||
|
||||
async def debugging_middleware(
|
||||
context: AgentContext,
|
||||
call_next: Callable[[AgentContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Run-level debugging middleware for troubleshooting specific runs."""
|
||||
print("[Debug] Debug mode enabled for this run")
|
||||
@@ -134,7 +134,7 @@ async def debugging_middleware(
|
||||
|
||||
context.metadata["debug_enabled"] = True
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
print("[Debug] Debug information collected")
|
||||
|
||||
@@ -145,7 +145,7 @@ class CachingMiddleware(AgentMiddleware):
|
||||
def __init__(self) -> None:
|
||||
self.cache: dict[str, AgentResponse] = {}
|
||||
|
||||
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Create a simple cache key from the last message
|
||||
last_message = context.messages[-1] if context.messages else None
|
||||
cache_key: str = last_message.text if last_message and last_message.text else "no_message"
|
||||
@@ -158,7 +158,7 @@ class CachingMiddleware(AgentMiddleware):
|
||||
print(f"[Cache] Cache MISS for: '{cache_key[:30]}...'")
|
||||
context.metadata["cache_key"] = cache_key
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Cache the result if we have one
|
||||
if context.result:
|
||||
@@ -168,14 +168,14 @@ class CachingMiddleware(AgentMiddleware):
|
||||
|
||||
async def function_logging_middleware(
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Function middleware that logs all function calls."""
|
||||
function_name = context.function.name
|
||||
args = context.arguments
|
||||
print(f"[FunctionLog] Calling function: {function_name} with args: {args}")
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
print(f"[FunctionLog] Function {function_name} completed")
|
||||
|
||||
@@ -275,7 +275,7 @@ async def main() -> None:
|
||||
query = "What's the secret weather password for Berlin?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'Request was blocked by security middleware'}")
|
||||
print(f"Agent: {result.text if result and result.text else 'Request was blocked by security middleware'}")
|
||||
print()
|
||||
|
||||
# Run 7: Normal query again (no run-level middleware interference)
|
||||
|
||||
@@ -57,7 +57,7 @@ class InputObserverMiddleware(ChatMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: ChatContext,
|
||||
call_next: Callable[[ChatContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Observe and modify input messages before they are sent to AI."""
|
||||
print("[InputObserverMiddleware] Observing input messages:")
|
||||
@@ -91,7 +91,7 @@ class InputObserverMiddleware(ChatMiddleware):
|
||||
context.messages[:] = modified_messages
|
||||
|
||||
# Continue to next middleware or AI execution
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Observe that processing is complete
|
||||
print("[InputObserverMiddleware] Processing completed")
|
||||
@@ -100,7 +100,7 @@ class InputObserverMiddleware(ChatMiddleware):
|
||||
@chat_middleware
|
||||
async def security_and_override_middleware(
|
||||
context: ChatContext,
|
||||
call_next: Callable[[ChatContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Function-based middleware that implements security filtering and response override."""
|
||||
print("[SecurityMiddleware] Processing input...")
|
||||
@@ -131,7 +131,7 @@ async def security_and_override_middleware(
|
||||
raise MiddlewareTermination
|
||||
|
||||
# Continue to next middleware or AI execution
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
|
||||
async def class_based_chat_middleware() -> None:
|
||||
|
||||
@@ -50,7 +50,7 @@ class SecurityAgentMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: AgentContext,
|
||||
call_next: Callable[[AgentContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
# Check for potential security violations in the query
|
||||
# Look at the last user message
|
||||
@@ -67,7 +67,7 @@ class SecurityAgentMiddleware(AgentMiddleware):
|
||||
return
|
||||
|
||||
print("[SecurityAgentMiddleware] Security check passed.")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
|
||||
class LoggingFunctionMiddleware(FunctionMiddleware):
|
||||
@@ -76,14 +76,14 @@ class LoggingFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
function_name = context.function.name
|
||||
print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
|
||||
@@ -53,7 +53,7 @@ def get_current_time() -> str:
|
||||
async def simple_agent_middleware(context, call_next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality
|
||||
"""Agent middleware that runs before and after agent execution."""
|
||||
print("[Agent MiddlewareTypes] Before agent execution")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
print("[Agent MiddlewareTypes] After agent execution")
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ async def simple_agent_middleware(context, call_next): # type: ignore - paramet
|
||||
async def simple_function_middleware(context, call_next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality
|
||||
"""Function middleware that runs before and after function calls."""
|
||||
print(f"[Function MiddlewareTypes] Before calling: {context.function.name}") # type: ignore
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
print(f"[Function MiddlewareTypes] After calling: {context.function.name}") # type: ignore
|
||||
|
||||
|
||||
|
||||
@@ -35,13 +35,13 @@ def unstable_data_service(
|
||||
|
||||
|
||||
async def exception_handling_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
function_name = context.function.name
|
||||
|
||||
try:
|
||||
print(f"[ExceptionHandlingMiddleware] Executing function: {function_name}")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
print(f"[ExceptionHandlingMiddleware] Function {function_name} completed successfully.")
|
||||
except TimeoutError as e:
|
||||
print(f"[ExceptionHandlingMiddleware] Caught TimeoutError: {e}")
|
||||
|
||||
@@ -43,7 +43,7 @@ def get_weather(
|
||||
|
||||
async def security_agent_middleware(
|
||||
context: AgentContext,
|
||||
call_next: Callable[[AgentContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Agent middleware that checks for security violations."""
|
||||
# Check for potential security violations in the query
|
||||
@@ -57,12 +57,12 @@ async def security_agent_middleware(
|
||||
return
|
||||
|
||||
print("[SecurityAgentMiddleware] Security check passed.")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
|
||||
async def logging_function_middleware(
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Function middleware that logs function calls."""
|
||||
function_name = context.function.name
|
||||
@@ -70,7 +70,7 @@ async def logging_function_middleware(
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
@@ -105,7 +105,7 @@ async def main() -> None:
|
||||
query = "What's the secret weather password?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'No response'}\n")
|
||||
print(f"Agent: {result.text if result and result.text else 'No response'}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -49,7 +49,7 @@ class PreTerminationMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: AgentContext,
|
||||
call_next: Callable[[AgentContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
# Check if the user message contains any blocked words
|
||||
last_message = context.messages[-1] if context.messages else None
|
||||
@@ -75,7 +75,7 @@ class PreTerminationMiddleware(AgentMiddleware):
|
||||
# Terminate to prevent further processing
|
||||
raise MiddlewareTermination(result=context.result)
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
|
||||
class PostTerminationMiddleware(AgentMiddleware):
|
||||
@@ -88,7 +88,7 @@ class PostTerminationMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: AgentContext,
|
||||
call_next: Callable[[AgentContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
print(f"[PostTerminationMiddleware] Processing request (response count: {self.response_count})")
|
||||
|
||||
@@ -101,7 +101,7 @@ class PostTerminationMiddleware(AgentMiddleware):
|
||||
raise MiddlewareTermination
|
||||
|
||||
# Allow the agent to process normally
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Increment response count after processing
|
||||
self.response_count += 1
|
||||
@@ -158,14 +158,14 @@ async def post_termination_middleware() -> None:
|
||||
query = "What about the weather in London?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'No response (terminated)'}")
|
||||
print(f"Agent: {result.text if result and result.text else 'No response (terminated)'}")
|
||||
|
||||
# Third run (should also be terminated)
|
||||
print("\n3. Third run (should also be terminated):")
|
||||
query = "And New York?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'No response (terminated)'}")
|
||||
print(f"Agent: {result.text if result and result.text else 'No response (terminated)'}")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
|
||||
@@ -49,11 +49,11 @@ def get_weather(
|
||||
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
|
||||
|
||||
|
||||
async def weather_override_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def weather_override_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
"""Chat middleware that overrides weather results for both streaming and non-streaming cases."""
|
||||
|
||||
# Let the original agent execution complete first
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Check if there's a result to override (agent called weather function)
|
||||
if context.result is not None:
|
||||
@@ -84,9 +84,9 @@ async def weather_override_middleware(context: ChatContext, call_next: Callable[
|
||||
context.result = ChatResponse(messages=[Message(role=Role.ASSISTANT, text=custom_message)])
|
||||
|
||||
|
||||
async def validate_weather_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def validate_weather_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
"""Chat middleware that simulates result validation for both streaming and non-streaming cases."""
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
validation_note = "Validation: weather data verified."
|
||||
|
||||
@@ -104,9 +104,9 @@ async def validate_weather_middleware(context: ChatContext, call_next: Callable[
|
||||
context.result.messages.append(Message(role=Role.ASSISTANT, text=validation_note))
|
||||
|
||||
|
||||
async def agent_cleanup_middleware(context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
|
||||
async def agent_cleanup_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
"""Agent middleware that validates chat middleware effects and cleans the result."""
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
if context.result is None:
|
||||
return
|
||||
|
||||
@@ -54,7 +54,7 @@ class SessionContextContainer:
|
||||
async def inject_context_middleware(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""MiddlewareTypes that extracts runtime context from kwargs and stores in container.
|
||||
|
||||
@@ -74,7 +74,7 @@ class SessionContextContainer:
|
||||
print(f" - Session Metadata Keys: {list(self.session_metadata.keys())}")
|
||||
|
||||
# Continue to tool execution
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
|
||||
# Create a container instance that will be shared via closure
|
||||
@@ -278,19 +278,19 @@ async def pattern_2_hierarchical_with_kwargs_propagation() -> None:
|
||||
|
||||
@function_middleware
|
||||
async def email_kwargs_tracker(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
email_agent_kwargs.update(context.kwargs)
|
||||
print(f"[EmailAgent] Received runtime context: {list(context.kwargs.keys())}")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
@function_middleware
|
||||
async def sms_kwargs_tracker(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
sms_agent_kwargs.update(context.kwargs)
|
||||
print(f"[SMSAgent] Received runtime context: {list(context.kwargs.keys())}")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
client = OpenAIChatClient(model_id="gpt-4o-mini")
|
||||
|
||||
@@ -359,7 +359,7 @@ class AuthContextMiddleware:
|
||||
self.validated_tokens: list[str] = []
|
||||
|
||||
async def validate_and_track(
|
||||
self, context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
self, context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Validate API token and track usage."""
|
||||
api_token = context.kwargs.get("api_token")
|
||||
@@ -375,7 +375,7 @@ class AuthContextMiddleware:
|
||||
else:
|
||||
print("[AuthMiddleware] No API token provided")
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
|
||||
@tool(approval_mode="never_require")
|
||||
|
||||
@@ -57,7 +57,7 @@ class MiddlewareContainer:
|
||||
async def call_counter_middleware(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""First middleware: increments call count in shared state."""
|
||||
# Increment the shared call count
|
||||
@@ -66,18 +66,18 @@ class MiddlewareContainer:
|
||||
print(f"[CallCounter] This is function call #{self.call_count}")
|
||||
|
||||
# Call the next middleware/function
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
async def result_enhancer_middleware(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Second middleware: uses shared call count to enhance function results."""
|
||||
print(f"[ResultEnhancer] Current total calls so far: {self.call_count}")
|
||||
|
||||
# Call the next middleware/function
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# After function execution, enhance the result using shared state
|
||||
if context.result:
|
||||
|
||||
@@ -46,7 +46,7 @@ def get_weather(
|
||||
|
||||
async def thread_tracking_middleware(
|
||||
context: AgentContext,
|
||||
call_next: Callable[[AgentContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""MiddlewareTypes that tracks and logs thread behavior across runs."""
|
||||
thread_messages = []
|
||||
@@ -57,7 +57,7 @@ async def thread_tracking_middleware(
|
||||
print(f"[MiddlewareTypes pre-execution] Thread history messages: {len(thread_messages)}")
|
||||
|
||||
# Call call_next to execute the agent
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Check thread state after agent execution
|
||||
updated_thread_messages = []
|
||||
|
||||
Generated
+2
-2
@@ -209,7 +209,6 @@ dependencies = [
|
||||
{ name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "aiohttp", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "azure-ai-agents", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "azure-ai-projects", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
@@ -217,7 +216,6 @@ requires-dist = [
|
||||
{ name = "agent-framework-core", editable = "packages/core" },
|
||||
{ name = "aiohttp" },
|
||||
{ name = "azure-ai-agents", specifier = "==1.2.0b5" },
|
||||
{ name = "azure-ai-projects", specifier = ">=2.0.0b3" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -324,6 +322,7 @@ name = "agent-framework-core"
|
||||
version = "1.0.0b260210"
|
||||
source = { editable = "packages/core" }
|
||||
dependencies = [
|
||||
{ name = "azure-ai-projects", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "azure-identity", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "mcp", extra = ["ws"], marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
@@ -378,6 +377,7 @@ requires-dist = [
|
||||
{ name = "agent-framework-orchestrations", marker = "extra == 'all'", editable = "packages/orchestrations" },
|
||||
{ name = "agent-framework-purview", marker = "extra == 'all'", editable = "packages/purview" },
|
||||
{ name = "agent-framework-redis", marker = "extra == 'all'", editable = "packages/redis" },
|
||||
{ name = "azure-ai-projects", specifier = ">=2.0.0b3" },
|
||||
{ name = "azure-identity", specifier = ">=1,<2" },
|
||||
{ name = "mcp", extras = ["ws"], specifier = ">=1.24.0,<2" },
|
||||
{ name = "openai", specifier = ">=1.99.0" },
|
||||
|
||||
Reference in New Issue
Block a user