Merge branch 'main' into feature-session-statebag

This commit is contained in:
westey
2026-02-11 16:27:36 +00:00
committed by GitHub
Unverified
42 changed files with 801 additions and 611 deletions
+1
View File
@@ -25,6 +25,7 @@
"src\\Microsoft.Agents.AI.Purview\\Microsoft.Agents.AI.Purview.csproj",
"src\\Microsoft.Agents.AI.Workflows.Declarative.AzureAI\\Microsoft.Agents.AI.Workflows.Declarative.AzureAI.csproj",
"src\\Microsoft.Agents.AI.Workflows.Declarative\\Microsoft.Agents.AI.Workflows.Declarative.csproj",
"src\\Microsoft.Agents.AI.Workflows.Generators\\Microsoft.Agents.AI.Workflows.Generators.csproj",
"src\\Microsoft.Agents.AI.Workflows\\Microsoft.Agents.AI.Workflows.csproj",
"src\\Microsoft.Agents.AI\\Microsoft.Agents.AI.csproj"
]
-1
View File
@@ -24,7 +24,6 @@ classifiers = [
]
dependencies = [
"agent-framework-core>=1.0.0b260210",
"azure-ai-projects >= 2.0.0b3",
"azure-ai-agents == 1.2.0b5",
"aiohttp",
]
+1 -1
View File
@@ -119,7 +119,7 @@ from agent_framework import Agent, AgentMiddleware, AgentContext
class LoggingMiddleware(AgentMiddleware):
async def process(self, context: AgentContext, call_next) -> None:
print(f"Input: {context.messages}")
await call_next(context)
await call_next()
print(f"Output: {context.result}")
agent = Agent(..., middleware=[LoggingMiddleware()])
@@ -145,7 +145,7 @@ class AgentContext:
context.metadata["start_time"] = time.time()
# Continue execution
await call_next(context)
await call_next()
# Access result after execution
print(f"Result: {context.result}")
@@ -229,7 +229,7 @@ class FunctionInvocationContext:
raise MiddlewareTermination("Validation failed")
# Continue execution
await call_next(context)
await call_next()
"""
def __init__(
@@ -293,7 +293,7 @@ class ChatContext:
context.metadata["input_tokens"] = self.count_tokens(context.messages)
# Continue execution
await call_next(context)
await call_next()
# Access result and count output tokens
if context.result:
@@ -365,7 +365,7 @@ class AgentMiddleware(ABC):
async def process(self, context: AgentContext, call_next):
for attempt in range(self.max_retries):
await call_next(context)
await call_next()
if context.result and not context.result.is_error:
break
print(f"Retry {attempt + 1}/{self.max_retries}")
@@ -379,7 +379,7 @@ class AgentMiddleware(ABC):
async def process(
self,
context: AgentContext,
call_next: Callable[[AgentContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Process an agent invocation.
@@ -431,7 +431,7 @@ class FunctionMiddleware(ABC):
raise MiddlewareTermination()
# Execute function
await call_next(context)
await call_next()
# Cache result
if context.result:
@@ -446,7 +446,7 @@ class FunctionMiddleware(ABC):
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Process a function invocation.
@@ -493,7 +493,7 @@ class ChatMiddleware(ABC):
context.messages.insert(0, Message(role="system", text=self.system_prompt))
# Continue execution
await call_next(context)
await call_next()
# Use with an agent
@@ -508,7 +508,7 @@ class ChatMiddleware(ABC):
async def process(
self,
context: ChatContext,
call_next: Callable[[ChatContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Process a chat client request.
@@ -531,15 +531,13 @@ class ChatMiddleware(ABC):
# Pure function type definitions for convenience
AgentMiddlewareCallable = Callable[[AgentContext, Callable[[AgentContext], Awaitable[None]]], Awaitable[None]]
AgentMiddlewareCallable = Callable[[AgentContext, Callable[[], Awaitable[None]]], Awaitable[None]]
AgentMiddlewareTypes: TypeAlias = AgentMiddleware | AgentMiddlewareCallable
FunctionMiddlewareCallable = Callable[
[FunctionInvocationContext, Callable[[FunctionInvocationContext], Awaitable[None]]], Awaitable[None]
]
FunctionMiddlewareCallable = Callable[[FunctionInvocationContext, Callable[[], Awaitable[None]]], Awaitable[None]]
FunctionMiddlewareTypes: TypeAlias = FunctionMiddleware | FunctionMiddlewareCallable
ChatMiddlewareCallable = Callable[[ChatContext, Callable[[ChatContext], Awaitable[None]]], Awaitable[None]]
ChatMiddlewareCallable = Callable[[ChatContext, Callable[[], Awaitable[None]]], Awaitable[None]]
ChatMiddlewareTypes: TypeAlias = ChatMiddleware | ChatMiddlewareCallable
ChatAndFunctionMiddlewareTypes: TypeAlias = (
@@ -578,7 +576,7 @@ def agent_middleware(func: AgentMiddlewareCallable) -> AgentMiddlewareCallable:
@agent_middleware
async def logging_middleware(context: AgentContext, call_next):
print(f"Before: {context.agent.name}")
await call_next(context)
await call_next()
print(f"After: {context.result}")
@@ -611,7 +609,7 @@ def function_middleware(func: FunctionMiddlewareCallable) -> FunctionMiddlewareC
@function_middleware
async def logging_middleware(context: FunctionInvocationContext, call_next):
print(f"Calling: {context.function.name}")
await call_next(context)
await call_next()
print(f"Result: {context.result}")
@@ -644,7 +642,7 @@ def chat_middleware(func: ChatMiddlewareCallable) -> ChatMiddlewareCallable:
@chat_middleware
async def logging_middleware(context: ChatContext, call_next):
print(f"Messages: {len(context.messages)}")
await call_next(context)
await call_next()
print(f"Response: {context.result}")
@@ -666,10 +664,10 @@ class MiddlewareWrapper(Generic[ContextT]):
ContextT: The type of context object this middleware operates on.
"""
def __init__(self, func: Callable[[ContextT, Callable[[ContextT], Awaitable[None]]], Awaitable[None]]) -> None:
def __init__(self, func: Callable[[ContextT, Callable[[], Awaitable[None]]], Awaitable[None]]) -> None:
self.func = func
async def process(self, context: ContextT, call_next: Callable[[ContextT], Awaitable[None]]) -> None:
async def process(self, context: ContextT, call_next: Callable[[], Awaitable[None]]) -> None:
await self.func(context, call_next)
@@ -772,25 +770,25 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline):
context.result = await context.result
return context.result
def create_next_handler(index: int) -> Callable[[AgentContext], Awaitable[None]]:
def create_next_handler(index: int) -> Callable[[], Awaitable[None]]:
if index >= len(self._middleware):
async def final_wrapper(c: AgentContext) -> None:
c.result = final_handler(c) # type: ignore[assignment]
if inspect.isawaitable(c.result):
c.result = await c.result
async def final_wrapper() -> None:
context.result = final_handler(context) # type: ignore[assignment]
if inspect.isawaitable(context.result):
context.result = await context.result
return final_wrapper
async def current_handler(c: AgentContext) -> None:
async def current_handler() -> None:
# MiddlewareTermination bubbles up to execute() to skip post-processing
await self._middleware[index].process(c, create_next_handler(index + 1))
await self._middleware[index].process(context, create_next_handler(index + 1))
return current_handler
first_handler = create_next_handler(0)
with contextlib.suppress(MiddlewareTermination):
await first_handler(context)
await first_handler()
if context.result and isinstance(context.result, ResponseStream):
for hook in context.stream_transform_hooks:
@@ -847,25 +845,25 @@ class FunctionMiddlewarePipeline(BaseMiddlewarePipeline):
if not self._middleware:
return await final_handler(context)
def create_next_handler(index: int) -> Callable[[FunctionInvocationContext], Awaitable[None]]:
def create_next_handler(index: int) -> Callable[[], Awaitable[None]]:
if index >= len(self._middleware):
async def final_wrapper(c: FunctionInvocationContext) -> None:
c.result = final_handler(c)
if inspect.isawaitable(c.result):
c.result = await c.result
async def final_wrapper() -> None:
context.result = final_handler(context)
if inspect.isawaitable(context.result):
context.result = await context.result
return final_wrapper
async def current_handler(c: FunctionInvocationContext) -> None:
async def current_handler() -> None:
# MiddlewareTermination bubbles up to execute() to skip post-processing
await self._middleware[index].process(c, create_next_handler(index + 1))
await self._middleware[index].process(context, create_next_handler(index + 1))
return current_handler
first_handler = create_next_handler(0)
# Don't suppress MiddlewareTermination - let it propagate to signal loop termination
await first_handler(context)
await first_handler()
return context.result
@@ -922,25 +920,25 @@ class ChatMiddlewarePipeline(BaseMiddlewarePipeline):
raise ValueError("Streaming agent middleware requires a ResponseStream result.")
return context.result
def create_next_handler(index: int) -> Callable[[ChatContext], Awaitable[None]]:
def create_next_handler(index: int) -> Callable[[], Awaitable[None]]:
if index >= len(self._middleware):
async def final_wrapper(c: ChatContext) -> None:
c.result = final_handler(c) # type: ignore[assignment]
if inspect.isawaitable(c.result):
c.result = await c.result
async def final_wrapper() -> None:
context.result = final_handler(context) # type: ignore[assignment]
if inspect.isawaitable(context.result):
context.result = await context.result
return final_wrapper
async def current_handler(c: ChatContext) -> None:
async def current_handler() -> None:
# MiddlewareTermination bubbles up to execute() to skip post-processing
await self._middleware[index].process(c, create_next_handler(index + 1))
await self._middleware[index].process(context, create_next_handler(index + 1))
return current_handler
first_handler = create_next_handler(0)
with contextlib.suppress(MiddlewareTermination):
await first_handler(context)
await first_handler()
if context.result and isinstance(context.result, ResponseStream):
for hook in context.stream_transform_hooks:
@@ -7,11 +7,14 @@ from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic
from urllib.parse import urljoin
from azure.ai.projects.aio import AIProjectClient
from azure.core.credentials import TokenCredential
from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI
from openai import AsyncOpenAI
from openai.lib.azure import AsyncAzureADTokenProvider
from pydantic import ValidationError
from .._middleware import ChatMiddlewareLayer
from .._telemetry import AGENT_FRAMEWORK_USER_AGENT
from .._tools import FunctionInvocationConfiguration, FunctionInvocationLayer
from ..exceptions import ServiceInitializationError
from ..observability import ChatTelemetryLayer
@@ -72,7 +75,9 @@ class AzureOpenAIResponsesClient( # type: ignore[misc]
token_endpoint: str | None = None,
credential: TokenCredential | None = None,
default_headers: Mapping[str, str] | None = None,
async_client: AsyncAzureOpenAI | None = None,
async_client: AsyncOpenAI | None = None,
project_client: Any | None = None,
project_endpoint: str | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
instruction_role: str | None = None,
@@ -82,6 +87,14 @@ class AzureOpenAIResponsesClient( # type: ignore[misc]
) -> None:
"""Initialize an Azure OpenAI Responses client.
The client can be created in two ways:
1. **Direct Azure OpenAI** (default): Provide endpoint, api_key, or credential
to connect directly to an Azure OpenAI deployment.
2. **Foundry project endpoint**: Provide a ``project_client`` or ``project_endpoint``
(with ``credential``) to create the client via an Azure AI Foundry project.
This requires the ``azure-ai-projects`` package to be installed.
Keyword Args:
api_key: The API key. If provided, will override the value in the env vars or .env file.
Can also be set via environment variable AZURE_OPENAI_API_KEY.
@@ -105,6 +118,12 @@ class AzureOpenAIResponsesClient( # type: ignore[misc]
default_headers: The default headers mapping of string keys to
string values for HTTP requests.
async_client: An existing client to use.
project_client: An existing ``AIProjectClient`` (from ``azure.ai.projects.aio``) to use.
The OpenAI client will be obtained via ``project_client.get_openai_client()``.
Requires the ``azure-ai-projects`` package.
project_endpoint: The Azure AI Foundry project endpoint URL.
When provided with ``credential``, an ``AIProjectClient`` will be created
and used to obtain the OpenAI client. Requires the ``azure-ai-projects`` package.
env_file_path: Use the environment settings file as a fallback to using env vars.
env_file_encoding: The encoding of the environment settings file, defaults to 'utf-8'.
instruction_role: The role to use for 'instruction' messages, for example, summarization
@@ -132,6 +151,27 @@ class AzureOpenAIResponsesClient( # type: ignore[misc]
# Or loading from a .env file
client = AzureOpenAIResponsesClient(env_file_path="path/to/.env")
# Using a Foundry project endpoint
from azure.identity import DefaultAzureCredential
client = AzureOpenAIResponsesClient(
project_endpoint="https://your-project.services.ai.azure.com",
deployment_name="gpt-4o",
credential=DefaultAzureCredential(),
)
# Or using an existing AIProjectClient
from azure.ai.projects.aio import AIProjectClient
project_client = AIProjectClient(
endpoint="https://your-project.services.ai.azure.com",
credential=DefaultAzureCredential(),
)
client = AzureOpenAIResponsesClient(
project_client=project_client,
deployment_name="gpt-4o",
)
# Using custom ChatOptions with type safety:
from typing import TypedDict
from agent_framework.azure import AzureOpenAIResponsesOptions
@@ -146,6 +186,15 @@ class AzureOpenAIResponsesClient( # type: ignore[misc]
"""
if model_id := kwargs.pop("model_id", None) and not deployment_name:
deployment_name = str(model_id)
# Project client path: create OpenAI client from an Azure AI Foundry project
if async_client is None and (project_client is not None or project_endpoint is not None):
async_client = self._create_client_from_project(
project_client=project_client,
project_endpoint=project_endpoint,
credential=credential,
)
try:
azure_openai_settings = AzureOpenAISettings(
# pydantic settings will see if there is a value, if not, will try the env var or .env file
@@ -195,9 +244,48 @@ class AzureOpenAIResponsesClient( # type: ignore[misc]
function_invocation_configuration=function_invocation_configuration,
)
@staticmethod
def _create_client_from_project(
*,
project_client: AIProjectClient | None,
project_endpoint: str | None,
credential: TokenCredential | None,
) -> AsyncOpenAI:
"""Create an AsyncOpenAI client from an Azure AI Foundry project.
Args:
project_client: An existing AIProjectClient to use.
project_endpoint: The Azure AI Foundry project endpoint URL.
credential: Azure credential for authentication.
Returns:
An AsyncAzureOpenAI client obtained from the project client.
Raises:
ServiceInitializationError: If required parameters are missing or
the azure-ai-projects package is not installed.
"""
if project_client is not None:
return project_client.get_openai_client()
if not project_endpoint:
raise ServiceInitializationError(
"Azure AI project endpoint is required when project_client is not provided."
)
if not credential:
raise ServiceInitializationError(
"Azure credential is required when using project_endpoint without a project_client."
)
project_client = AIProjectClient(
endpoint=project_endpoint,
credential=credential, # type: ignore[arg-type]
user_agent=AGENT_FRAMEWORK_USER_AGENT,
)
return project_client.get_openai_client()
@override
def _check_model_presence(self, run_options: dict[str, Any]) -> None:
if not run_options.get("model"):
def _check_model_presence(self, options: dict[str, Any]) -> None:
if not options.get("model"):
if not self.model_id:
raise ValueError("deployment_name must be a non-empty string")
run_options["model"] = self.model_id
options["model"] = self.model_id
@@ -9,6 +9,7 @@ from copy import copy
from typing import Any, ClassVar, Final
from azure.core.credentials import TokenCredential
from openai import AsyncOpenAI
from openai.lib.azure import AsyncAzureOpenAI
from pydantic import SecretStr, model_validator
@@ -162,7 +163,7 @@ class AzureOpenAIConfigMixin(OpenAIBase):
token_endpoint: str | None = None,
credential: TokenCredential | None = None,
default_headers: Mapping[str, str] | None = None,
client: AsyncAzureOpenAI | None = None,
client: AsyncOpenAI | None = None,
instruction_role: str | None = None,
**kwargs: Any,
) -> None:
@@ -901,6 +901,7 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
"""Prepare a chat message for the OpenAI Responses API format."""
all_messages: list[dict[str, Any]] = []
args: dict[str, Any] = {
"type": "message",
"role": message.role,
}
for content in message.contents:
@@ -911,16 +912,22 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
case "function_result":
new_args: dict[str, Any] = {}
new_args.update(self._prepare_content_for_openai(message.role, content, call_id_to_id)) # type: ignore[arg-type]
all_messages.append(new_args)
if new_args:
all_messages.append(new_args)
case "function_call":
function_call = self._prepare_content_for_openai(message.role, content, call_id_to_id) # type: ignore[arg-type]
all_messages.append(function_call) # type: ignore
if function_call:
all_messages.append(function_call) # type: ignore
case "function_approval_response" | "function_approval_request":
all_messages.append(self._prepare_content_for_openai(message.role, content, call_id_to_id)) # type: ignore
prepared = self._prepare_content_for_openai(Role(message.role), content, call_id_to_id)
if prepared:
all_messages.append(prepared) # type: ignore
case _:
if "content" not in args:
args["content"] = []
args["content"].append(self._prepare_content_for_openai(message.role, content, call_id_to_id)) # type: ignore
prepared_content = self._prepare_content_for_openai(message.role, content, call_id_to_id) # type: ignore
if prepared_content:
if "content" not in args:
args["content"] = []
args["content"].append(prepared_content) # type: ignore
if "content" in args or "tool_calls" in args:
all_messages.append(args)
return all_messages
+1
View File
@@ -34,6 +34,7 @@ dependencies = [
# connectors and functions
"openai>=1.99.0",
"azure-identity>=1,<2",
"azure-ai-projects >= 2.0.0b3",
"mcp[ws]>=1.24.0,<2",
"packaging>=24.1",
]
@@ -3,6 +3,7 @@
import json
import os
from typing import Annotated, Any
from unittest.mock import MagicMock
import pytest
from azure.identity import AzureCliCredential
@@ -115,6 +116,119 @@ def test_init_with_empty_model_id(azure_openai_unit_test_env: dict[str, str]) ->
)
def test_init_with_project_client(azure_openai_unit_test_env: dict[str, str]) -> None:
"""Test initialization with an existing AIProjectClient."""
from unittest.mock import patch
from openai import AsyncOpenAI
# Create a mock AIProjectClient that returns a mock AsyncOpenAI client
mock_openai_client = MagicMock(spec=AsyncOpenAI)
mock_openai_client.default_headers = {}
mock_project_client = MagicMock()
mock_project_client.get_openai_client.return_value = mock_openai_client
with patch(
"agent_framework.azure._responses_client.AzureOpenAIResponsesClient._create_client_from_project",
return_value=mock_openai_client,
):
azure_responses_client = AzureOpenAIResponsesClient(
project_client=mock_project_client,
deployment_name="gpt-4o",
)
assert azure_responses_client.model_id == "gpt-4o"
assert azure_responses_client.client is mock_openai_client
assert isinstance(azure_responses_client, SupportsChatGetResponse)
def test_init_with_project_endpoint(azure_openai_unit_test_env: dict[str, str]) -> None:
"""Test initialization with a project endpoint and credential."""
from unittest.mock import patch
from openai import AsyncOpenAI
mock_openai_client = MagicMock(spec=AsyncOpenAI)
mock_openai_client.default_headers = {}
with patch(
"agent_framework.azure._responses_client.AzureOpenAIResponsesClient._create_client_from_project",
return_value=mock_openai_client,
):
azure_responses_client = AzureOpenAIResponsesClient(
project_endpoint="https://test-project.services.ai.azure.com",
deployment_name="gpt-4o",
credential=AzureCliCredential(),
)
assert azure_responses_client.model_id == "gpt-4o"
assert azure_responses_client.client is mock_openai_client
assert isinstance(azure_responses_client, SupportsChatGetResponse)
def test_create_client_from_project_with_project_client() -> None:
"""Test _create_client_from_project with an existing project client."""
from openai import AsyncOpenAI
mock_openai_client = MagicMock(spec=AsyncOpenAI)
mock_project_client = MagicMock()
mock_project_client.get_openai_client.return_value = mock_openai_client
result = AzureOpenAIResponsesClient._create_client_from_project(
project_client=mock_project_client,
project_endpoint=None,
credential=None,
)
assert result is mock_openai_client
mock_project_client.get_openai_client.assert_called_once()
def test_create_client_from_project_with_endpoint() -> None:
"""Test _create_client_from_project with a project endpoint."""
from unittest.mock import patch
from openai import AsyncOpenAI
mock_openai_client = MagicMock(spec=AsyncOpenAI)
mock_credential = MagicMock()
with patch("agent_framework.azure._responses_client.AIProjectClient") as MockAIProjectClient:
mock_instance = MockAIProjectClient.return_value
mock_instance.get_openai_client.return_value = mock_openai_client
result = AzureOpenAIResponsesClient._create_client_from_project(
project_client=None,
project_endpoint="https://test-project.services.ai.azure.com",
credential=mock_credential,
)
assert result is mock_openai_client
MockAIProjectClient.assert_called_once()
mock_instance.get_openai_client.assert_called_once()
def test_create_client_from_project_missing_endpoint() -> None:
"""Test _create_client_from_project raises error when endpoint is missing."""
with pytest.raises(ServiceInitializationError, match="project endpoint is required"):
AzureOpenAIResponsesClient._create_client_from_project(
project_client=None,
project_endpoint=None,
credential=MagicMock(),
)
def test_create_client_from_project_missing_credential() -> None:
"""Test _create_client_from_project raises error when credential is missing."""
with pytest.raises(ServiceInitializationError, match="credential is required"):
AzureOpenAIResponsesClient._create_client_from_project(
project_client=None,
project_endpoint="https://test-project.services.ai.azure.com",
credential=None,
)
def test_serialize(azure_openai_unit_test_env: dict[str, str]) -> None:
default_headers = {"X-Unit-Test": "test-guid"}
@@ -19,12 +19,10 @@ class TestAsToolKwargsPropagation:
captured_kwargs: dict[str, Any] = {}
@agent_middleware
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Capture kwargs passed to the sub-agent
captured_kwargs.update(context.kwargs)
await call_next(context)
await call_next()
# Setup mock response
client.responses = [
@@ -62,11 +60,9 @@ class TestAsToolKwargsPropagation:
captured_kwargs: dict[str, Any] = {}
@agent_middleware
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
captured_kwargs.update(context.kwargs)
await call_next(context)
await call_next()
# Setup mock response
client.responses = [
@@ -99,12 +95,10 @@ class TestAsToolKwargsPropagation:
captured_kwargs_list: list[dict[str, Any]] = []
@agent_middleware
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Capture kwargs at each level
captured_kwargs_list.append(dict(context.kwargs))
await call_next(context)
await call_next()
# Setup mock responses to trigger nested tool invocation: B calls tool C, then completes.
client.responses = [
@@ -162,11 +156,9 @@ class TestAsToolKwargsPropagation:
captured_kwargs: dict[str, Any] = {}
@agent_middleware
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
captured_kwargs.update(context.kwargs)
await call_next(context)
await call_next()
# Setup mock streaming responses
from agent_framework import ChatResponseUpdate
@@ -224,11 +216,9 @@ class TestAsToolKwargsPropagation:
captured_kwargs: dict[str, Any] = {}
@agent_middleware
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
captured_kwargs.update(context.kwargs)
await call_next(context)
await call_next()
# Setup mock response
client.responses = [
@@ -266,16 +256,14 @@ class TestAsToolKwargsPropagation:
call_count = 0
@agent_middleware
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
nonlocal call_count
call_count += 1
if call_count == 1:
first_call_kwargs.update(context.kwargs)
elif call_count == 2:
second_call_kwargs.update(context.kwargs)
await call_next(context)
await call_next()
# Setup mock responses for both calls
client.responses = [
@@ -318,11 +306,9 @@ class TestAsToolKwargsPropagation:
captured_kwargs: dict[str, Any] = {}
@agent_middleware
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
captured_kwargs.update(context.kwargs)
await call_next(context)
await call_next()
# Setup mock response
client.responses = [
@@ -2298,9 +2298,7 @@ async def test_streaming_error_recovery_resets_counter(chat_client_base: Support
class TerminateLoopMiddleware(FunctionMiddleware):
"""Middleware that raises MiddlewareTermination to exit the function calling loop."""
async def process(
self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
async def process(self, context: FunctionInvocationContext, next_handler: Callable[[], Awaitable[None]]) -> None:
# Set result to a simple value - the framework will wrap it in FunctionResultContent
context.result = "terminated by middleware"
raise MiddlewareTermination
@@ -2355,14 +2353,12 @@ async def test_terminate_loop_single_function_call(chat_client_base: SupportsCha
class SelectiveTerminateMiddleware(FunctionMiddleware):
"""Only terminates for terminating_function."""
async def process(
self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
async def process(self, context: FunctionInvocationContext, next_handler: Callable[[], Awaitable[None]]) -> None:
if context.function.name == "terminating_function":
# Set result to a simple value - the framework will wrap it in FunctionResultContent
context.result = "terminated by middleware"
raise MiddlewareTermination
await next_handler(context)
await next_handler()
async def test_terminate_loop_multiple_function_calls_one_terminates(chat_client_base: SupportsChatGetResponse):
@@ -135,12 +135,12 @@ class TestAgentMiddlewarePipeline:
"""Test cases for AgentMiddlewarePipeline."""
class PreNextTerminateMiddleware(AgentMiddleware):
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
raise MiddlewareTermination
class PostNextTerminateMiddleware(AgentMiddleware):
async def process(self, context: AgentContext, call_next: Any) -> None:
await call_next(context)
await call_next()
raise MiddlewareTermination
def test_init_empty(self) -> None:
@@ -157,8 +157,8 @@ class TestAgentMiddlewarePipeline:
def test_init_with_function_middleware(self) -> None:
"""Test AgentMiddlewarePipeline initialization with function-based middleware."""
async def test_middleware(context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
await call_next(context)
async def test_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
pipeline = AgentMiddlewarePipeline(test_middleware)
assert pipeline.has_middlewares
@@ -185,11 +185,9 @@ class TestAgentMiddlewarePipeline:
def __init__(self, name: str):
self.name = name
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append(f"{self.name}_before")
await call_next(context)
await call_next()
execution_order.append(f"{self.name}_after")
middleware = OrderTrackingMiddleware("test")
@@ -238,11 +236,9 @@ class TestAgentMiddlewarePipeline:
def __init__(self, name: str):
self.name = name
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append(f"{self.name}_before")
await call_next(context)
await call_next()
execution_order.append(f"{self.name}_after")
middleware = StreamOrderTrackingMiddleware("test")
@@ -367,12 +363,10 @@ class TestAgentMiddlewarePipeline:
captured_thread = None
class ThreadCapturingMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
nonlocal captured_thread
captured_thread = context.thread
await call_next(context)
await call_next()
middleware = ThreadCapturingMiddleware()
pipeline = AgentMiddlewarePipeline(middleware)
@@ -394,12 +388,10 @@ class TestAgentMiddlewarePipeline:
captured_thread = "not_none" # Use string to distinguish from None
class ThreadCapturingMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
nonlocal captured_thread
captured_thread = context.thread
await call_next(context)
await call_next()
middleware = ThreadCapturingMiddleware()
pipeline = AgentMiddlewarePipeline(middleware)
@@ -425,7 +417,7 @@ class TestFunctionMiddlewarePipeline:
class PostNextTerminateFunctionMiddleware(FunctionMiddleware):
async def process(self, context: FunctionInvocationContext, call_next: Any) -> None:
await call_next(context)
await call_next()
raise MiddlewareTermination
async def test_execute_with_pre_next_termination(self, mock_function: FunctionTool[Any, Any]) -> None:
@@ -482,10 +474,8 @@ class TestFunctionMiddlewarePipeline:
def test_init_with_function_middleware(self) -> None:
"""Test FunctionMiddlewarePipeline initialization with function-based middleware."""
async def test_middleware(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
await call_next(context)
async def test_middleware(context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
pipeline = FunctionMiddlewarePipeline(test_middleware)
assert pipeline.has_middlewares
@@ -515,10 +505,10 @@ class TestFunctionMiddlewarePipeline:
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
execution_order.append(f"{self.name}_before")
await call_next(context)
await call_next()
execution_order.append(f"{self.name}_after")
middleware = OrderTrackingFunctionMiddleware("test")
@@ -541,12 +531,12 @@ class TestChatMiddlewarePipeline:
"""Test cases for ChatMiddlewarePipeline."""
class PreNextTerminateChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
raise MiddlewareTermination
class PostNextTerminateChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
await call_next(context)
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
raise MiddlewareTermination
def test_init_empty(self) -> None:
@@ -563,8 +553,8 @@ class TestChatMiddlewarePipeline:
def test_init_with_function_middleware(self) -> None:
"""Test ChatMiddlewarePipeline initialization with function-based middleware."""
async def test_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
await call_next(context)
async def test_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
pipeline = ChatMiddlewarePipeline(test_middleware)
assert pipeline.has_middlewares
@@ -592,9 +582,9 @@ class TestChatMiddlewarePipeline:
def __init__(self, name: str):
self.name = name
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append(f"{self.name}_before")
await call_next(context)
await call_next()
execution_order.append(f"{self.name}_after")
middleware = OrderTrackingChatMiddleware("test")
@@ -644,9 +634,9 @@ class TestChatMiddlewarePipeline:
def __init__(self, name: str):
self.name = name
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append(f"{self.name}_before")
await call_next(context)
await call_next()
execution_order.append(f"{self.name}_after")
middleware = StreamOrderTrackingChatMiddleware("test")
@@ -774,12 +764,10 @@ class TestClassBasedMiddleware:
metadata_updates: list[str] = []
class MetadataAgentMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
context.metadata["before"] = True
metadata_updates.append("before")
await call_next(context)
await call_next()
context.metadata["after"] = True
metadata_updates.append("after")
@@ -807,11 +795,11 @@ class TestClassBasedMiddleware:
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
context.metadata["before"] = True
metadata_updates.append("before")
await call_next(context)
await call_next()
context.metadata["after"] = True
metadata_updates.append("after")
@@ -839,12 +827,10 @@ class TestFunctionBasedMiddleware:
"""Test function-based agent middleware."""
execution_order: list[str] = []
async def test_agent_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def test_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("function_before")
context.metadata["function_middleware"] = True
await call_next(context)
await call_next()
execution_order.append("function_after")
pipeline = AgentMiddlewarePipeline(test_agent_middleware)
@@ -866,11 +852,11 @@ class TestFunctionBasedMiddleware:
execution_order: list[str] = []
async def test_function_middleware(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
execution_order.append("function_before")
context.metadata["function_middleware"] = True
await call_next(context)
await call_next()
execution_order.append("function_after")
pipeline = FunctionMiddlewarePipeline(test_function_middleware)
@@ -896,18 +882,14 @@ class TestMixedMiddleware:
execution_order: list[str] = []
class ClassMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("class_before")
await call_next(context)
await call_next()
execution_order.append("class_after")
async def function_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def function_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("function_before")
await call_next(context)
await call_next()
execution_order.append("function_after")
pipeline = AgentMiddlewarePipeline(ClassMiddleware(), function_middleware)
@@ -931,17 +913,17 @@ class TestMixedMiddleware:
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
execution_order.append("class_before")
await call_next(context)
await call_next()
execution_order.append("class_after")
async def function_middleware(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
execution_order.append("function_before")
await call_next(context)
await call_next()
execution_order.append("function_after")
pipeline = FunctionMiddlewarePipeline(ClassMiddleware(), function_middleware)
@@ -962,16 +944,14 @@ class TestMixedMiddleware:
execution_order: list[str] = []
class ClassChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("class_before")
await call_next(context)
await call_next()
execution_order.append("class_after")
async def function_chat_middleware(
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
) -> None:
async def function_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("function_before")
await call_next(context)
await call_next()
execution_order.append("function_after")
pipeline = ChatMiddlewarePipeline(ClassChatMiddleware(), function_chat_middleware)
@@ -997,27 +977,21 @@ class TestMultipleMiddlewareOrdering:
execution_order: list[str] = []
class FirstMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("first_before")
await call_next(context)
await call_next()
execution_order.append("first_after")
class SecondMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("second_before")
await call_next(context)
await call_next()
execution_order.append("second_after")
class ThirdMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("third_before")
await call_next(context)
await call_next()
execution_order.append("third_after")
middleware = [FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()]
@@ -1051,20 +1025,20 @@ class TestMultipleMiddlewareOrdering:
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
execution_order.append("first_before")
await call_next(context)
await call_next()
execution_order.append("first_after")
class SecondMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
execution_order.append("second_before")
await call_next(context)
await call_next()
execution_order.append("second_after")
middleware = [FirstMiddleware(), SecondMiddleware()]
@@ -1087,21 +1061,21 @@ class TestMultipleMiddlewareOrdering:
execution_order: list[str] = []
class FirstChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("first_before")
await call_next(context)
await call_next()
execution_order.append("first_after")
class SecondChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("second_before")
await call_next(context)
await call_next()
execution_order.append("second_after")
class ThirdChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("third_before")
await call_next(context)
await call_next()
execution_order.append("third_after")
middleware = [FirstChatMiddleware(), SecondChatMiddleware(), ThirdChatMiddleware()]
@@ -1136,9 +1110,7 @@ class TestContextContentValidation:
"""Test that agent context contains expected data."""
class ContextValidationMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Verify context has all expected attributes
assert hasattr(context, "agent")
assert hasattr(context, "messages")
@@ -1156,7 +1128,7 @@ class TestContextContentValidation:
# Add custom metadata
context.metadata["validated"] = True
await call_next(context)
await call_next()
middleware = ContextValidationMiddleware()
pipeline = AgentMiddlewarePipeline(middleware)
@@ -1178,7 +1150,7 @@ class TestContextContentValidation:
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
# Verify context has all expected attributes
assert hasattr(context, "function")
@@ -1194,7 +1166,7 @@ class TestContextContentValidation:
# Add custom metadata
context.metadata["validated"] = True
await call_next(context)
await call_next()
middleware = ContextValidationMiddleware()
pipeline = FunctionMiddlewarePipeline(middleware)
@@ -1213,7 +1185,7 @@ class TestContextContentValidation:
"""Test that chat context contains expected data."""
class ChatContextValidationMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Verify context has all expected attributes
assert hasattr(context, "client")
assert hasattr(context, "messages")
@@ -1235,7 +1207,7 @@ class TestContextContentValidation:
# Add custom metadata
context.metadata["validated"] = True
await call_next(context)
await call_next()
middleware = ChatContextValidationMiddleware()
pipeline = ChatMiddlewarePipeline(middleware)
@@ -1260,11 +1232,9 @@ class TestStreamingScenarios:
streaming_flags: list[bool] = []
class StreamingFlagMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
streaming_flags.append(context.stream)
await call_next(context)
await call_next()
middleware = StreamingFlagMiddleware()
pipeline = AgentMiddlewarePipeline(middleware)
@@ -1302,11 +1272,9 @@ class TestStreamingScenarios:
chunks_processed: list[str] = []
class StreamProcessingMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
chunks_processed.append("before_stream")
await call_next(context)
await call_next()
chunks_processed.append("after_stream")
middleware = StreamProcessingMiddleware()
@@ -1345,9 +1313,9 @@ class TestStreamingScenarios:
streaming_flags: list[bool] = []
class ChatStreamingFlagMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
streaming_flags.append(context.stream)
await call_next(context)
await call_next()
middleware = ChatStreamingFlagMiddleware()
pipeline = ChatMiddlewarePipeline(middleware)
@@ -1386,9 +1354,9 @@ class TestStreamingScenarios:
chunks_processed: list[str] = []
class ChatStreamProcessingMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
chunks_processed.append("before_stream")
await call_next(context)
await call_next()
chunks_processed.append("after_stream")
middleware = ChatStreamProcessingMiddleware()
@@ -1436,24 +1404,22 @@ class FunctionTestArgs(BaseModel):
class TestAgentMiddleware(AgentMiddleware):
"""Test implementation of AgentMiddleware."""
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
await call_next(context)
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
class TestFunctionMiddleware(FunctionMiddleware):
"""Test implementation of FunctionMiddleware."""
async def process(
self, context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
await call_next(context)
async def process(self, context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
class TestChatMiddleware(ChatMiddleware):
"""Test implementation of ChatMiddleware."""
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
await call_next(context)
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
class MockFunctionArgs(BaseModel):
@@ -1469,9 +1435,7 @@ class TestMiddlewareExecutionControl:
"""Test that when agent middleware doesn't call next(), no execution happens."""
class NoNextMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Don't call next() - this should prevent any execution
pass
@@ -1498,9 +1462,7 @@ class TestMiddlewareExecutionControl:
"""Test that when agent middleware doesn't call next(), no streaming execution happens."""
class NoNextStreamingMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Don't call next() - this should prevent any execution
pass
@@ -1537,7 +1499,7 @@ class TestMiddlewareExecutionControl:
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
# Don't call next() - this should prevent any execution
pass
@@ -1566,18 +1528,14 @@ class TestMiddlewareExecutionControl:
execution_order: list[str] = []
class FirstMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("first")
# Don't call next() - this should stop the pipeline
class SecondMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("second")
await call_next(context)
await call_next()
pipeline = AgentMiddlewarePipeline(FirstMiddleware(), SecondMiddleware())
messages = [Message(role="user", text="test")]
@@ -1601,7 +1559,7 @@ class TestMiddlewareExecutionControl:
"""Test that when chat middleware doesn't call next(), no execution happens."""
class NoNextChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Don't call next() - this should prevent any execution
pass
@@ -1629,7 +1587,7 @@ class TestMiddlewareExecutionControl:
"""Test that when chat middleware doesn't call next(), no streaming execution happens."""
class NoNextStreamingChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Don't call next() - this should prevent any execution
pass
@@ -1670,14 +1628,14 @@ class TestMiddlewareExecutionControl:
execution_order: list[str] = []
class FirstChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("first")
# Don't call next() - this should stop the pipeline
class SecondChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("second")
await call_next(context)
await call_next()
pipeline = ChatMiddlewarePipeline(FirstChatMiddleware(), SecondChatMiddleware())
messages = [Message(role="user", text="test")]
@@ -43,11 +43,9 @@ class TestResultOverrideMiddleware:
override_response = AgentResponse(messages=[Message(role="assistant", text="overridden response")])
class ResponseOverrideMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Execute the pipeline first, then override the response
await call_next(context)
await call_next()
context.result = override_response
middleware = ResponseOverrideMiddleware()
@@ -79,11 +77,9 @@ class TestResultOverrideMiddleware:
yield AgentResponseUpdate(contents=[Content.from_text(text=" stream")])
class StreamResponseOverrideMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Execute the pipeline first, then override the response stream
await call_next(context)
await call_next()
context.result = ResponseStream(override_stream())
middleware = StreamResponseOverrideMiddleware()
@@ -115,10 +111,10 @@ class TestResultOverrideMiddleware:
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
# Execute the pipeline first, then override the result
await call_next(context)
await call_next()
context.result = override_result
middleware = ResultOverrideMiddleware()
@@ -145,11 +141,9 @@ class TestResultOverrideMiddleware:
mock_chat_client = MockChatClient()
class ChatAgentResponseOverrideMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Always call next() first to allow execution
await call_next(context)
await call_next()
# Then conditionally override based on content
if any("special" in msg.text for msg in context.messages if msg.text):
context.result = AgentResponse(
@@ -184,15 +178,13 @@ class TestResultOverrideMiddleware:
yield AgentResponseUpdate(contents=[Content.from_text(text=" response!")])
class ChatAgentStreamOverrideMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Check if we want to override BEFORE calling next to avoid creating unused streams
if any("custom stream" in msg.text for msg in context.messages if msg.text):
context.result = ResponseStream(custom_stream())
return # Don't call next() - we're overriding the entire result
# Normal case - let the agent handle it
await call_next(context)
await call_next()
# Create Agent with override middleware
middleware = ChatAgentStreamOverrideMiddleware()
@@ -223,12 +215,10 @@ class TestResultOverrideMiddleware:
"""Test that when agent middleware conditionally doesn't call next(), no execution happens."""
class ConditionalNoNextMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Only call next() if message contains "execute"
if any("execute" in msg.text for msg in context.messages if msg.text):
await call_next(context)
await call_next()
# Otherwise, don't call next() - no execution should happen
middleware = ConditionalNoNextMiddleware()
@@ -269,13 +259,13 @@ class TestResultOverrideMiddleware:
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
# Only call next() if argument name contains "execute"
args = context.arguments
assert isinstance(args, FunctionTestArgs)
if "execute" in args.name:
await call_next(context)
await call_next()
# Otherwise, don't call next() - no execution should happen
middleware = ConditionalNoNextFunctionMiddleware()
@@ -318,14 +308,12 @@ class TestResultObservability:
observed_responses: list[AgentResponse] = []
class ObservabilityMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Context should be empty before next()
assert context.result is None
# Call next to execute
await call_next(context)
await call_next()
# Context should now contain the response for observability
assert context.result is not None
@@ -355,13 +343,13 @@ class TestResultObservability:
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
# Context should be empty before next()
assert context.result is None
# Call next to execute
await call_next(context)
await call_next()
# Context should now contain the result for observability
assert context.result is not None
@@ -386,11 +374,9 @@ class TestResultObservability:
"""Test that middleware can override response after observing execution."""
class PostExecutionOverrideMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Call next to execute first
await call_next(context)
await call_next()
# Now observe and conditionally override
assert context.result is not None
@@ -423,10 +409,10 @@ class TestResultObservability:
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
# Call next to execute first
await call_next(context)
await call_next()
# Now observe and conditionally override
assert context.result is not None
@@ -44,11 +44,9 @@ class TestChatAgentClassBasedMiddleware:
def __init__(self, name: str):
self.name = name
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append(f"{self.name}_before")
await call_next(context)
await call_next()
execution_order.append(f"{self.name}_after")
# Create Agent with middleware
@@ -76,9 +74,9 @@ class TestChatAgentClassBasedMiddleware:
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
await call_next(context)
await call_next()
middleware = TrackingFunctionMiddleware()
Agent(client=client, middleware=[middleware])
@@ -96,10 +94,10 @@ class TestChatAgentClassBasedMiddleware:
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
execution_order.append(f"{self.name}_before")
await call_next(context)
await call_next()
execution_order.append(f"{self.name}_after")
middleware = TrackingFunctionMiddleware("function_middleware")
@@ -122,13 +120,11 @@ class TestChatAgentFunctionBasedMiddleware:
execution_order: list[str] = []
class PreTerminationMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("middleware_before")
raise MiddlewareTermination
# Code after raise is unreachable
await call_next(context)
await call_next()
execution_order.append("middleware_after")
# Create Agent with terminating middleware
@@ -153,11 +149,9 @@ class TestChatAgentFunctionBasedMiddleware:
execution_order: list[str] = []
class PostTerminationMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("middleware_before")
await call_next(context)
await call_next()
execution_order.append("middleware_after")
context.terminate = True
@@ -193,12 +187,12 @@ class TestChatAgentFunctionBasedMiddleware:
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
execution_order.append("middleware_before")
context.terminate = True
# We call next() but since terminate=True, subsequent middleware and handler should not execute
await call_next(context)
await call_next()
execution_order.append("middleware_after")
Agent(client=client, middleware=[PreTerminationFunctionMiddleware()], tools=[])
@@ -211,10 +205,10 @@ class TestChatAgentFunctionBasedMiddleware:
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
execution_order.append("middleware_before")
await call_next(context)
await call_next()
execution_order.append("middleware_after")
context.terminate = True
@@ -224,11 +218,9 @@ class TestChatAgentFunctionBasedMiddleware:
"""Test function-based agent middleware with Agent."""
execution_order: list[str] = []
async def tracking_agent_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def tracking_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("agent_function_before")
await call_next(context)
await call_next()
execution_order.append("agent_function_after")
# Create Agent with function middleware
@@ -252,9 +244,9 @@ class TestChatAgentFunctionBasedMiddleware:
"""Test function-based function middleware with Agent."""
async def tracking_function_middleware(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
await call_next(context)
await call_next()
Agent(client=client, middleware=[tracking_function_middleware])
@@ -265,10 +257,10 @@ class TestChatAgentFunctionBasedMiddleware:
execution_order: list[str] = []
async def tracking_function_middleware(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
execution_order.append("function_function_before")
await call_next(context)
await call_next()
execution_order.append("function_function_after")
agent = Agent(client=chat_client_base, middleware=[tracking_function_middleware])
@@ -290,12 +282,10 @@ class TestChatAgentStreamingMiddleware:
streaming_flags: list[bool] = []
class StreamingTrackingMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("middleware_before")
streaming_flags.append(context.stream)
await call_next(context)
await call_next()
execution_order.append("middleware_after")
# Create Agent with middleware
@@ -334,11 +324,9 @@ class TestChatAgentStreamingMiddleware:
streaming_flags: list[bool] = []
class FlagTrackingMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
streaming_flags.append(context.stream)
await call_next(context)
await call_next()
# Create Agent with middleware
middleware = FlagTrackingMiddleware()
@@ -368,11 +356,9 @@ class TestChatAgentMultipleMiddlewareOrdering:
def __init__(self, name: str):
self.name = name
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append(f"{self.name}_before")
await call_next(context)
await call_next()
execution_order.append(f"{self.name}_after")
# Create multiple middleware
@@ -400,35 +386,31 @@ class TestChatAgentMultipleMiddlewareOrdering:
execution_order: list[str] = []
class ClassAgentMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("class_agent_before")
await call_next(context)
await call_next()
execution_order.append("class_agent_after")
async def function_agent_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def function_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("function_agent_before")
await call_next(context)
await call_next()
execution_order.append("function_agent_after")
class ClassFunctionMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
execution_order.append("class_function_before")
await call_next(context)
await call_next()
execution_order.append("class_function_after")
async def function_function_middleware(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
execution_order.append("function_function_before")
await call_next(context)
await call_next()
execution_order.append("function_function_after")
agent = Agent(
@@ -447,25 +429,21 @@ class TestChatAgentMultipleMiddlewareOrdering:
execution_order: list[str] = []
class ClassAgentMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("class_agent_before")
await call_next(context)
await call_next()
execution_order.append("class_agent_after")
async def function_agent_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def function_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("function_agent_before")
await call_next(context)
await call_next()
execution_order.append("function_agent_after")
async def function_function_middleware(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
execution_order.append("function_function_before")
await call_next(context)
await call_next()
execution_order.append("function_function_after")
agent = Agent(
@@ -521,10 +499,10 @@ class TestChatAgentFunctionMiddlewareWithTools:
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
execution_order.append(f"{self.name}_before")
await call_next(context)
await call_next()
execution_order.append(f"{self.name}_after")
# Set up mock to return a function call first, then a regular response
@@ -583,10 +561,10 @@ class TestChatAgentFunctionMiddlewareWithTools:
execution_order: list[str] = []
async def tracking_function_middleware(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
execution_order.append("function_middleware_before")
await call_next(context)
await call_next()
execution_order.append("function_middleware_after")
# Set up mock to return a function call first, then a regular response
@@ -647,20 +625,20 @@ class TestChatAgentFunctionMiddlewareWithTools:
async def process(
self,
context: AgentContext,
call_next: Callable[[AgentContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
execution_order.append("agent_middleware_before")
await call_next(context)
await call_next()
execution_order.append("agent_middleware_after")
class TrackingFunctionMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
execution_order.append("function_middleware_before")
await call_next(context)
await call_next()
execution_order.append("function_middleware_after")
# Set up mock to return a function call first, then a regular response
@@ -728,7 +706,7 @@ class TestChatAgentFunctionMiddlewareWithTools:
@function_middleware
async def kwargs_middleware(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
nonlocal middleware_called
middleware_called = True
@@ -748,7 +726,7 @@ class TestChatAgentFunctionMiddlewareWithTools:
modified_kwargs["new_param"] = context.kwargs.get("new_param")
modified_kwargs["custom_param"] = context.kwargs.get("custom_param")
await call_next(context)
await call_next()
chat_client_base.run_responses = [
ChatResponse(
@@ -801,9 +779,9 @@ class TestMiddlewareDynamicRebuild:
self.name = name
self.execution_log = execution_log
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
self.execution_log.append(f"{self.name}_start")
await call_next(context)
await call_next()
self.execution_log.append(f"{self.name}_end")
async def test_middleware_dynamic_rebuild_non_streaming(self, client: "MockChatClient") -> None:
@@ -924,9 +902,9 @@ class TestRunLevelMiddleware:
self.name = name
self.execution_log = execution_log
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
self.execution_log.append(f"{self.name}_start")
await call_next(context)
await call_next()
self.execution_log.append(f"{self.name}_end")
async def test_run_level_middleware_isolation(self, client: "MockChatClient") -> None:
@@ -976,29 +954,25 @@ class TestRunLevelMiddleware:
def __init__(self, name: str):
self.name = name
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_log.append(f"{self.name}_start")
# Set metadata to pass information to run middleware
context.metadata[f"{self.name}_key"] = f"{self.name}_value"
await call_next(context)
await call_next()
execution_log.append(f"{self.name}_end")
class MetadataRunMiddleware(AgentMiddleware):
def __init__(self, name: str):
self.name = name
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_log.append(f"{self.name}_start")
# Read metadata set by agent middleware
for key, value in context.metadata.items():
metadata_log.append(f"{self.name}_reads_{key}:{value}")
# Set run-level metadata
context.metadata[f"{self.name}_key"] = f"{self.name}_value"
await call_next(context)
await call_next()
execution_log.append(f"{self.name}_end")
# Create agent with agent-level middleware
@@ -1049,12 +1023,10 @@ class TestRunLevelMiddleware:
def __init__(self, name: str):
self.name = name
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_log.append(f"{self.name}_start")
streaming_flags.append(context.stream)
await call_next(context)
await call_next()
execution_log.append(f"{self.name}_end")
# Create agent without agent-level middleware
@@ -1093,48 +1065,44 @@ class TestRunLevelMiddleware:
# Agent-level middleware
class AgentLevelAgentMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_log.append("agent_level_agent_start")
context.metadata["agent_level_agent"] = "processed"
await call_next(context)
await call_next()
execution_log.append("agent_level_agent_end")
class AgentLevelFunctionMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
execution_log.append("agent_level_function_start")
context.metadata["agent_level_function"] = "processed"
await call_next(context)
await call_next()
execution_log.append("agent_level_function_end")
# Run-level middleware
class RunLevelAgentMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_log.append("run_level_agent_start")
# Verify agent-level middleware metadata is available
assert "agent_level_agent" in context.metadata
context.metadata["run_level_agent"] = "processed"
await call_next(context)
await call_next()
execution_log.append("run_level_agent_end")
class RunLevelFunctionMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
execution_log.append("run_level_function_start")
# Verify agent-level function middleware metadata is available
assert "agent_level_function" in context.metadata
context.metadata["run_level_function"] = "processed"
await call_next(context)
await call_next()
execution_log.append("run_level_function_end")
# Create tool function for testing function middleware
@@ -1217,18 +1185,16 @@ class TestMiddlewareDecoratorLogic:
execution_order: list[str] = []
@agent_middleware
async def matching_agent_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def matching_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("decorator_type_match_agent")
await call_next(context)
await call_next()
@function_middleware
async def matching_function_middleware(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
execution_order.append("decorator_type_match_function")
await call_next(context)
await call_next()
# Create tool function for testing function middleware
def custom_tool(message: str) -> str:
@@ -1282,7 +1248,7 @@ class TestMiddlewareDecoratorLogic:
context: FunctionInvocationContext, # Wrong type for @agent_middleware
call_next: Any,
) -> None:
await call_next(context)
await call_next()
agent = Agent(client=client, middleware=[mismatched_middleware])
await agent.run([Message(role="user", text="test")])
@@ -1294,12 +1260,12 @@ class TestMiddlewareDecoratorLogic:
@agent_middleware
async def decorator_only_agent(context: Any, call_next: Any) -> None: # No type annotation
execution_order.append("decorator_only_agent")
await call_next(context)
await call_next()
@function_middleware
async def decorator_only_function(context: Any, call_next: Any) -> None: # No type annotation
execution_order.append("decorator_only_function")
await call_next(context)
await call_next()
# Create tool function for testing function middleware
def custom_tool(message: str) -> str:
@@ -1346,16 +1312,16 @@ class TestMiddlewareDecoratorLogic:
execution_order: list[str] = []
# No decorator
async def type_only_agent(context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def type_only_agent(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("type_only_agent")
await call_next(context)
await call_next()
# No decorator
async def type_only_function(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
execution_order.append("type_only_function")
await call_next(context)
await call_next()
# Create tool function for testing function middleware
def custom_tool(message: str) -> str:
@@ -1399,7 +1365,7 @@ class TestMiddlewareDecoratorLogic:
"""Neither decorator nor parameter type specified - should throw exception."""
async def no_info_middleware(context: Any, call_next: Any) -> None: # No decorator, no type
await call_next(context)
await call_next()
# Should raise MiddlewareException
with pytest.raises(MiddlewareException, match="Cannot determine middleware type"):
@@ -1447,9 +1413,7 @@ class TestChatAgentThreadBehavior:
thread_states: list[dict[str, Any]] = []
class ThreadTrackingMiddleware(AgentMiddleware):
async def process(
self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Capture state before next() call
thread_messages = []
if context.thread and context.thread.message_store:
@@ -1464,7 +1428,7 @@ class TestChatAgentThreadBehavior:
}
thread_states.append(before_state)
await call_next(context)
await call_next()
# Capture state after next() call
thread_messages_after = []
@@ -1560,9 +1524,9 @@ class TestChatAgentChatMiddleware:
execution_order: list[str] = []
class TrackingChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("chat_middleware_before")
await call_next(context)
await call_next()
execution_order.append("chat_middleware_after")
# Create Agent with chat middleware
@@ -1588,11 +1552,9 @@ class TestChatAgentChatMiddleware:
"""Test function-based chat middleware with Agent."""
execution_order: list[str] = []
async def tracking_chat_middleware(
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
) -> None:
async def tracking_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("chat_middleware_before")
await call_next(context)
await call_next()
execution_order.append("chat_middleware_after")
# Create Agent with function-based chat middleware
@@ -1617,9 +1579,7 @@ class TestChatAgentChatMiddleware:
"""Test that chat middleware can modify messages before sending to model."""
@chat_middleware
async def message_modifier_middleware(
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
) -> None:
async def message_modifier_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Modify the first message by adding a prefix
if context.messages:
for idx, msg in enumerate(context.messages):
@@ -1628,7 +1588,7 @@ class TestChatAgentChatMiddleware:
original_text = msg.text or ""
context.messages[idx] = Message(role=msg.role, text=f"MODIFIED: {original_text}")
break
await call_next(context)
await call_next()
# Create Agent with message-modifying middleware
client = MockBaseChatClient()
@@ -1646,9 +1606,7 @@ class TestChatAgentChatMiddleware:
"""Test that chat middleware can override the response."""
@chat_middleware
async def response_override_middleware(
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
) -> None:
async def response_override_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Override the response without calling next()
context.result = ChatResponse(
messages=[Message(role="assistant", text="MiddlewareTypes overridden response")],
@@ -1675,15 +1633,15 @@ class TestChatAgentChatMiddleware:
execution_order: list[str] = []
@chat_middleware
async def first_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("first_before")
await call_next(context)
await call_next()
execution_order.append("first_after")
@chat_middleware
async def second_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("second_before")
await call_next(context)
await call_next()
execution_order.append("second_after")
# Create Agent with multiple chat middleware
@@ -1709,10 +1667,10 @@ class TestChatAgentChatMiddleware:
streaming_flags: list[bool] = []
class StreamingTrackingChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("streaming_chat_before")
streaming_flags.append(context.stream)
await call_next(context)
await call_next()
execution_order.append("streaming_chat_after")
# Create Agent with chat middleware
@@ -1749,13 +1707,13 @@ class TestChatAgentChatMiddleware:
execution_order: list[str] = []
class PreTerminationChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("middleware_before")
# Set a custom response since we're terminating
context.result = ChatResponse(messages=[Message(role="assistant", text="Terminated by middleware")])
raise MiddlewareTermination
# We call next() but since terminate=True, execution should stop
await call_next(context)
await call_next()
execution_order.append("middleware_after")
# Create Agent with terminating middleware
@@ -1777,9 +1735,9 @@ class TestChatAgentChatMiddleware:
execution_order: list[str] = []
class PostTerminationChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("middleware_before")
await call_next(context)
await call_next()
execution_order.append("middleware_after")
context.terminate = True
@@ -1804,21 +1762,21 @@ class TestChatAgentChatMiddleware:
"""Test Agent with combined middleware types."""
execution_order: list[str] = []
async def agent_middleware(context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("agent_middleware_before")
await call_next(context)
await call_next()
execution_order.append("agent_middleware_after")
async def chat_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("chat_middleware_before")
await call_next(context)
await call_next()
execution_order.append("chat_middleware_after")
async def function_middleware(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
execution_order.append("function_middleware_before")
await call_next(context)
await call_next()
execution_order.append("function_middleware_after")
# Create Agent with function middleware and tools
@@ -1842,9 +1800,7 @@ class TestChatAgentChatMiddleware:
modified_kwargs: dict[str, Any] = {}
@agent_middleware
async def kwargs_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
async def kwargs_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Capture the original kwargs
captured_kwargs.update(context.kwargs)
@@ -1856,7 +1812,7 @@ class TestChatAgentChatMiddleware:
# Store modified kwargs for verification
modified_kwargs.update(context.kwargs)
await call_next(context)
await call_next()
# Create Agent with agent middleware
client = MockBaseChatClient()
@@ -1895,10 +1851,10 @@ class TestChatAgentChatMiddleware:
# class TrackingMiddleware(AgentMiddleware):
# async def process(
# self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
# self, context: AgentContext, call_next: Callable[[], Awaitable[None]]
# ) -> None:
# execution_order.append("before")
# await call_next(context)
# await call_next()
# execution_order.append("after")
# @use_agent_middleware
@@ -32,10 +32,10 @@ class TestChatMiddleware:
async def process(
self,
context: ChatContext,
call_next: Callable[[ChatContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
execution_order.append("chat_middleware_before")
await call_next(context)
await call_next()
execution_order.append("chat_middleware_after")
# Add middleware to chat client
@@ -58,11 +58,9 @@ class TestChatMiddleware:
execution_order: list[str] = []
@chat_middleware
async def logging_chat_middleware(
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
) -> None:
async def logging_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("function_middleware_before")
await call_next(context)
await call_next()
execution_order.append("function_middleware_after")
# Add middleware to chat client
@@ -84,14 +82,12 @@ class TestChatMiddleware:
"""Test that chat middleware can modify messages before sending to model."""
@chat_middleware
async def message_modifier_middleware(
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
) -> None:
async def message_modifier_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Modify the first message by adding a prefix
if context.messages and len(context.messages) > 0:
original_text = context.messages[0].text or ""
context.messages[0] = Message(role=context.messages[0].role, text=f"MODIFIED: {original_text}")
await call_next(context)
await call_next()
# Add middleware to chat client
chat_client_base.chat_middleware = [message_modifier_middleware]
@@ -110,9 +106,7 @@ class TestChatMiddleware:
"""Test that chat middleware can override the response."""
@chat_middleware
async def response_override_middleware(
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
) -> None:
async def response_override_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Override the response without calling next()
context.result = ChatResponse(
messages=[Message(role="assistant", text="MiddlewareTypes overridden response")],
@@ -138,15 +132,15 @@ class TestChatMiddleware:
execution_order: list[str] = []
@chat_middleware
async def first_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("first_before")
await call_next(context)
await call_next()
execution_order.append("first_after")
@chat_middleware
async def second_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("second_before")
await call_next(context)
await call_next()
execution_order.append("second_after")
# Add middleware to chat client (order should be preserved)
@@ -173,11 +167,9 @@ class TestChatMiddleware:
execution_order: list[str] = []
@chat_middleware
async def agent_level_chat_middleware(
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
) -> None:
async def agent_level_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("agent_chat_middleware_before")
await call_next(context)
await call_next()
execution_order.append("agent_chat_middleware_after")
client = MockBaseChatClient()
@@ -205,15 +197,15 @@ class TestChatMiddleware:
execution_order: list[str] = []
@chat_middleware
async def first_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("first_before")
await call_next(context)
await call_next()
execution_order.append("first_after")
@chat_middleware
async def second_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("second_before")
await call_next(context)
await call_next()
execution_order.append("second_after")
# Create Agent with multiple chat middleware
@@ -240,9 +232,7 @@ class TestChatMiddleware:
execution_order: list[str] = []
@chat_middleware
async def streaming_middleware(
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
) -> None:
async def streaming_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_order.append("streaming_before")
# Verify it's a streaming context
assert context.stream is True
@@ -254,7 +244,7 @@ class TestChatMiddleware:
return update
context.stream_transform_hooks.append(upper_case_update)
await call_next(context)
await call_next()
execution_order.append("streaming_after")
# Add middleware to chat client
@@ -278,11 +268,9 @@ class TestChatMiddleware:
execution_count = {"count": 0}
@chat_middleware
async def counting_middleware(
context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]
) -> None:
async def counting_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
execution_count["count"] += 1
await call_next(context)
await call_next()
# First call with run-level middleware
messages = [Message(role="user", text="first message")]
@@ -310,7 +298,7 @@ class TestChatMiddleware:
modified_kwargs: dict[str, Any] = {}
@chat_middleware
async def kwargs_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def kwargs_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Capture the original kwargs
captured_kwargs.update(context.kwargs)
@@ -322,7 +310,7 @@ class TestChatMiddleware:
# Store modified kwargs for verification
modified_kwargs.update(context.kwargs)
await call_next(context)
await call_next()
# Add middleware to chat client
chat_client_base.chat_middleware = [kwargs_middleware]
@@ -355,11 +343,11 @@ class TestChatMiddleware:
@function_middleware
async def test_function_middleware(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
nonlocal execution_order
execution_order.append(f"function_middleware_before_{context.function.name}")
await call_next(context)
await call_next()
execution_order.append(f"function_middleware_after_{context.function.name}")
# Define a simple tool function
@@ -421,10 +409,10 @@ class TestChatMiddleware:
@function_middleware
async def run_level_function_middleware(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
execution_order.append("run_level_function_middleware_before")
await call_next(context)
await call_next()
execution_order.append("run_level_function_middleware_after")
# Define a simple tool function
@@ -798,12 +798,8 @@ def test_chat_message_with_error_content() -> None:
result = client._prepare_message_for_openai(message, call_id_to_id)
# Message should be prepared with empty content list since ErrorContent returns {}
assert len(result) == 1
prepared_message = result[0]
assert prepared_message["role"] == "assistant"
# Content should be a list with empty dict since ErrorContent returns {}
assert prepared_message.get("content") == [{}]
# Message should be empty since ErrorContent is filtered out
assert len(result) == 0
def test_chat_message_with_usage_content() -> None:
@@ -823,12 +819,8 @@ def test_chat_message_with_usage_content() -> None:
result = client._prepare_message_for_openai(message, call_id_to_id)
# Message should be prepared with empty content list since UsageContent returns {}
assert len(result) == 1
prepared_message = result[0]
assert prepared_message["role"] == "assistant"
# Content should be a list with empty dict since UsageContent returns {}
assert prepared_message.get("content") == [{}]
# Message should be empty since UsageContent is filtered out
assert len(result) == 0
def test_hosted_file_content_preparation() -> None:
@@ -207,7 +207,7 @@ def test_serialize(ollama_unit_test_env: dict[str, str]) -> None:
def test_chat_middleware(ollama_unit_test_env: dict[str, str]) -> None:
@chat_middleware
async def sample_middleware(context, call_next):
await call_next(context)
await call_next()
ollama_chat_client = OllamaChatClient(middleware=[sample_middleware])
assert len(ollama_chat_client.middleware) == 1
@@ -129,11 +129,11 @@ class _AutoHandoffMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Intercept matching handoff tool calls and inject synthetic results."""
if context.function.name not in self._handoff_functions:
await call_next(context)
await call_next()
return
from agent_framework._middleware import MiddlewareTermination
@@ -65,7 +65,7 @@ class PurviewPolicyMiddleware(AgentMiddleware):
async def process(
self,
context: AgentContext,
call_next: Callable[[AgentContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None: # type: ignore[override]
resolved_user_id: str | None = None
try:
@@ -92,7 +92,7 @@ class PurviewPolicyMiddleware(AgentMiddleware):
if not self._settings.ignore_exceptions:
raise
await call_next(context)
await call_next()
try:
# Post (response) check only if we have a normal AgentResponse
@@ -162,7 +162,7 @@ class PurviewChatPolicyMiddleware(ChatMiddleware):
async def process(
self,
context: ChatContext,
call_next: Callable[[ChatContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None: # type: ignore[override]
resolved_user_id: str | None = None
try:
@@ -187,7 +187,7 @@ class PurviewChatPolicyMiddleware(ChatMiddleware):
if not self._settings.ignore_exceptions:
raise
await call_next(context)
await call_next()
try:
# Post (response) evaluation only if non-streaming and we have messages result shape
@@ -49,7 +49,7 @@ class TestPurviewChatPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
next_called = False
async def mock_next(ctx: ChatContext) -> None:
async def mock_next() -> None:
nonlocal next_called
next_called = True
@@ -57,7 +57,7 @@ class TestPurviewChatPolicyMiddleware:
def __init__(self):
self.messages = [Message(role="assistant", text="Hi there")]
ctx.result = Result()
chat_context.result = Result()
await middleware.process(chat_context, mock_next)
assert next_called
@@ -67,7 +67,7 @@ class TestPurviewChatPolicyMiddleware:
async def test_blocks_prompt(self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext) -> None:
with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")):
async def mock_next(ctx: ChatContext) -> None: # should not run
async def mock_next() -> None: # should not run
raise AssertionError("next should not be called when prompt blocked")
with pytest.raises(MiddlewareTermination):
@@ -88,12 +88,12 @@ class TestPurviewChatPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", side_effect=side_effect):
async def mock_next(ctx: ChatContext) -> None:
async def mock_next() -> None:
class Result:
def __init__(self):
self.messages = [Message(role="assistant", text="Sensitive output")] # pragma: no cover
ctx.result = Result()
chat_context.result = Result()
await middleware.process(chat_context, mock_next)
assert call_state["count"] == 2
@@ -114,8 +114,8 @@ class TestPurviewChatPolicyMiddleware:
)
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
async def mock_next(ctx: ChatContext) -> None:
ctx.result = MagicMock()
async def mock_next() -> None:
streaming_context.result = MagicMock()
await middleware.process(streaming_context, mock_next)
assert mock_proc.call_count == 1
@@ -138,10 +138,10 @@ class TestPurviewChatPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx: ChatContext) -> None:
async def mock_next() -> None:
result = MagicMock()
result.messages = [Message(role="assistant", text="Response")]
ctx.result = result
chat_context.result = result
await middleware.process(chat_context, mock_next)
@@ -162,10 +162,10 @@ class TestPurviewChatPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx: ChatContext) -> None:
async def mock_next() -> None:
result = MagicMock()
result.messages = [Message(role="assistant", text="Response")]
ctx.result = result
chat_context.result = result
await middleware.process(chat_context, mock_next)
@@ -194,7 +194,7 @@ class TestPurviewChatPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx: ChatContext) -> None:
async def mock_next() -> None:
raise AssertionError("next should not be called")
# Should raise the exception
@@ -224,10 +224,10 @@ class TestPurviewChatPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", side_effect=side_effect):
async def mock_next(ctx: ChatContext) -> None:
async def mock_next() -> None:
result = MagicMock()
result.messages = [Message(role="assistant", text="OK")]
ctx.result = result
context.result = result
with pytest.raises(PurviewPaymentRequiredError):
await middleware.process(context, mock_next)
@@ -249,7 +249,7 @@ class TestPurviewChatPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx: ChatContext) -> None:
async def mock_next() -> None:
result = MagicMock()
result.messages = [Message(role="assistant", text="Response")]
context.result = result
@@ -265,9 +265,9 @@ class TestPurviewChatPolicyMiddleware:
"""Test middleware handles result that doesn't have messages attribute."""
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")):
async def mock_next(ctx: ChatContext) -> None:
async def mock_next() -> None:
# Set result to something without messages attribute
ctx.result = "Some string result"
chat_context.result = "Some string result"
await middleware.process(chat_context, mock_next)
@@ -289,7 +289,7 @@ class TestPurviewChatPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx: ChatContext) -> None:
async def mock_next() -> None:
result = MagicMock()
result.messages = [Message(role="assistant", text="Response")]
context.result = result
@@ -313,7 +313,7 @@ class TestPurviewChatPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", side_effect=ValueError("boom")):
async def mock_next(_: ChatContext) -> None:
async def mock_next() -> None:
raise AssertionError("next should not be called")
with pytest.raises(ValueError, match="boom"):
@@ -342,10 +342,10 @@ class TestPurviewChatPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", side_effect=side_effect):
async def mock_next(ctx: ChatContext) -> None:
async def mock_next() -> None:
result = MagicMock()
result.messages = [Message(role="assistant", text="OK")]
ctx.result = result
context.result = result
with pytest.raises(ValueError, match="post"):
await middleware.process(context, mock_next)
@@ -361,10 +361,10 @@ class TestPurviewChatPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
async def mock_next(ctx: ChatContext) -> None:
async def mock_next() -> None:
result = MagicMock()
result.messages = [Message(role="assistant", text="Hi")]
ctx.result = result
context.result = result
await middleware.process(context, mock_next)
@@ -382,10 +382,10 @@ class TestPurviewChatPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
async def mock_next(ctx: ChatContext) -> None:
async def mock_next() -> None:
result = MagicMock()
result.messages = [Message(role="assistant", text="Hi")]
ctx.result = result
context.result = result
await middleware.process(context, mock_next)
@@ -401,10 +401,10 @@ class TestPurviewChatPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
async def mock_next(ctx: ChatContext) -> None:
async def mock_next() -> None:
result = MagicMock()
result.messages = [Message(role="assistant", text="Response")]
ctx.result = result
context.result = result
await middleware.process(context, mock_next)
@@ -55,10 +55,10 @@ class TestPurviewPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")):
next_called = False
async def mock_next(ctx: AgentContext) -> None:
async def mock_next() -> None:
nonlocal next_called
next_called = True
ctx.result = AgentResponse(messages=[Message(role="assistant", text="I'm good, thanks!")])
context.result = AgentResponse(messages=[Message(role="assistant", text="I'm good, thanks!")])
await middleware.process(context, mock_next)
@@ -74,7 +74,7 @@ class TestPurviewPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")):
next_called = False
async def mock_next(ctx: AgentContext) -> None:
async def mock_next() -> None:
nonlocal next_called
next_called = True
@@ -101,8 +101,8 @@ class TestPurviewPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx: AgentContext) -> None:
ctx.result = AgentResponse(
async def mock_next() -> None:
context.result = AgentResponse(
messages=[Message(role="assistant", text="Here's some sensitive information")]
)
@@ -125,8 +125,8 @@ class TestPurviewPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")):
async def mock_next(ctx: AgentContext) -> None:
ctx.result = "Some non-standard result"
async def mock_next() -> None:
context.result = "Some non-standard result"
await middleware.process(context, mock_next)
@@ -142,8 +142,8 @@ class TestPurviewPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_process:
async def mock_next(ctx: AgentContext) -> None:
ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
async def mock_next() -> None:
context.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
await middleware.process(context, mock_next)
@@ -160,8 +160,8 @@ class TestPurviewPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
async def mock_next(ctx: AgentContext) -> None:
ctx.result = AgentResponse(messages=[Message(role="assistant", text="streaming")])
async def mock_next() -> None:
context.result = AgentResponse(messages=[Message(role="assistant", text="streaming")])
await middleware.process(context, mock_next)
@@ -181,7 +181,7 @@ class TestPurviewPolicyMiddleware:
side_effect=PurviewPaymentRequiredError("Payment required"),
):
async def mock_next(_: AgentContext) -> None:
async def mock_next() -> None:
raise AssertionError("next should not be called")
with pytest.raises(PurviewPaymentRequiredError):
@@ -206,8 +206,8 @@ class TestPurviewPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", side_effect=side_effect):
async def mock_next(ctx: AgentContext) -> None:
ctx.result = AgentResponse(messages=[Message(role="assistant", text="OK")])
async def mock_next() -> None:
context.result = AgentResponse(messages=[Message(role="assistant", text="OK")])
with pytest.raises(PurviewPaymentRequiredError):
await middleware.process(context, mock_next)
@@ -231,8 +231,8 @@ class TestPurviewPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", side_effect=side_effect):
async def mock_next(ctx: AgentContext) -> None:
ctx.result = AgentResponse(messages=[Message(role="assistant", text="OK")])
async def mock_next() -> None:
context.result = AgentResponse(messages=[Message(role="assistant", text="OK")])
with pytest.raises(ValueError, match="Post-check blew up"):
await middleware.process(context, mock_next)
@@ -250,8 +250,8 @@ class TestPurviewPolicyMiddleware:
middleware._processor, "process_messages", side_effect=Exception("Pre-check error")
) as mock_process:
async def mock_next(ctx: AgentContext) -> None:
ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
async def mock_next() -> None:
context.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
await middleware.process(context, mock_next)
@@ -280,8 +280,8 @@ class TestPurviewPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx: AgentContext) -> None:
ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
async def mock_next() -> None:
context.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
await middleware.process(context, mock_next)
@@ -306,8 +306,8 @@ class TestPurviewPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx):
ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
async def mock_next():
context.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
# Should not raise, just log
await middleware.process(context, mock_next)
@@ -330,7 +330,7 @@ class TestPurviewPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx):
async def mock_next():
pass
# Should raise the exception
@@ -346,8 +346,8 @@ class TestPurviewPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
async def mock_next(ctx: AgentContext) -> None:
ctx.result = AgentResponse(messages=[Message(role="assistant", text="Hi")])
async def mock_next() -> None:
context.result = AgentResponse(messages=[Message(role="assistant", text="Hi")])
await middleware.process(context, mock_next)
@@ -364,8 +364,8 @@ class TestPurviewPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
async def mock_next(ctx: AgentContext) -> None:
ctx.result = AgentResponse(messages=[Message(role="assistant", text="Hi")])
async def mock_next() -> None:
context.result = AgentResponse(messages=[Message(role="assistant", text="Hi")])
await middleware.process(context, mock_next)
@@ -383,8 +383,8 @@ class TestPurviewPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
async def mock_next(ctx: AgentContext) -> None:
ctx.result = AgentResponse(messages=[Message(role="assistant", text="Hi")])
async def mock_next() -> None:
context.result = AgentResponse(messages=[Message(role="assistant", text="Hi")])
await middleware.process(context, mock_next)
@@ -399,8 +399,8 @@ class TestPurviewPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
async def mock_next(ctx: AgentContext) -> None:
ctx.result = AgentResponse(messages=[Message(role="assistant", text="Hi")])
async def mock_next() -> None:
context.result = AgentResponse(messages=[Message(role="assistant", text="Hi")])
await middleware.process(context, mock_next)
@@ -416,8 +416,8 @@ class TestPurviewPolicyMiddleware:
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
async def mock_next(ctx: AgentContext) -> None:
ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
async def mock_next() -> None:
context.result = AgentResponse(messages=[Message(role="assistant", text="Response")])
await middleware.process(context, mock_next)
+2 -1
View File
@@ -5,7 +5,8 @@
"**/autogen-migration/**",
"**/semantic-kernel-migration/**",
"**/demos/**",
"**/agent_with_foundry_tracing.py"
"**/agent_with_foundry_tracing.py",
"**/azure_responses_client_with_foundry.py"
],
"typeCheckingMode": "off",
"reportMissingImports": "error",
+1
View File
@@ -78,6 +78,7 @@ This directory contains samples demonstrating the capabilities of Microsoft Agen
| [`getting_started/agents/azure_openai/azure_responses_client_image_analysis.py`](./getting_started/agents/azure_openai/azure_responses_client_image_analysis.py) | Azure OpenAI Responses Client with Image Analysis Example |
| [`getting_started/agents/azure_openai/azure_responses_client_with_code_interpreter.py`](./getting_started/agents/azure_openai/azure_responses_client_with_code_interpreter.py) | Azure OpenAI Responses Client with Code Interpreter Example |
| [`getting_started/agents/azure_openai/azure_responses_client_with_explicit_settings.py`](./getting_started/agents/azure_openai/azure_responses_client_with_explicit_settings.py) | Azure OpenAI Responses Client with Explicit Settings Example |
| [`getting_started/agents/azure_openai/azure_responses_client_with_foundry.py`](./getting_started/agents/azure_openai/azure_responses_client_with_foundry.py) | Azure OpenAI Responses Client with Foundry Project Example |
| [`getting_started/agents/azure_openai/azure_responses_client_with_function_tools.py`](./getting_started/agents/azure_openai/azure_responses_client_with_function_tools.py) | Azure OpenAI Responses Client with Function Tools Example |
| [`getting_started/agents/azure_openai/azure_responses_client_with_hosted_mcp.py`](./getting_started/agents/azure_openai/azure_responses_client_with_hosted_mcp.py) | Azure OpenAI Responses Client with Hosted Model Context Protocol (MCP) Example |
| [`getting_started/agents/azure_openai/azure_responses_client_with_local_mcp.py`](./getting_started/agents/azure_openai/azure_responses_client_with_local_mcp.py) | Azure OpenAI Responses Client with local Model Context Protocol (MCP) Example |
+8 -8
View File
@@ -267,7 +267,7 @@ class TerminatingMiddleware(FunctionMiddleware):
if self.should_terminate(context):
context.result = "terminated by middleware"
raise MiddlewareTermination # Exit function invocation loop
await call_next(context)
await call_next()
```
## Arguments Added/Altered at Each Layer
@@ -347,7 +347,7 @@ class CachingMiddleware(FunctionMiddleware):
return # Upstream post-processing still runs
# Option B: Call call_next, then return normally
await call_next(context)
await call_next()
self.cache[context.function.name] = context.result
return # Normal completion
```
@@ -362,7 +362,7 @@ class BlockedFunctionMiddleware(FunctionMiddleware):
if context.function.name in self.blocked_functions:
context.result = "Function blocked by policy"
raise MiddlewareTermination("Blocked") # Skips ALL post-processing
await call_next(context)
await call_next()
```
### 3. Raise Any Other Exception
@@ -374,7 +374,7 @@ class ValidationMiddleware(FunctionMiddleware):
async def process(self, context: FunctionInvocationContext, call_next):
if not self.is_valid(context.arguments):
raise ValueError("Invalid arguments") # Bubbles up to user
await call_next(context)
await call_next()
```
## `return` vs `raise MiddlewareTermination`
@@ -385,7 +385,7 @@ The key difference is what happens to **upstream middleware's post-processing**:
class MiddlewareA(AgentMiddleware):
async def process(self, context, call_next):
print("A: before")
await call_next(context)
await call_next()
print("A: after") # Does this run?
class MiddlewareB(AgentMiddleware):
@@ -410,7 +410,7 @@ With middleware registered as `[MiddlewareA, MiddlewareB]`:
## Calling `call_next()` or Not
The decision to call `call_next(context)` determines whether downstream middleware and the actual operation execute:
The decision to call `call_next()` determines whether downstream middleware and the actual operation execute:
### Without calling `call_next()` - Skip downstream
@@ -430,7 +430,7 @@ async def process(self, context, call_next):
```python
async def process(self, context, call_next):
# Pre-processing
await call_next(context) # Execute downstream + actual operation
await call_next() # Execute downstream + actual operation
# Post-processing (context.result now contains real result)
return
```
@@ -450,7 +450,7 @@ async def process(self, context, call_next):
| `raise MiddlewareTermination` | Yes | ✅ | ✅ | ❌ No |
| `raise OtherException` | Either | Depends | Depends | ❌ No (exception propagates) |
> **Note:** The first row (`return` after calling `call_next()`) is the default behavior. Python functions implicitly return `None` at the end, so simply calling `await call_next(context)` without an explicit `return` statement achieves this pattern.
> **Note:** The first row (`return` after calling `call_next()`) is the default behavior. Python functions implicitly return `None` at the end, so simply calling `await call_next()` without an explicit `return` statement achieves this pattern.
## Streaming vs Non-Streaming
@@ -20,13 +20,13 @@ multiple specialized agents, each focusing on specific tasks.
async def logging_middleware(
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""MiddlewareTypes that logs tool invocations to show the delegation flow."""
print(f"[Calling tool: {context.function.name}]")
print(f"[Request: {context.arguments}]")
await call_next(context)
await call_next()
print(f"[Response: {context.result}]")
@@ -22,6 +22,7 @@ This folder contains examples demonstrating different ways to create and use age
| [`azure_responses_client_with_code_interpreter.py`](azure_responses_client_with_code_interpreter.py) | Shows how to use `AzureOpenAIResponsesClient.get_code_interpreter_tool()` with Azure agents to write and execute Python code. Includes helper methods for accessing code interpreter data from response chunks. |
| [`azure_responses_client_with_explicit_settings.py`](azure_responses_client_with_explicit_settings.py) | Shows how to initialize an agent with a specific responses client, configuring settings explicitly including endpoint and deployment name. |
| [`azure_responses_client_with_file_search.py`](azure_responses_client_with_file_search.py) | Demonstrates using `AzureOpenAIResponsesClient.get_file_search_tool()` with Azure OpenAI Responses Client for direct document-based question answering and information retrieval from vector stores. |
| [`azure_responses_client_with_foundry.py`](azure_responses_client_with_foundry.py) | Shows how to create an agent using an Azure AI Foundry project endpoint instead of a direct Azure OpenAI endpoint. Requires the `azure-ai-projects` package. |
| [`azure_responses_client_with_function_tools.py`](azure_responses_client_with_function_tools.py) | Demonstrates how to use function tools with agents. Shows both agent-level tools (defined when creating the agent) and query-level tools (provided with specific queries). |
| [`azure_responses_client_with_hosted_mcp.py`](azure_responses_client_with_hosted_mcp.py) | Shows how to integrate Azure OpenAI Responses Client with hosted Model Context Protocol (MCP) servers using `AzureOpenAIResponsesClient.get_mcp_tool()` for extended functionality. |
| [`azure_responses_client_with_local_mcp.py`](azure_responses_client_with_local_mcp.py) | Shows how to integrate Azure OpenAI Responses Client with local Model Context Protocol (MCP) servers using MCPStreamableHTTPTool for extended functionality. |
@@ -35,6 +36,9 @@ Make sure to set the following environment variables before running the examples
- `AZURE_OPENAI_CHAT_DEPLOYMENT_NAME`: The name of your Azure OpenAI chat model deployment
- `AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME`: The name of your Azure OpenAI Responses deployment
For the Foundry project sample (`azure_responses_client_with_foundry.py`), also set:
- `AZURE_AI_PROJECT_ENDPOINT`: Your Azure AI Foundry project endpoint
Optionally, you can set:
- `AZURE_OPENAI_API_VERSION`: The API version to use (default is `2024-02-15-preview`)
- `AZURE_OPENAI_API_KEY`: Your Azure OpenAI API key (if not using `AzureCliCredential`)
@@ -0,0 +1,113 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import os
from random import randint
from typing import Annotated
from agent_framework import tool
from agent_framework.azure import AzureOpenAIResponsesClient
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
from pydantic import Field
"""
Azure OpenAI Responses Client with Foundry Project Example
This sample demonstrates how to create an AzureOpenAIResponsesClient using an
Azure AI Foundry project endpoint. Instead of providing an Azure OpenAI endpoint
directly, you provide a Foundry project endpoint and the client is created via
the Azure AI Foundry project SDK.
This requires:
- The `azure-ai-projects` package to be installed.
- The `AZURE_AI_PROJECT_ENDPOINT` environment variable set to your Foundry project endpoint.
- The `AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME` environment variable set to the model deployment name.
"""
load_dotenv() # Load environment variables from .env file if present
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/getting_started/tools/function_tool_with_approval.py and samples/getting_started/tools/function_tool_with_approval_and_threads.py.
@tool(approval_mode="never_require")
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
async def non_streaming_example() -> None:
"""Example of non-streaming response (get the complete result at once)."""
print("=== Non-streaming Response Example ===")
# 1. Create the AzureOpenAIResponsesClient using a Foundry project endpoint.
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
credential = AzureCliCredential()
agent = AzureOpenAIResponsesClient(
project_endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"],
deployment_name=os.environ["AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME"],
credential=credential,
).as_agent(
instructions="You are a helpful weather agent.",
tools=get_weather,
)
# 2. Run a query and print the result.
query = "What's the weather like in Seattle?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Result: {result}\n")
async def streaming_example() -> None:
"""Example of streaming response (get results as they are generated)."""
print("=== Streaming Response Example ===")
# 1. Create the AzureOpenAIResponsesClient using a Foundry project endpoint.
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
credential = AzureCliCredential()
agent = AzureOpenAIResponsesClient(
project_endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"],
deployment_name=os.environ["AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME"],
credential=credential,
).as_agent(
instructions="You are a helpful weather agent.",
tools=get_weather,
)
# 2. Stream the response and print each chunk as it arrives.
query = "What's the weather like in Portland?"
print(f"User: {query}")
print("Agent: ", end="", flush=True)
async for chunk in agent.run(query, stream=True):
if chunk.text:
print(chunk.text, end="", flush=True)
print("\n")
async def main() -> None:
print("=== Azure OpenAI Responses Client with Foundry Project Example ===")
await non_streaming_example()
await streaming_example()
if __name__ == "__main__":
asyncio.run(main())
"""
Sample output:
=== Azure OpenAI Responses Client with Foundry Project Example ===
=== Non-streaming Response Example ===
User: What's the weather like in Seattle?
Result: The weather in Seattle is cloudy with a high of 18°C.
=== Streaming Response Example ===
User: What's the weather like in Portland?
Agent: The weather in Portland is sunny with a high of 25°C.
"""
@@ -29,7 +29,7 @@ response generation, showing both streaming and non-streaming responses.
@chat_middleware
async def security_and_override_middleware(
context: ChatContext,
call_next: Callable[[ChatContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Function-based middleware that implements security filtering and response override."""
print("[SecurityMiddleware] Processing input...")
@@ -60,7 +60,7 @@ async def security_and_override_middleware(
raise MiddlewareTermination(result=context.result)
# Continue to next middleware or AI execution
await call_next(context)
await call_next()
print("[SecurityMiddleware] Response generated.")
print(type(context.result))
@@ -19,13 +19,13 @@ multiple specialized agents, each focusing on specific tasks.
async def logging_middleware(
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""MiddlewareTypes that logs tool invocations to show the delegation flow."""
print(f"[Calling tool: {context.function.name}]")
print(f"[Request: {context.arguments}]")
await call_next(context)
await call_next()
print(f"[Response: {context.result}]")
@@ -38,7 +38,7 @@ def cleanup_resources():
@chat_middleware
async def security_filter_middleware(
context: ChatContext,
call_next: Callable[[ChatContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Chat middleware that blocks requests containing sensitive information."""
blocked_terms = ["password", "secret", "api_key", "token"]
@@ -80,13 +80,13 @@ async def security_filter_middleware(
raise MiddlewareTermination(result=context.result)
await call_next(context)
await call_next()
@function_middleware
async def atlantis_location_filter_middleware(
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Function middleware that blocks weather requests for Atlantis."""
# Check if location parameter is "atlantis"
@@ -98,7 +98,7 @@ async def atlantis_location_filter_middleware(
)
raise MiddlewareTermination(result=context.result)
await call_next(context)
await call_next()
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/getting_started/tools/function_tool_with_approval.py and samples/getting_started/tools/function_tool_with_approval_and_threads.py.
@@ -68,7 +68,7 @@ def get_weather(
class SecurityAgentMiddleware(AgentMiddleware):
"""Agent-level security middleware that validates all requests."""
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
print("[SecurityMiddleware] Checking security for all requests...")
# Check for security violations in the last user message
@@ -81,18 +81,18 @@ class SecurityAgentMiddleware(AgentMiddleware):
print("[SecurityMiddleware] Security check passed.")
context.metadata["security_validated"] = True
await call_next(context)
await call_next()
async def performance_monitor_middleware(
context: AgentContext,
call_next: Callable[[AgentContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Agent-level performance monitoring for all runs."""
print("[PerformanceMonitor] Starting performance monitoring...")
start_time = time.time()
await call_next(context)
await call_next()
end_time = time.time()
duration = end_time - start_time
@@ -104,7 +104,7 @@ async def performance_monitor_middleware(
class HighPriorityMiddleware(AgentMiddleware):
"""Run-level middleware for high priority requests."""
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
print("[HighPriority] Processing high priority request with expedited handling...")
# Read metadata set by agent-level middleware
@@ -115,13 +115,13 @@ class HighPriorityMiddleware(AgentMiddleware):
context.metadata["priority"] = "high"
context.metadata["expedited"] = True
await call_next(context)
await call_next()
print("[HighPriority] High priority processing completed")
async def debugging_middleware(
context: AgentContext,
call_next: Callable[[AgentContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Run-level debugging middleware for troubleshooting specific runs."""
print("[Debug] Debug mode enabled for this run")
@@ -134,7 +134,7 @@ async def debugging_middleware(
context.metadata["debug_enabled"] = True
await call_next(context)
await call_next()
print("[Debug] Debug information collected")
@@ -145,7 +145,7 @@ class CachingMiddleware(AgentMiddleware):
def __init__(self) -> None:
self.cache: dict[str, AgentResponse] = {}
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Create a simple cache key from the last message
last_message = context.messages[-1] if context.messages else None
cache_key: str = last_message.text if last_message and last_message.text else "no_message"
@@ -158,7 +158,7 @@ class CachingMiddleware(AgentMiddleware):
print(f"[Cache] Cache MISS for: '{cache_key[:30]}...'")
context.metadata["cache_key"] = cache_key
await call_next(context)
await call_next()
# Cache the result if we have one
if context.result:
@@ -168,14 +168,14 @@ class CachingMiddleware(AgentMiddleware):
async def function_logging_middleware(
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Function middleware that logs all function calls."""
function_name = context.function.name
args = context.arguments
print(f"[FunctionLog] Calling function: {function_name} with args: {args}")
await call_next(context)
await call_next()
print(f"[FunctionLog] Function {function_name} completed")
@@ -275,7 +275,7 @@ async def main() -> None:
query = "What's the secret weather password for Berlin?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'Request was blocked by security middleware'}")
print(f"Agent: {result.text if result and result.text else 'Request was blocked by security middleware'}")
print()
# Run 7: Normal query again (no run-level middleware interference)
@@ -57,7 +57,7 @@ class InputObserverMiddleware(ChatMiddleware):
async def process(
self,
context: ChatContext,
call_next: Callable[[ChatContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Observe and modify input messages before they are sent to AI."""
print("[InputObserverMiddleware] Observing input messages:")
@@ -91,7 +91,7 @@ class InputObserverMiddleware(ChatMiddleware):
context.messages[:] = modified_messages
# Continue to next middleware or AI execution
await call_next(context)
await call_next()
# Observe that processing is complete
print("[InputObserverMiddleware] Processing completed")
@@ -100,7 +100,7 @@ class InputObserverMiddleware(ChatMiddleware):
@chat_middleware
async def security_and_override_middleware(
context: ChatContext,
call_next: Callable[[ChatContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Function-based middleware that implements security filtering and response override."""
print("[SecurityMiddleware] Processing input...")
@@ -131,7 +131,7 @@ async def security_and_override_middleware(
raise MiddlewareTermination
# Continue to next middleware or AI execution
await call_next(context)
await call_next()
async def class_based_chat_middleware() -> None:
@@ -50,7 +50,7 @@ class SecurityAgentMiddleware(AgentMiddleware):
async def process(
self,
context: AgentContext,
call_next: Callable[[AgentContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
# Check for potential security violations in the query
# Look at the last user message
@@ -67,7 +67,7 @@ class SecurityAgentMiddleware(AgentMiddleware):
return
print("[SecurityAgentMiddleware] Security check passed.")
await call_next(context)
await call_next()
class LoggingFunctionMiddleware(FunctionMiddleware):
@@ -76,14 +76,14 @@ class LoggingFunctionMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
function_name = context.function.name
print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.")
start_time = time.time()
await call_next(context)
await call_next()
end_time = time.time()
duration = end_time - start_time
@@ -53,7 +53,7 @@ def get_current_time() -> str:
async def simple_agent_middleware(context, call_next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality
"""Agent middleware that runs before and after agent execution."""
print("[Agent MiddlewareTypes] Before agent execution")
await call_next(context)
await call_next()
print("[Agent MiddlewareTypes] After agent execution")
@@ -61,7 +61,7 @@ async def simple_agent_middleware(context, call_next): # type: ignore - paramet
async def simple_function_middleware(context, call_next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality
"""Function middleware that runs before and after function calls."""
print(f"[Function MiddlewareTypes] Before calling: {context.function.name}") # type: ignore
await call_next(context)
await call_next()
print(f"[Function MiddlewareTypes] After calling: {context.function.name}") # type: ignore
@@ -35,13 +35,13 @@ def unstable_data_service(
async def exception_handling_middleware(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
function_name = context.function.name
try:
print(f"[ExceptionHandlingMiddleware] Executing function: {function_name}")
await call_next(context)
await call_next()
print(f"[ExceptionHandlingMiddleware] Function {function_name} completed successfully.")
except TimeoutError as e:
print(f"[ExceptionHandlingMiddleware] Caught TimeoutError: {e}")
@@ -43,7 +43,7 @@ def get_weather(
async def security_agent_middleware(
context: AgentContext,
call_next: Callable[[AgentContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Agent middleware that checks for security violations."""
# Check for potential security violations in the query
@@ -57,12 +57,12 @@ async def security_agent_middleware(
return
print("[SecurityAgentMiddleware] Security check passed.")
await call_next(context)
await call_next()
async def logging_function_middleware(
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Function middleware that logs function calls."""
function_name = context.function.name
@@ -70,7 +70,7 @@ async def logging_function_middleware(
start_time = time.time()
await call_next(context)
await call_next()
end_time = time.time()
duration = end_time - start_time
@@ -105,7 +105,7 @@ async def main() -> None:
query = "What's the secret weather password?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response'}\n")
print(f"Agent: {result.text if result and result.text else 'No response'}\n")
if __name__ == "__main__":
@@ -49,7 +49,7 @@ class PreTerminationMiddleware(AgentMiddleware):
async def process(
self,
context: AgentContext,
call_next: Callable[[AgentContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
# Check if the user message contains any blocked words
last_message = context.messages[-1] if context.messages else None
@@ -75,7 +75,7 @@ class PreTerminationMiddleware(AgentMiddleware):
# Terminate to prevent further processing
raise MiddlewareTermination(result=context.result)
await call_next(context)
await call_next()
class PostTerminationMiddleware(AgentMiddleware):
@@ -88,7 +88,7 @@ class PostTerminationMiddleware(AgentMiddleware):
async def process(
self,
context: AgentContext,
call_next: Callable[[AgentContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
print(f"[PostTerminationMiddleware] Processing request (response count: {self.response_count})")
@@ -101,7 +101,7 @@ class PostTerminationMiddleware(AgentMiddleware):
raise MiddlewareTermination
# Allow the agent to process normally
await call_next(context)
await call_next()
# Increment response count after processing
self.response_count += 1
@@ -158,14 +158,14 @@ async def post_termination_middleware() -> None:
query = "What about the weather in London?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response (terminated)'}")
print(f"Agent: {result.text if result and result.text else 'No response (terminated)'}")
# Third run (should also be terminated)
print("\n3. Third run (should also be terminated):")
query = "And New York?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response (terminated)'}")
print(f"Agent: {result.text if result and result.text else 'No response (terminated)'}")
async def main() -> None:
@@ -49,11 +49,11 @@ def get_weather(
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
async def weather_override_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def weather_override_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
"""Chat middleware that overrides weather results for both streaming and non-streaming cases."""
# Let the original agent execution complete first
await call_next(context)
await call_next()
# Check if there's a result to override (agent called weather function)
if context.result is not None:
@@ -84,9 +84,9 @@ async def weather_override_middleware(context: ChatContext, call_next: Callable[
context.result = ChatResponse(messages=[Message(role=Role.ASSISTANT, text=custom_message)])
async def validate_weather_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def validate_weather_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
"""Chat middleware that simulates result validation for both streaming and non-streaming cases."""
await call_next(context)
await call_next()
validation_note = "Validation: weather data verified."
@@ -104,9 +104,9 @@ async def validate_weather_middleware(context: ChatContext, call_next: Callable[
context.result.messages.append(Message(role=Role.ASSISTANT, text=validation_note))
async def agent_cleanup_middleware(context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def agent_cleanup_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
"""Agent middleware that validates chat middleware effects and cleans the result."""
await call_next(context)
await call_next()
if context.result is None:
return
@@ -54,7 +54,7 @@ class SessionContextContainer:
async def inject_context_middleware(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""MiddlewareTypes that extracts runtime context from kwargs and stores in container.
@@ -74,7 +74,7 @@ class SessionContextContainer:
print(f" - Session Metadata Keys: {list(self.session_metadata.keys())}")
# Continue to tool execution
await call_next(context)
await call_next()
# Create a container instance that will be shared via closure
@@ -278,19 +278,19 @@ async def pattern_2_hierarchical_with_kwargs_propagation() -> None:
@function_middleware
async def email_kwargs_tracker(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
email_agent_kwargs.update(context.kwargs)
print(f"[EmailAgent] Received runtime context: {list(context.kwargs.keys())}")
await call_next(context)
await call_next()
@function_middleware
async def sms_kwargs_tracker(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
sms_agent_kwargs.update(context.kwargs)
print(f"[SMSAgent] Received runtime context: {list(context.kwargs.keys())}")
await call_next(context)
await call_next()
client = OpenAIChatClient(model_id="gpt-4o-mini")
@@ -359,7 +359,7 @@ class AuthContextMiddleware:
self.validated_tokens: list[str] = []
async def validate_and_track(
self, context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
self, context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
"""Validate API token and track usage."""
api_token = context.kwargs.get("api_token")
@@ -375,7 +375,7 @@ class AuthContextMiddleware:
else:
print("[AuthMiddleware] No API token provided")
await call_next(context)
await call_next()
@tool(approval_mode="never_require")
@@ -57,7 +57,7 @@ class MiddlewareContainer:
async def call_counter_middleware(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""First middleware: increments call count in shared state."""
# Increment the shared call count
@@ -66,18 +66,18 @@ class MiddlewareContainer:
print(f"[CallCounter] This is function call #{self.call_count}")
# Call the next middleware/function
await call_next(context)
await call_next()
async def result_enhancer_middleware(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Second middleware: uses shared call count to enhance function results."""
print(f"[ResultEnhancer] Current total calls so far: {self.call_count}")
# Call the next middleware/function
await call_next(context)
await call_next()
# After function execution, enhance the result using shared state
if context.result:
@@ -46,7 +46,7 @@ def get_weather(
async def thread_tracking_middleware(
context: AgentContext,
call_next: Callable[[AgentContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""MiddlewareTypes that tracks and logs thread behavior across runs."""
thread_messages = []
@@ -57,7 +57,7 @@ async def thread_tracking_middleware(
print(f"[MiddlewareTypes pre-execution] Thread history messages: {len(thread_messages)}")
# Call call_next to execute the agent
await call_next(context)
await call_next()
# Check thread state after agent execution
updated_thread_messages = []
+2 -2
View File
@@ -209,7 +209,6 @@ dependencies = [
{ name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "aiohttp", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "azure-ai-agents", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "azure-ai-projects", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
]
[package.metadata]
@@ -217,7 +216,6 @@ requires-dist = [
{ name = "agent-framework-core", editable = "packages/core" },
{ name = "aiohttp" },
{ name = "azure-ai-agents", specifier = "==1.2.0b5" },
{ name = "azure-ai-projects", specifier = ">=2.0.0b3" },
]
[[package]]
@@ -324,6 +322,7 @@ name = "agent-framework-core"
version = "1.0.0b260210"
source = { editable = "packages/core" }
dependencies = [
{ name = "azure-ai-projects", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "azure-identity", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "mcp", extra = ["ws"], marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
@@ -378,6 +377,7 @@ requires-dist = [
{ name = "agent-framework-orchestrations", marker = "extra == 'all'", editable = "packages/orchestrations" },
{ name = "agent-framework-purview", marker = "extra == 'all'", editable = "packages/purview" },
{ name = "agent-framework-redis", marker = "extra == 'all'", editable = "packages/redis" },
{ name = "azure-ai-projects", specifier = ">=2.0.0b3" },
{ name = "azure-identity", specifier = ">=1,<2" },
{ name = "mcp", extras = ["ws"], specifier = ">=1.24.0,<2" },
{ name = "openai", specifier = ">=1.99.0" },