diff --git a/python/packages/main/agent_framework/_middleware.py b/python/packages/main/agent_framework/_middleware.py index 7150881f6b..ba157dcc28 100644 --- a/python/packages/main/agent_framework/_middleware.py +++ b/python/packages/main/agent_framework/_middleware.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterable, Awaitable, Callable from dataclasses import dataclass, field +from enum import Enum from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage @@ -15,12 +16,22 @@ if TYPE_CHECKING: TAgent = TypeVar("TAgent", bound="AgentProtocol") + +class MiddlewareType(Enum): + """Enum representing the type of middleware.""" + + AGENT = "agent" + FUNCTION = "function" + + __all__ = [ "AgentMiddleware", "AgentRunContext", "FunctionInvocationContext", "FunctionMiddleware", "Middleware", + "agent_middleware", + "function_middleware", "use_agent_middleware", ] @@ -38,6 +49,8 @@ class AgentRunContext: to see the actual execution result or can be set to override the execution result. For non-streaming: should be AgentRunResponse For streaming: should be AsyncIterable[AgentRunResponseUpdate] + terminate: A flag indicating whether to terminate execution after current middleware. + When set to True, execution will stop as soon as control returns to framework. """ agent: "AgentProtocol" @@ -45,6 +58,7 @@ class AgentRunContext: is_streaming: bool = False metadata: dict[str, Any] = field(default_factory=lambda: {}) result: AgentRunResponse | AsyncIterable[AgentRunResponseUpdate] | None = None + terminate: bool = False @dataclass @@ -57,12 +71,15 @@ class FunctionInvocationContext: metadata: Metadata dictionary for sharing data between function middleware. result: Function execution result. Can be observed after calling next() to see the actual execution result or can be set to override the execution result. + terminate: A flag indicating whether to terminate execution after current middleware. + When set to True, execution will stop as soon as control returns to framework. """ function: "AIFunction[Any, Any]" arguments: "BaseModel" metadata: dict[str, Any] = field(default_factory=lambda: {}) result: Any = None + terminate: bool = False class AgentMiddleware(ABC): @@ -131,6 +148,53 @@ FunctionMiddlewareCallable = Callable[ Middleware: TypeAlias = AgentMiddleware | AgentMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable +# Middleware type markers for decorators +def agent_middleware(func: AgentMiddlewareCallable) -> AgentMiddlewareCallable: + """Decorator to mark a function as agent middleware. + + This decorator explicitly identifies a function as agent middleware, + which processes AgentRunContext objects. + + Args: + func: The middleware function to mark as agent middleware. + + Returns: + The same function with agent middleware marker. + + Example: + @agent_middleware + async def my_middleware(context: AgentRunContext, next): + # Process agent invocation + await next(context) + """ + # Add marker attribute to identify this as agent middleware + func._middleware_type: MiddlewareType = MiddlewareType.AGENT # type: ignore + return func + + +def function_middleware(func: FunctionMiddlewareCallable) -> FunctionMiddlewareCallable: + """Decorator to mark a function as function middleware. + + This decorator explicitly identifies a function as function middleware, + which processes FunctionInvocationContext objects. + + Args: + func: The middleware function to mark as function middleware. + + Returns: + The same function with function middleware marker. + + Example: + @function_middleware + async def my_middleware(context: FunctionInvocationContext, next): + # Process function invocation + await next(context) + """ + # Add marker attribute to identify this as function middleware + func._middleware_type: MiddlewareType = MiddlewareType.FUNCTION # type: ignore + return func + + class AgentMiddlewareWrapper(AgentMiddleware): """Wrapper to convert pure functions into AgentMiddleware protocol objects.""" @@ -271,6 +335,10 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline): if index >= len(self._middlewares): async def final_wrapper(c: AgentRunContext) -> None: + # If terminate was set, skip execution + if c.terminate: + return + # Execute actual handler and populate context for observability result = await final_handler(c) result_container["result"] = result @@ -282,6 +350,10 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline): next_handler = create_next_handler(index + 1) async def current_handler(c: AgentRunContext) -> None: + # If terminate is set, don't continue the pipeline + if c.terminate: + return + await middleware.process(c, next_handler) # After middleware execution, check if response was overridden if c.result is not None and isinstance(c.result, AgentRunResponse): @@ -337,6 +409,10 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline): if index >= len(self._middlewares): async def final_wrapper(c: AgentRunContext) -> None: # noqa: RUF029 + # If terminate was set, skip execution + if c.terminate: + return + # Execute actual handler and populate context for observability result = final_handler(c) result_container["result_stream"] = result @@ -349,6 +425,9 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline): async def current_handler(c: AgentRunContext) -> None: await middleware.process(c, next_handler) + # If terminate is set, don't continue the pipeline + if c.terminate: + return return current_handler @@ -424,8 +503,8 @@ class FunctionMiddlewarePipeline(BaseMiddlewarePipeline): # Custom final handler that handles pre-existing results async def function_final_handler(c: FunctionInvocationContext) -> Any: - # If result was set before calling next(), skip execution - if c.result is not None: + # If terminate was set, skip execution and return the result (which might be None) + if c.terminate: return c.result # Execute actual handler and populate context for observability return await final_handler(c) @@ -446,6 +525,10 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]: This decorator adds middleware functionality to any agent class. It wraps the run() and run_stream() methods to provide middleware execution. + The middleware execution can be terminated at any point by setting the + context.terminate property to True. Once set, the pipeline will stop executing + further middleware as soon as control returns to the pipeline. + Args: agent_class: The agent class to add middleware support to. @@ -458,12 +541,101 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]: original_run = agent_class.run # type: ignore[attr-defined] original_run_stream = agent_class.run_stream # type: ignore[attr-defined] - def _initialize_middleware_pipelines(self: Any, middlewares: Middleware | list[Middleware] | None) -> None: - """Initialize agent and function middleware pipelines from the provided middleware list.""" - if not middlewares: - return + def _determine_middleware_type(middleware: Any) -> MiddlewareType: + """Determine middleware type using decorator and/or parameter type annotation. - middleware_list: list[Middleware] = middlewares if isinstance(middlewares, list) else [middlewares] # type: ignore + Args: + middleware: The middleware function to analyze. + + Returns: + MiddlewareType.AGENT or MiddlewareType.FUNCTION indicating the middleware type. + + Raises: + ValueError: When middleware type cannot be determined or there's a mismatch. + """ + # Check for decorator marker + decorator_type: MiddlewareType | None = getattr(middleware, "_middleware_type", None) + + # Check for parameter type annotation + param_type: MiddlewareType | None = None + try: + sig = inspect.signature(middleware) + params = list(sig.parameters.values()) + + # Must have at least 2 parameters (context and next) + if len(params) >= 2: + first_param = params[0] + if hasattr(first_param.annotation, "__name__"): + annotation_name = first_param.annotation.__name__ + if annotation_name == "AgentRunContext": + param_type = MiddlewareType.AGENT + elif annotation_name == "FunctionInvocationContext": + param_type = MiddlewareType.FUNCTION + else: + # Not enough parameters - can't be valid middleware + raise ValueError( + f"Middleware function must have at least 2 parameters (context, next), " + f"but {middleware.__name__} has {len(params)}" + ) + except Exception as e: + if isinstance(e, ValueError): + raise # Re-raise our custom errors + # Signature inspection failed - continue with other checks + pass + + if decorator_type and param_type: + # Both decorator and parameter type specified - they must match + if decorator_type != param_type: + raise ValueError( + f"Middleware type mismatch: decorator indicates '{decorator_type.value}' " + f"but parameter type indicates '{param_type.value}' for function {middleware.__name__}" + ) + return decorator_type + + if decorator_type: + # Just decorator specified - rely on decorator + return decorator_type + + if param_type: + # Just parameter type specified - rely on types + return param_type + + # Neither decorator nor parameter type specified - throw exception + raise ValueError( + f"Cannot determine middleware type for function {middleware.__name__}. " + f"Please either use @agent_middleware/@function_middleware decorators " + f"or specify parameter types (AgentRunContext or FunctionInvocationContext)." + ) + + def _build_middleware_pipelines( + agent_level_middlewares: Middleware | list[Middleware] | None, + run_level_middlewares: Middleware | list[Middleware] | None = None, + ) -> tuple[AgentMiddlewarePipeline, FunctionMiddlewarePipeline]: + """Build fresh agent and function middleware pipelines from the provided middleware lists. + + Args: + agent_level_middlewares: Agent-level middleware (executed first) + run_level_middlewares: Run-level middleware (executed after agent middleware) + """ + # Merge middleware lists: agent middleware first, then run middleware + combined_middlewares: list[Middleware] = [] + + if agent_level_middlewares: + if isinstance(agent_level_middlewares, list): + combined_middlewares.extend(agent_level_middlewares) # type: ignore[arg-type] + else: + combined_middlewares.append(agent_level_middlewares) + + if run_level_middlewares: + if isinstance(run_level_middlewares, list): + combined_middlewares.extend(run_level_middlewares) # type: ignore[arg-type] + else: + combined_middlewares.append(run_level_middlewares) + + if not combined_middlewares: + return AgentMiddlewarePipeline(), FunctionMiddlewarePipeline() + + middleware_list = combined_middlewares # Separate agent and function middleware using isinstance checks agent_middlewares: list[AgentMiddleware | AgentMiddlewareCallable] = [] @@ -475,75 +647,42 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]: elif isinstance(middleware, FunctionMiddleware): function_middlewares.append(middleware) elif callable(middleware): # type: ignore[arg-type] - # Check function signature to determine type - try: - sig = inspect.signature(middleware) - params = list(sig.parameters.values()) - if len(params) >= 1: - first_param = params[0] - # Check if first parameter is AgentRunContext or FunctionInvocationContext - if ( - hasattr(first_param.annotation, "__name__") - and first_param.annotation.__name__ == "AgentRunContext" - ): - agent_middlewares.append(middleware) # type: ignore - elif ( - hasattr(first_param.annotation, "__name__") - and first_param.annotation.__name__ == "FunctionInvocationContext" - ): - function_middlewares.append(middleware) # type: ignore - else: - # Default to agent middleware if uncertain - agent_middlewares.append(middleware) # type: ignore - else: - agent_middlewares.append(middleware) # type: ignore - except Exception: - # If signature inspection fails, assume it's an agent middleware + # Determine middleware type using decorator and/or parameter type annotation + middleware_type = _determine_middleware_type(middleware) + if middleware_type == MiddlewareType.AGENT: agent_middlewares.append(middleware) # type: ignore + elif middleware_type == MiddlewareType.FUNCTION: + function_middlewares.append(middleware) # type: ignore + else: + # This should not happen if _determine_middleware_type is implemented correctly + raise ValueError(f"Unknown middleware type: {middleware_type}") else: # Fallback agent_middlewares.append(middleware) # type: ignore - self._agent_middleware_pipeline = AgentMiddlewarePipeline(agent_middlewares) - self._function_middleware_pipeline = FunctionMiddlewarePipeline(function_middlewares) + return AgentMiddlewarePipeline(agent_middlewares), FunctionMiddlewarePipeline(function_middlewares) async def middleware_enabled_run( self: Any, messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, *, thread: Any = None, + middleware: Middleware | list[Middleware] | None = None, **kwargs: Any, ) -> AgentRunResponse: """Middleware-enabled run method.""" - # Initialize middleware pipelines if not already done - if ( - hasattr(self, "middleware") - and self.middleware - and not ( - hasattr(self, "_agent_middleware_pipeline") - and hasattr(self, "_function_middleware_pipeline") - and ( - self._agent_middleware_pipeline.has_middlewares - or self._function_middleware_pipeline.has_middlewares - ) - ) - ): - _initialize_middleware_pipelines(self, self.middleware) - - # Ensure pipelines exist even if empty - if not hasattr(self, "_agent_middleware_pipeline"): - self._agent_middleware_pipeline = AgentMiddlewarePipeline() - if not hasattr(self, "_function_middleware_pipeline"): - self._function_middleware_pipeline = FunctionMiddlewarePipeline() + # Build fresh middleware pipelines from current middleware collection and run-level middleware + agent_middleware = getattr(self, "middleware", None) + agent_pipeline, function_pipeline = _build_middleware_pipelines(agent_middleware, middleware) # Add function middleware pipeline to kwargs if available - if self._function_middleware_pipeline.has_middlewares: - kwargs["_function_middleware_pipeline"] = self._function_middleware_pipeline + if function_pipeline.has_middlewares: + kwargs["_function_middleware_pipeline"] = function_pipeline normalized_messages = self._normalize_messages(messages) # Execute with middleware if available - if self._agent_middleware_pipeline.has_middlewares: + if agent_pipeline.has_middlewares: context = AgentRunContext( agent=self, # type: ignore[arg-type] messages=normalized_messages, @@ -553,7 +692,7 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]: async def _execute_handler(ctx: AgentRunContext) -> AgentRunResponse: return await original_run(self, ctx.messages, thread=thread, **kwargs) # type: ignore - result = await self._agent_middleware_pipeline.execute( + result = await agent_pipeline.execute( self, # type: ignore[arg-type] normalized_messages, context, @@ -570,38 +709,22 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]: messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, *, thread: Any = None, + middleware: Middleware | list[Middleware] | None = None, **kwargs: Any, ) -> AsyncIterable[AgentRunResponseUpdate]: """Middleware-enabled run_stream method.""" - # Initialize middleware pipelines if not already done - if ( - hasattr(self, "middleware") - and self.middleware - and not ( - hasattr(self, "_agent_middleware_pipeline") - and hasattr(self, "_function_middleware_pipeline") - and ( - self._agent_middleware_pipeline.has_middlewares - or self._function_middleware_pipeline.has_middlewares - ) - ) - ): - _initialize_middleware_pipelines(self, self.middleware) - - # Ensure pipelines exist even if empty - if not hasattr(self, "_agent_middleware_pipeline"): - self._agent_middleware_pipeline = AgentMiddlewarePipeline() - if not hasattr(self, "_function_middleware_pipeline"): - self._function_middleware_pipeline = FunctionMiddlewarePipeline() + # Build fresh middleware pipelines from current middleware collection and run-level middleware + agent_middleware = getattr(self, "middleware", None) + agent_pipeline, function_pipeline = _build_middleware_pipelines(agent_middleware, middleware) # Add function middleware pipeline to kwargs if available - if self._function_middleware_pipeline.has_middlewares: - kwargs["_function_middleware_pipeline"] = self._function_middleware_pipeline + if function_pipeline.has_middlewares: + kwargs["_function_middleware_pipeline"] = function_pipeline normalized_messages = self._normalize_messages(messages) # Execute with middleware if available - if self._agent_middleware_pipeline.has_middlewares: + if agent_pipeline.has_middlewares: context = AgentRunContext( agent=self, # type: ignore[arg-type] messages=normalized_messages, @@ -613,7 +736,7 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]: yield update async def _stream_generator() -> AsyncIterable[AgentRunResponseUpdate]: - async for update in self._agent_middleware_pipeline.execute_stream( + async for update in agent_pipeline.execute_stream( self, # type: ignore[arg-type] normalized_messages, context, diff --git a/python/packages/main/tests/main/test_middleware.py b/python/packages/main/tests/main/test_middleware.py index f0f8b536c1..a530c4c5e8 100644 --- a/python/packages/main/tests/main/test_middleware.py +++ b/python/packages/main/tests/main/test_middleware.py @@ -77,6 +77,16 @@ class TestFunctionInvocationContext: class TestAgentMiddlewarePipeline: """Test cases for AgentMiddlewarePipeline.""" + class PreNextTerminateMiddleware(AgentMiddleware): + async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + context.terminate = True + await next(context) + + class PostNextTerminateMiddleware(AgentMiddleware): + async def process(self, context: AgentRunContext, next: Any) -> None: + await next(context) + context.terminate = True + def test_init_empty(self) -> None: """Test AgentMiddlewarePipeline initialization with no middlewares.""" pipeline = AgentMiddlewarePipeline() @@ -194,10 +204,143 @@ class TestAgentMiddlewarePipeline: assert updates[1].text == "chunk2" assert execution_order == ["test_before", "test_after", "handler_start", "handler_end"] + async def test_execute_with_pre_next_termination(self, mock_agent: AgentProtocol) -> None: + """Test pipeline execution with termination before next().""" + middleware = self.PreNextTerminateMiddleware() + pipeline = AgentMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + execution_order: list[str] = [] + + async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + # Handler should not be executed when terminated before next() + execution_order.append("handler") + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + + response = await pipeline.execute(mock_agent, messages, context, final_handler) + assert response is not None + assert context.terminate + # Handler should not be called when terminated before next() + assert execution_order == [] + assert not response.messages + + async def test_execute_with_post_next_termination(self, mock_agent: AgentProtocol) -> None: + """Test pipeline execution with termination after next().""" + middleware = self.PostNextTerminateMiddleware() + pipeline = AgentMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + execution_order: list[str] = [] + + async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + execution_order.append("handler") + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + + response = await pipeline.execute(mock_agent, messages, context, final_handler) + assert response is not None + assert len(response.messages) == 1 + assert response.messages[0].text == "response" + assert context.terminate + assert execution_order == ["handler"] + + async def test_execute_stream_with_pre_next_termination(self, mock_agent: AgentProtocol) -> None: + """Test pipeline streaming execution with termination before next().""" + middleware = self.PreNextTerminateMiddleware() + pipeline = AgentMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + execution_order: list[str] = [] + + async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]: + # Handler should not be executed when terminated before next() + execution_order.append("handler_start") + yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")]) + yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")]) + execution_order.append("handler_end") + + updates: list[AgentRunResponseUpdate] = [] + async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler): + updates.append(update) + + assert context.terminate + # Handler should not be called when terminated before next() + assert execution_order == [] + assert not updates + + async def test_execute_stream_with_post_next_termination(self, mock_agent: AgentProtocol) -> None: + """Test pipeline streaming execution with termination after next().""" + middleware = self.PostNextTerminateMiddleware() + pipeline = AgentMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + execution_order: list[str] = [] + + async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]: + execution_order.append("handler_start") + yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")]) + yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")]) + execution_order.append("handler_end") + + updates: list[AgentRunResponseUpdate] = [] + async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler): + updates.append(update) + + assert len(updates) == 2 + assert updates[0].text == "chunk1" + assert updates[1].text == "chunk2" + assert context.terminate + assert execution_order == ["handler_start", "handler_end"] + class TestFunctionMiddlewarePipeline: """Test cases for FunctionMiddlewarePipeline.""" + class PreNextTerminateFunctionMiddleware(FunctionMiddleware): + async def process(self, context: FunctionInvocationContext, next: Any) -> None: + context.terminate = True + await next(context) + + class PostNextTerminateFunctionMiddleware(FunctionMiddleware): + async def process(self, context: FunctionInvocationContext, next: Any) -> None: + await next(context) + context.terminate = True + + async def test_execute_with_pre_next_termination(self, mock_function: AIFunction[Any, Any]) -> None: + """Test pipeline execution with termination before next().""" + middleware = self.PreNextTerminateFunctionMiddleware() + pipeline = FunctionMiddlewarePipeline([middleware]) + arguments = FunctionTestArgs(name="test") + context = FunctionInvocationContext(function=mock_function, arguments=arguments) + execution_order: list[str] = [] + + async def final_handler(ctx: FunctionInvocationContext) -> str: + # Handler should not be executed when terminated before next() + execution_order.append("handler") + return "test result" + + result = await pipeline.execute(mock_function, arguments, context, final_handler) + assert result is None + assert context.terminate + # Handler should not be called when terminated before next() + assert execution_order == [] + + async def test_execute_with_post_next_termination(self, mock_function: AIFunction[Any, Any]) -> None: + """Test pipeline execution with termination after next().""" + middleware = self.PostNextTerminateFunctionMiddleware() + pipeline = FunctionMiddlewarePipeline([middleware]) + arguments = FunctionTestArgs(name="test") + context = FunctionInvocationContext(function=mock_function, arguments=arguments) + execution_order: list[str] = [] + + async def final_handler(ctx: FunctionInvocationContext) -> str: + execution_order.append("handler") + return "test result" + + result = await pipeline.execute(mock_function, arguments, context, final_handler) + assert result == "test result" + assert context.terminate + assert execution_order == ["handler"] + def test_init_empty(self) -> None: """Test FunctionMiddlewarePipeline initialization with no middlewares.""" pipeline = FunctionMiddlewarePipeline() @@ -884,44 +1027,6 @@ class TestMiddlewareExecutionControl: assert result.messages == [] # Empty response assert not handler_called - async def test_function_middleware_pre_execution_override_with_next( - self, mock_function: AIFunction[Any, Any] - ) -> None: - """Test that function middleware can override result before calling next() - this skips handler execution.""" - - class FunctionTestArgs(BaseModel): - name: str = Field(description="Test name parameter") - - class PreOverrideFunctionMiddleware(FunctionMiddleware): - async def process( - self, - context: FunctionInvocationContext, - next: Callable[[FunctionInvocationContext], Awaitable[None]], - ) -> None: - # Set override first - context.result = "pre-override result" - # Then call next() to continue middleware pipeline - await next(context) - - middleware = PreOverrideFunctionMiddleware() - pipeline = FunctionMiddlewarePipeline([middleware]) - arguments = FunctionTestArgs(name="test") - context = FunctionInvocationContext(function=mock_function, arguments=arguments) - - handler_called = False - - async def final_handler(ctx: FunctionInvocationContext) -> str: - nonlocal handler_called - handler_called = True - # This should not be called when result is pre-set - return "original result" - - result = await pipeline.execute(mock_function, arguments, context, final_handler) - - # Verify pre-override worked and handler was NOT called (because result was already set) - assert result == "pre-override result" - assert not handler_called - @pytest.fixture def mock_agent() -> AgentProtocol: diff --git a/python/packages/main/tests/main/test_middleware_with_agent.py b/python/packages/main/tests/main/test_middleware_with_agent.py index 7a2c030285..6ef89e5c97 100644 --- a/python/packages/main/tests/main/test_middleware_with_agent.py +++ b/python/packages/main/tests/main/test_middleware_with_agent.py @@ -1,6 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. from collections.abc import Awaitable, Callable +from typing import Any + +import pytest from agent_framework import ( AgentRunResponseUpdate, @@ -12,12 +15,15 @@ from agent_framework import ( FunctionResultContent, Role, TextContent, + agent_middleware, + function_middleware, ) from agent_framework._middleware import ( AgentMiddleware, AgentRunContext, FunctionInvocationContext, FunctionMiddleware, + MiddlewareType, ) from .conftest import MockChatClient @@ -98,6 +104,170 @@ class TestChatAgentClassBasedMiddleware: class TestChatAgentFunctionBasedMiddleware: """Test cases for function-based middleware integration with ChatAgent.""" + async def test_agent_middleware_with_pre_termination(self, chat_client: "MockChatClient") -> None: + """Test that agent middleware can terminate execution before calling next().""" + execution_order: list[str] = [] + + class PreTerminationMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append("middleware_before") + context.terminate = True + # We call next() but since terminate=True, subsequent middleware and handler should not execute + await next(context) + execution_order.append("middleware_after") + + # Create ChatAgent with terminating middleware + middleware = PreTerminationMiddleware() + agent = ChatAgent(chat_client=chat_client, middleware=[middleware]) + + # Execute the agent with multiple messages + messages = [ + ChatMessage(role=Role.USER, text="message1"), + ChatMessage(role=Role.USER, text="message2"), # This should not be processed due to termination + ] + response = await agent.run(messages) + + # Verify response + assert response is not None + assert not response.messages # No messages should be in response due to pre-termination + assert execution_order == ["middleware_before", "middleware_after"] # Middleware still completes + assert chat_client.call_count == 0 # No calls should be made due to termination + + async def test_agent_middleware_with_post_termination(self, chat_client: "MockChatClient") -> None: + """Test that agent middleware can terminate execution after calling next().""" + execution_order: list[str] = [] + + class PostTerminationMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append("middleware_before") + await next(context) + execution_order.append("middleware_after") + context.terminate = True + + # Create ChatAgent with terminating middleware + middleware = PostTerminationMiddleware() + agent = ChatAgent(chat_client=chat_client, middleware=[middleware]) + + # Execute the agent with multiple messages + messages = [ + ChatMessage(role=Role.USER, text="message1"), + ChatMessage(role=Role.USER, text="message2"), + ] + response = await agent.run(messages) + + # Verify response + assert response is not None + assert len(response.messages) == 1 + assert response.messages[0].role == Role.ASSISTANT + assert "test response" in response.messages[0].text + + # Verify middleware execution order + assert execution_order == ["middleware_before", "middleware_after"] + assert chat_client.call_count == 1 + + async def test_function_middleware_with_pre_termination(self, chat_client: "MockChatClient") -> None: + """Test that function middleware can terminate execution before calling next().""" + execution_order: list[str] = [] + + class PreTerminationFunctionMiddleware(FunctionMiddleware): + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + execution_order.append("middleware_before") + context.terminate = True + # We call next() but since terminate=True, subsequent middleware and handler should not execute + await next(context) + execution_order.append("middleware_after") + + # Create a message to start the conversation + messages = [ChatMessage(role=Role.USER, text="test message")] + + # Set up chat client to return a function call + chat_client.responses = [ + ChatResponse( + messages=[ + ChatMessage( + role=Role.ASSISTANT, + contents=[ + FunctionCallContent(call_id="test_call", name="test_function", arguments={"text": "test"}) + ], + ) + ] + ) + ] + + # Create the test function with the expected signature + def test_function(text: str) -> str: + execution_order.append("function_called") + return "test_result" + + # Create ChatAgent with function middleware and test function + middleware = PreTerminationFunctionMiddleware() + agent = ChatAgent(chat_client=chat_client, middleware=[middleware], tools=[test_function]) + + # Execute the agent + await agent.run(messages) + + # Verify that function was not called and only middleware executed + assert execution_order == ["middleware_before", "middleware_after"] + assert "function_called" not in execution_order + assert execution_order == ["middleware_before", "middleware_after"] + + async def test_function_middleware_with_post_termination(self, chat_client: "MockChatClient") -> None: + """Test that function middleware can terminate execution after calling next().""" + execution_order: list[str] = [] + + class PostTerminationFunctionMiddleware(FunctionMiddleware): + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + execution_order.append("middleware_before") + await next(context) + execution_order.append("middleware_after") + context.terminate = True + + # Create a message to start the conversation + messages = [ChatMessage(role=Role.USER, text="test message")] + + # Set up chat client to return a function call + chat_client.responses = [ + ChatResponse( + messages=[ + ChatMessage( + role=Role.ASSISTANT, + contents=[ + FunctionCallContent(call_id="test_call", name="test_function", arguments={"text": "test"}) + ], + ) + ] + ) + ] + + # Create the test function with the expected signature + def test_function(text: str) -> str: + execution_order.append("function_called") + return "test_result" + + # Create ChatAgent with function middleware and test function + middleware = PostTerminationFunctionMiddleware() + agent = ChatAgent(chat_client=chat_client, middleware=[middleware], tools=[test_function]) + + # Execute the agent + response = await agent.run(messages) + + # Verify that function was called and middleware executed + assert response is not None + assert "function_called" in execution_order + assert execution_order == ["middleware_before", "function_called", "middleware_after"] + async def test_function_based_agent_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None: """Test function-based agent middleware with ChatAgent.""" execution_order: list[str] = [] @@ -542,3 +712,631 @@ class TestChatAgentFunctionMiddlewareWithTools: assert len(function_results) == 1 assert function_calls[0].name == "sample_tool_function" assert function_results[0].call_id == function_calls[0].call_id + + +class TestMiddlewareDynamicRebuild: + """Test cases for dynamic middleware pipeline rebuilding with ChatAgent.""" + + class TrackingAgentMiddleware(AgentMiddleware): + """Test middleware that tracks execution.""" + + def __init__(self, name: str, execution_log: list[str]): + self.name = name + self.execution_log = execution_log + + async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + self.execution_log.append(f"{self.name}_start") + await next(context) + self.execution_log.append(f"{self.name}_end") + + async def test_middleware_dynamic_rebuild_non_streaming(self, chat_client: "MockChatClient") -> None: + """Test that middleware pipeline is rebuilt when agent.middleware collection is modified for non-streaming.""" + execution_log: list[str] = [] + + # Create agent with initial middleware + middleware1 = self.TrackingAgentMiddleware("middleware1", execution_log) + agent = ChatAgent(chat_client=chat_client, middleware=[middleware1]) + + # First execution - should use middleware1 + await agent.run("Test message 1") + assert "middleware1_start" in execution_log + assert "middleware1_end" in execution_log + + # Clear execution log + execution_log.clear() + + # Modify the middleware collection by adding another middleware + middleware2 = self.TrackingAgentMiddleware("middleware2", execution_log) + agent.middleware = [middleware1, middleware2] + + # Second execution - should use both middleware1 and middleware2 + await agent.run("Test message 2") + assert "middleware1_start" in execution_log + assert "middleware1_end" in execution_log + assert "middleware2_start" in execution_log + assert "middleware2_end" in execution_log + + # Clear execution log + execution_log.clear() + + # Modify the middleware collection by replacing with just middleware2 + agent.middleware = [middleware2] + + # Third execution - should use only middleware2 + await agent.run("Test message 3") + assert "middleware1_start" not in execution_log + assert "middleware1_end" not in execution_log + assert "middleware2_start" in execution_log + assert "middleware2_end" in execution_log + + # Clear execution log + execution_log.clear() + + # Remove all middleware + agent.middleware = [] + + # Fourth execution - should use no middleware + await agent.run("Test message 4") + assert len(execution_log) == 0 + + async def test_middleware_dynamic_rebuild_streaming(self, chat_client: "MockChatClient") -> None: + """Test that middleware pipeline is rebuilt for streaming when agent.middleware collection is modified.""" + execution_log: list[str] = [] + + # Create agent with initial middleware + middleware1 = self.TrackingAgentMiddleware("stream_middleware1", execution_log) + agent = ChatAgent(chat_client=chat_client, middleware=[middleware1]) + + # First streaming execution + updates: list[AgentRunResponseUpdate] = [] + async for update in agent.run_stream("Test stream message 1"): + updates.append(update) + + assert "stream_middleware1_start" in execution_log + assert "stream_middleware1_end" in execution_log + + # Clear execution log + execution_log.clear() + + # Modify the middleware collection + middleware2 = self.TrackingAgentMiddleware("stream_middleware2", execution_log) + agent.middleware = [middleware2] + + # Second streaming execution - should use only middleware2 + updates = [] + async for update in agent.run_stream("Test stream message 2"): + updates.append(update) + + assert "stream_middleware1_start" not in execution_log + assert "stream_middleware1_end" not in execution_log + assert "stream_middleware2_start" in execution_log + assert "stream_middleware2_end" in execution_log + + async def test_middleware_order_change_detection(self, chat_client: "MockChatClient") -> None: + """Test that changing the order of middleware is detected and applied.""" + execution_log: list[str] = [] + + middleware1 = self.TrackingAgentMiddleware("first", execution_log) + middleware2 = self.TrackingAgentMiddleware("second", execution_log) + + # Create agent with middleware in order [first, second] + agent = ChatAgent(chat_client=chat_client, middleware=[middleware1, middleware2]) + + # First execution + await agent.run("Test message 1") + assert execution_log == ["first_start", "second_start", "second_end", "first_end"] + + # Clear execution log + execution_log.clear() + + # Change order to [second, first] + agent.middleware = [middleware2, middleware1] + + # Second execution - should reflect new order + await agent.run("Test message 2") + assert execution_log == ["second_start", "first_start", "first_end", "second_end"] + + +class TestRunLevelMiddleware: + """Test cases for run-level middleware functionality.""" + + class TrackingAgentMiddleware(AgentMiddleware): + """Test middleware that tracks execution.""" + + def __init__(self, name: str, execution_log: list[str]): + self.name = name + self.execution_log = execution_log + + async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + self.execution_log.append(f"{self.name}_start") + await next(context) + self.execution_log.append(f"{self.name}_end") + + async def test_run_level_middleware_isolation(self, chat_client: "MockChatClient") -> None: + """Test that run-level middleware is isolated between multiple runs.""" + execution_log: list[str] = [] + + # Create agent without any agent-level middleware + agent = ChatAgent(chat_client=chat_client) + + # Create run-level middleware + run_middleware1 = self.TrackingAgentMiddleware("run1", execution_log) + run_middleware2 = self.TrackingAgentMiddleware("run2", execution_log) + + # First run with run_middleware1 + await agent.run("Test message 1", middleware=[run_middleware1]) + assert execution_log == ["run1_start", "run1_end"] + + # Clear execution log + execution_log.clear() + + # Second run with run_middleware2 - should not see run_middleware1 + await agent.run("Test message 2", middleware=[run_middleware2]) + assert execution_log == ["run2_start", "run2_end"] + assert "run1_start" not in execution_log + assert "run1_end" not in execution_log + + # Clear execution log + execution_log.clear() + + # Third run with no middleware - should not see any middleware execution + await agent.run("Test message 3") + assert execution_log == [] + + # Clear execution log + execution_log.clear() + + # Fourth run with both run middlewares - should see both + await agent.run("Test message 4", middleware=[run_middleware1, run_middleware2]) + assert execution_log == ["run1_start", "run2_start", "run2_end", "run1_end"] + + async def test_agent_plus_run_middleware_execution_order(self, chat_client: "MockChatClient") -> None: + """Test that agent middleware executes first, followed by run middleware.""" + execution_log: list[str] = [] + metadata_log: list[str] = [] + + class MetadataAgentMiddleware(AgentMiddleware): + def __init__(self, name: str): + self.name = name + + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_log.append(f"{self.name}_start") + # Set metadata to pass information to run middleware + context.metadata[f"{self.name}_key"] = f"{self.name}_value" + await next(context) + execution_log.append(f"{self.name}_end") + + class MetadataRunMiddleware(AgentMiddleware): + def __init__(self, name: str): + self.name = name + + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_log.append(f"{self.name}_start") + # Read metadata set by agent middleware + for key, value in context.metadata.items(): + metadata_log.append(f"{self.name}_reads_{key}:{value}") + # Set run-level metadata + context.metadata[f"{self.name}_key"] = f"{self.name}_value" + await next(context) + execution_log.append(f"{self.name}_end") + + # Create agent with agent-level middleware + agent_middleware = MetadataAgentMiddleware("agent") + agent = ChatAgent(chat_client=chat_client, middleware=[agent_middleware]) + + # Create run-level middleware + run_middleware = MetadataRunMiddleware("run") + + # Execute with both agent and run middleware + await agent.run("Test message", middleware=[run_middleware]) + + # Verify execution order: agent middleware wraps run middleware + expected_order = ["agent_start", "run_start", "run_end", "agent_end"] + assert execution_log == expected_order + + # Verify that run middleware can read agent middleware metadata + assert "run_reads_agent_key:agent_value" in metadata_log + + async def test_run_level_middleware_non_streaming(self, chat_client: "MockChatClient") -> None: + """Test run-level middleware with non-streaming execution.""" + execution_log: list[str] = [] + + # Create agent without agent-level middleware + agent = ChatAgent(chat_client=chat_client) + + # Create run-level middleware + run_middleware = self.TrackingAgentMiddleware("run_nonstream", execution_log) + + # Execute non-streaming with run middleware + response = await agent.run("Test non-streaming", middleware=[run_middleware]) + + # Verify response is correct + assert response is not None + assert len(response.messages) > 0 + assert response.messages[0].role == Role.ASSISTANT + assert "test response" in response.messages[0].text + + # Verify middleware was executed + assert execution_log == ["run_nonstream_start", "run_nonstream_end"] + + async def test_run_level_middleware_streaming(self, chat_client: "MockChatClient") -> None: + """Test run-level middleware with streaming execution.""" + execution_log: list[str] = [] + streaming_flags: list[bool] = [] + + class StreamingTrackingMiddleware(AgentMiddleware): + def __init__(self, name: str): + self.name = name + + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_log.append(f"{self.name}_start") + streaming_flags.append(context.is_streaming) + await next(context) + execution_log.append(f"{self.name}_end") + + # Create agent without agent-level middleware + agent = ChatAgent(chat_client=chat_client) + + # Set up mock streaming responses + chat_client.streaming_responses = [ + [ + ChatResponseUpdate(contents=[TextContent(text="Stream")], role=Role.ASSISTANT), + ChatResponseUpdate(contents=[TextContent(text=" response")], role=Role.ASSISTANT), + ] + ] + + # Create run-level middleware + run_middleware = StreamingTrackingMiddleware("run_stream") + + # Execute streaming with run middleware + updates: list[AgentRunResponseUpdate] = [] + async for update in agent.run_stream("Test streaming", middleware=[run_middleware]): + updates.append(update) + + # Verify streaming response + assert len(updates) == 2 + assert updates[0].text == "Stream" + assert updates[1].text == " response" + + # Verify middleware was executed with correct streaming flag + assert execution_log == ["run_stream_start", "run_stream_end"] + assert streaming_flags == [True] # Context should indicate streaming + + async def test_agent_and_run_level_both_agent_and_function_middleware(self, chat_client: "MockChatClient") -> None: + """Test complete scenario with agent and function middleware at both agent-level and run-level.""" + execution_log: list[str] = [] + + # Agent-level middleware + class AgentLevelAgentMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_log.append("agent_level_agent_start") + context.metadata["agent_level_agent"] = "processed" + await next(context) + execution_log.append("agent_level_agent_end") + + class AgentLevelFunctionMiddleware(FunctionMiddleware): + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + execution_log.append("agent_level_function_start") + context.metadata["agent_level_function"] = "processed" + await next(context) + execution_log.append("agent_level_function_end") + + # Run-level middleware + class RunLevelAgentMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_log.append("run_level_agent_start") + # Verify agent-level middleware metadata is available + assert "agent_level_agent" in context.metadata + context.metadata["run_level_agent"] = "processed" + await next(context) + execution_log.append("run_level_agent_end") + + class RunLevelFunctionMiddleware(FunctionMiddleware): + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + execution_log.append("run_level_function_start") + # Verify agent-level function middleware metadata is available + assert "agent_level_function" in context.metadata + context.metadata["run_level_function"] = "processed" + await next(context) + execution_log.append("run_level_function_end") + + # Create tool function for testing function middleware + def custom_tool(message: str) -> str: + execution_log.append("tool_executed") + return f"Tool response: {message}" + + # Set up mock to return a function call first, then a regular response + function_call_response = ChatResponse( + messages=[ + ChatMessage( + role=Role.ASSISTANT, + contents=[ + FunctionCallContent( + call_id="test_call", + name="custom_tool", + arguments='{"message": "test"}', + ) + ], + ) + ] + ) + final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")]) + chat_client.responses = [function_call_response, final_response] + + # Create agent with agent-level middleware + agent = ChatAgent( + chat_client=chat_client, + middleware=[AgentLevelAgentMiddleware(), AgentLevelFunctionMiddleware()], + tools=[custom_tool], + ) + + # Execute with run-level middleware + response = await agent.run( + "Test message", + middleware=[RunLevelAgentMiddleware(), RunLevelFunctionMiddleware()], + ) + + # Verify response + assert response is not None + assert len(response.messages) > 0 + assert chat_client.call_count == 2 # Function call + final response + + expected_order = [ + "agent_level_agent_start", + "run_level_agent_start", + "agent_level_function_start", + "run_level_function_start", + "tool_executed", + "run_level_function_end", + "agent_level_function_end", + "run_level_agent_end", + "agent_level_agent_end", + ] + assert execution_log == expected_order + + # Verify function call and result are in the response + all_contents = [content for message in response.messages for content in message.contents] + function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)] + function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)] + + assert len(function_calls) == 1 + assert len(function_results) == 1 + assert function_calls[0].name == "custom_tool" + assert function_results[0].call_id == function_calls[0].call_id + assert function_results[0].result is not None + assert "Tool response: test" in str(function_results[0].result) + + +class TestMiddlewareDecoratorLogic: + """Test the middleware decorator and type annotation logic.""" + + async def test_decorator_and_type_match(self, chat_client: MockChatClient) -> None: + """Both decorator and parameter type specified and match.""" + + execution_order: list[str] = [] + + @agent_middleware + async def matching_agent_middleware( + context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append("decorator_type_match_agent") + await next(context) + + @function_middleware + async def matching_function_middleware( + context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + execution_order.append("decorator_type_match_function") + await next(context) + + # Create tool function for testing function middleware + def custom_tool(message: str) -> str: + execution_order.append("tool_executed") + return f"Tool response: {message}" + + # Set up mock to return a function call first, then a regular response + function_call_response = ChatResponse( + messages=[ + ChatMessage( + role=Role.ASSISTANT, + contents=[ + FunctionCallContent( + call_id="test_call", + name="custom_tool", + arguments='{"message": "test"}', + ) + ], + ) + ] + ) + final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")]) + chat_client.responses = [function_call_response, final_response] + + # Should work without errors + agent = ChatAgent( + chat_client=chat_client, + middleware=[matching_agent_middleware, matching_function_middleware], + tools=[custom_tool], + ) + + response = await agent.run([ChatMessage(role=Role.USER, text="test")]) + + assert response is not None + assert "decorator_type_match_agent" in execution_order + assert "decorator_type_match_function" in execution_order + + async def test_decorator_and_type_mismatch(self, chat_client: MockChatClient) -> None: + """Both decorator and parameter type specified but don't match.""" + + # This will cause a type error at decoration time, so we need to test differently + # Should raise ValueError due to mismatch during agent creation + with pytest.raises(ValueError, match="Middleware type mismatch"): + + @agent_middleware # type: ignore[arg-type] + async def mismatched_middleware( + context: FunctionInvocationContext, # Wrong type for @agent_middleware + next: Any, + ) -> None: + await next(context) + + agent = ChatAgent(chat_client=chat_client, middleware=[mismatched_middleware]) + await agent.run([ChatMessage(role=Role.USER, text="test")]) + + async def test_only_decorator_specified(self, chat_client: Any) -> None: + """Only decorator specified - rely on decorator.""" + execution_order: list[str] = [] + + @agent_middleware + async def decorator_only_agent(context: Any, next: Any) -> None: # No type annotation + execution_order.append("decorator_only_agent") + await next(context) + + @function_middleware + async def decorator_only_function(context: Any, next: Any) -> None: # No type annotation + execution_order.append("decorator_only_function") + await next(context) + + # Create tool function for testing function middleware + def custom_tool(message: str) -> str: + execution_order.append("tool_executed") + return f"Tool response: {message}" + + # Set up mock to return a function call first, then a regular response + function_call_response = ChatResponse( + messages=[ + ChatMessage( + role=Role.ASSISTANT, + contents=[ + FunctionCallContent( + call_id="test_call", + name="custom_tool", + arguments='{"message": "test"}', + ) + ], + ) + ] + ) + final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")]) + chat_client.responses = [function_call_response, final_response] + + # Should work - relies on decorator + agent = ChatAgent( + chat_client=chat_client, middleware=[decorator_only_agent, decorator_only_function], tools=[custom_tool] + ) + + response = await agent.run([ChatMessage(role=Role.USER, text="test")]) + + assert response is not None + assert "decorator_only_agent" in execution_order + assert "decorator_only_function" in execution_order + + async def test_only_type_specified(self, chat_client: Any) -> None: + """Only parameter type specified - rely on types.""" + execution_order: list[str] = [] + + # No decorator + async def type_only_agent(context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + execution_order.append("type_only_agent") + await next(context) + + # No decorator + async def type_only_function( + context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + execution_order.append("type_only_function") + await next(context) + + # Create tool function for testing function middleware + def custom_tool(message: str) -> str: + execution_order.append("tool_executed") + return f"Tool response: {message}" + + # Set up mock to return a function call first, then a regular response + function_call_response = ChatResponse( + messages=[ + ChatMessage( + role=Role.ASSISTANT, + contents=[ + FunctionCallContent( + call_id="test_call", + name="custom_tool", + arguments='{"message": "test"}', + ) + ], + ) + ] + ) + final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")]) + chat_client.responses = [function_call_response, final_response] + + # Should work - relies on type annotations + agent = ChatAgent( + chat_client=chat_client, middleware=[type_only_agent, type_only_function], tools=[custom_tool] + ) + + response = await agent.run([ChatMessage(role=Role.USER, text="test")]) + + assert response is not None + assert "type_only_agent" in execution_order + assert "type_only_function" in execution_order + + async def test_neither_decorator_nor_type(self, chat_client: Any) -> None: + """Neither decorator nor parameter type specified - should throw exception.""" + + async def no_info_middleware(context: Any, next: Any) -> None: # No decorator, no type + await next(context) + + # Should raise ValueError + with pytest.raises(ValueError, match="Cannot determine middleware type"): + agent = ChatAgent(chat_client=chat_client, middleware=[no_info_middleware]) + await agent.run([ChatMessage(role=Role.USER, text="test")]) + + async def test_insufficient_parameters_error(self, chat_client: Any) -> None: + """Test that middleware with insufficient parameters raises an error.""" + from agent_framework import ChatAgent, agent_middleware + + # Should raise ValueError about insufficient parameters + with pytest.raises(ValueError, match="must have at least 2 parameters"): + + @agent_middleware # type: ignore[arg-type] + async def insufficient_params_middleware(context: Any) -> None: # Missing 'next' parameter + pass + + agent = ChatAgent(chat_client=chat_client, middleware=[insufficient_params_middleware]) + await agent.run([ChatMessage(role=Role.USER, text="test")]) + + async def test_decorator_markers_preserved(self) -> None: + """Test that decorator markers are properly set on functions.""" + + @agent_middleware + async def test_agent_middleware(context: Any, next: Any) -> None: + pass + + @function_middleware + async def test_function_middleware(context: Any, next: Any) -> None: + pass + + # Check that decorator markers were set + assert hasattr(test_agent_middleware, "_middleware_type") + assert test_agent_middleware._middleware_type == MiddlewareType.AGENT # type: ignore[attr-defined] + + assert hasattr(test_function_middleware, "_middleware_type") + assert test_function_middleware._middleware_type == MiddlewareType.FUNCTION # type: ignore[attr-defined] diff --git a/python/samples/getting_started/middleware/agent_and_run_level_middleware.py b/python/samples/getting_started/middleware/agent_and_run_level_middleware.py new file mode 100644 index 0000000000..0b962d8af8 --- /dev/null +++ b/python/samples/getting_started/middleware/agent_and_run_level_middleware.py @@ -0,0 +1,271 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import time +from collections.abc import Awaitable, Callable +from random import randint +from typing import Annotated + +from agent_framework import ( + AgentMiddleware, + AgentRunContext, + AgentRunResponse, + FunctionInvocationContext, +) +from agent_framework.foundry import FoundryChatClient +from azure.identity.aio import AzureCliCredential +from pydantic import Field + +""" +Agent-Level and Run-Level Middleware Example + +This sample demonstrates the difference between agent-level and run-level middleware: + +- Agent-level middleware: Applied to ALL runs of the agent (persistent across runs) +- Run-level middleware: Applied to specific runs only (isolated per run) + +The example shows: +1. Agent-level security middleware that validates all requests +2. Agent-level performance monitoring across all runs +3. Run-level context middleware for specific use cases (high priority, debugging) +4. Run-level caching middleware for expensive operations + +Execution order: Agent middleware (outermost) -> Run middleware (innermost) -> Agent execution +""" + + +def get_weather( + location: Annotated[str, Field(description="The location to get the weather for.")], +) -> str: + """Get the weather for a given location.""" + conditions = ["sunny", "cloudy", "rainy", "stormy"] + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + + +# Agent-level middleware (applied to ALL runs) +class SecurityAgentMiddleware(AgentMiddleware): + """Agent-level security middleware that validates all requests.""" + + async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + print("[SecurityMiddleware] Checking security for all requests...") + + # Check for security violations in the last user message + last_message = context.messages[-1] if context.messages else None + if last_message and last_message.text: + query = last_message.text.lower() + if any(word in query for word in ["password", "secret", "credentials"]): + print("[SecurityMiddleware] Security violation detected! Blocking request.") + return # Don't call next() to prevent execution + + print("[SecurityMiddleware] Security check passed.") + context.metadata["security_validated"] = True + await next(context) + + +async def performance_monitor_middleware( + context: AgentRunContext, + next: Callable[[AgentRunContext], Awaitable[None]], +) -> None: + """Agent-level performance monitoring for all runs.""" + print("[PerformanceMonitor] Starting performance monitoring...") + start_time = time.time() + + await next(context) + + end_time = time.time() + duration = end_time - start_time + print(f"[PerformanceMonitor] Total execution time: {duration:.3f}s") + context.metadata["execution_time"] = duration + + +# Run-level middleware (applied to specific runs only) +class HighPriorityMiddleware(AgentMiddleware): + """Run-level middleware for high priority requests.""" + + async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + print("[HighPriority] Processing high priority request with expedited handling...") + + # Read metadata set by agent-level middleware + if context.metadata.get("security_validated"): + print("[HighPriority] Security validation confirmed from agent middleware") + + # Set high priority flag + context.metadata["priority"] = "high" + context.metadata["expedited"] = True + + await next(context) + print("[HighPriority] High priority processing completed") + + +async def debugging_middleware( + context: AgentRunContext, + next: Callable[[AgentRunContext], Awaitable[None]], +) -> None: + """Run-level debugging middleware for troubleshooting specific runs.""" + print("[Debug] Debug mode enabled for this run") + print(f"[Debug] Messages count: {len(context.messages)}") + print(f"[Debug] Is streaming: {context.is_streaming}") + + # Log existing metadata from agent middleware + if context.metadata: + print(f"[Debug] Existing metadata: {context.metadata}") + + context.metadata["debug_enabled"] = True + + await next(context) + + print("[Debug] Debug information collected") + + +class CachingMiddleware(AgentMiddleware): + """Run-level caching middleware for expensive operations.""" + + def __init__(self) -> None: + self.cache: dict[str, AgentRunResponse] = {} + + async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + # Create a simple cache key from the last message + last_message = context.messages[-1] if context.messages else None + cache_key: str = last_message.text if last_message and last_message.text else "no_message" + + if cache_key in self.cache: + print(f"[Cache] Cache HIT for: '{cache_key[:30]}...'") + context.result = self.cache[cache_key] # type: ignore + return # Don't call next(), return cached result + + print(f"[Cache] Cache MISS for: '{cache_key[:30]}...'") + context.metadata["cache_key"] = cache_key + + await next(context) + + # Cache the result if we have one + if context.result: + self.cache[cache_key] = context.result # type: ignore + print("[Cache] Result cached for future use") + + +async def function_logging_middleware( + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], +) -> None: + """Function middleware that logs all function calls.""" + function_name = context.function.name + args = context.arguments + print(f"[FunctionLog] Calling function: {function_name} with args: {args}") + + await next(context) + + print(f"[FunctionLog] Function {function_name} completed") + + +async def main() -> None: + """Example demonstrating agent-level and run-level middleware.""" + print("=== Agent-Level and Run-Level Middleware Example ===\n") + + # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred + # authentication option. + async with ( + AzureCliCredential() as credential, + FoundryChatClient(async_credential=credential).create_agent( + name="WeatherAgent", + instructions="You are a helpful weather assistant.", + tools=get_weather, + # Agent-level middleware: applied to ALL runs + middleware=[ + SecurityAgentMiddleware(), + performance_monitor_middleware, + function_logging_middleware, + ], + ) as agent, + ): + print("Agent created with agent-level middleware:") + print(" - SecurityMiddleware (blocks sensitive requests)") + print(" - PerformanceMonitor (tracks execution time)") + print(" - FunctionLogging (logs all function calls)") + print() + + # Run 1: Normal query with no run-level middleware + print("=" * 60) + print("RUN 1: Normal query (agent-level middleware only)") + print("=" * 60) + query = "What's the weather like in Paris?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result.text if result.text else 'No response'}") + print() + + # Run 2: High priority request with run-level middleware + print("=" * 60) + print("RUN 2: High priority request (agent + run-level middleware)") + print("=" * 60) + query = "What's the weather in Tokyo? This is urgent!" + print(f"User: {query}") + result = await agent.run( + query, + middleware=HighPriorityMiddleware(), # Run-level middleware + ) + print(f"Agent: {result.text if result.text else 'No response'}") + print() + + # Run 3: Debug mode with run-level debugging middleware + print("=" * 60) + print("RUN 3: Debug mode (agent + run-level debugging)") + print("=" * 60) + query = "What's the weather in London?" + print(f"User: {query}") + result = await agent.run( + query, + middleware=[debugging_middleware], # Run-level middleware + ) + print(f"Agent: {result.text if result.text else 'No response'}") + print() + + # Run 4: Multiple run-level middleware + print("=" * 60) + print("RUN 4: Multiple run-level middleware (caching + debug)") + print("=" * 60) + caching = CachingMiddleware() + query = "What's the weather in New York?" + print(f"User: {query}") + result = await agent.run( + query, + middleware=[caching, debugging_middleware], # Multiple run-level middleware + ) + print(f"Agent: {result.text if result.text else 'No response'}") + print() + + # Run 5: Test cache hit with same query + print("=" * 60) + print("RUN 5: Test cache hit (same query as Run 4)") + print("=" * 60) + print(f"User: {query}") # Same query as Run 4 + result = await agent.run( + query, + middleware=[caching], # Same caching middleware instance + ) + print(f"Agent: {result.text if result.text else 'No response'}") + print() + + # Run 6: Security violation test + print("=" * 60) + print("RUN 6: Security test (should be blocked by agent middleware)") + print("=" * 60) + query = "What's the secret weather password for Berlin?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result.text if result.text else 'Request was blocked by security middleware'}") + print() + + # Run 7: Normal query again (no run-level middleware interference) + print("=" * 60) + print("RUN 7: Normal query again (agent-level middleware only)") + print("=" * 60) + query = "What's the weather in Sydney?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result.text if result.text else 'No response'}") + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/middleware/decorator_middleware.py b/python/samples/getting_started/middleware/decorator_middleware.py new file mode 100644 index 0000000000..672f4c7741 --- /dev/null +++ b/python/samples/getting_started/middleware/decorator_middleware.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import datetime + +from agent_framework import ( + agent_middleware, + function_middleware, +) +from agent_framework.foundry import FoundryChatClient +from azure.identity.aio import AzureCliCredential + +""" +Decorator Middleware Example + +This sample demonstrates how to use @agent_middleware and @function_middleware decorators +to explicitly mark middleware functions without requiring type annotations. + +The framework supports the following middleware detection scenarios: + +1. Both decorator and parameter type specified: + - Validates that they match (e.g., @agent_middleware with AgentRunContext) + - Throws exception if they don't match for safety + +2. Only decorator specified: + - Relies on decorator to determine middleware type + - No type annotations needed - framework handles context types automatically + +3. Only parameter type specified: + - Uses type annotations (AgentRunContext, FunctionInvocationContext) for detection + +4. Neither decorator nor parameter type specified: + - Throws exception requiring either decorator or type annotation + - Prevents ambiguous middleware that can't be properly classified + +Key benefits of decorator approach: +- No type annotations needed (simpler syntax) +- Explicit middleware type declaration +- Clear intent in code +- Prevents type mismatches +""" + + +def get_current_time() -> str: + """Get the current time.""" + return f"Current time is {datetime.datetime.now().strftime('%H:%M:%S')}" + + +@agent_middleware # Decorator marks this as agent middleware - no type annotations needed +async def simple_agent_middleware(context, next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality + """Agent middleware that runs before and after agent execution.""" + print("[Agent Middleware] Before agent execution") + await next(context) + print("[Agent Middleware] After agent execution") + + +@function_middleware # Decorator marks this as function middleware - no type annotations needed +async def simple_function_middleware(context, next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality + """Function middleware that runs before and after function calls.""" + print(f"[Function Middleware] Before calling: {context.function.name}") # type: ignore + await next(context) + print(f"[Function Middleware] After calling: {context.function.name}") # type: ignore + + +async def main() -> None: + """Example demonstrating decorator-based middleware.""" + print("=== Decorator Middleware Example ===") + + # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred + # authentication option. + async with ( + AzureCliCredential() as credential, + FoundryChatClient(async_credential=credential).create_agent( + name="TimeAgent", + instructions="You are a helpful time assistant. Call get_current_time when asked about time.", + tools=get_current_time, + middleware=[simple_agent_middleware, simple_function_middleware], + ) as agent, + ): + query = "What time is it?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result.text if result.text else 'No response'}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/middleware/middleware_termination.py b/python/samples/getting_started/middleware/middleware_termination.py new file mode 100644 index 0000000000..991368a0f5 --- /dev/null +++ b/python/samples/getting_started/middleware/middleware_termination.py @@ -0,0 +1,177 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from collections.abc import Awaitable, Callable +from random import randint +from typing import Annotated + +from agent_framework import ( + AgentMiddleware, + AgentRunContext, + AgentRunResponse, + ChatMessage, + Role, +) +from agent_framework.foundry import FoundryChatClient +from azure.identity.aio import AzureCliCredential +from pydantic import Field + +""" +Middleware Termination Example + +This sample demonstrates how middleware can terminate execution using the `context.terminate` flag. +The example includes: + +- PreTerminationMiddleware: Terminates execution before calling next() to prevent agent processing +- PostTerminationMiddleware: Allows processing to complete but terminates further execution + +This is useful for implementing security checks, rate limiting, or early exit conditions. +""" + + +def get_weather( + location: Annotated[str, Field(description="The location to get the weather for.")], +) -> str: + """Get the weather for a given location.""" + conditions = ["sunny", "cloudy", "rainy", "stormy"] + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + + +class PreTerminationMiddleware(AgentMiddleware): + """Middleware that terminates execution before calling the agent.""" + + def __init__(self, blocked_words: list[str]): + self.blocked_words = [word.lower() for word in blocked_words] + + async def process( + self, + context: AgentRunContext, + next: Callable[[AgentRunContext], Awaitable[None]], + ) -> None: + # Check if the user message contains any blocked words + last_message = context.messages[-1] if context.messages else None + if last_message and last_message.text: + query = last_message.text.lower() + for blocked_word in self.blocked_words: + if blocked_word in query: + print(f"[PreTerminationMiddleware] Blocked word '{blocked_word}' detected. Terminating request.") + + # Set a custom response + context.result = AgentRunResponse( + messages=[ + ChatMessage( + role=Role.ASSISTANT, + text=( + f"Sorry, I cannot process requests containing '{blocked_word}'. " + "Please rephrase your question." + ), + ) + ] + ) + + # Set terminate flag to prevent further processing + context.terminate = True + break + + await next(context) + + +class PostTerminationMiddleware(AgentMiddleware): + """Middleware that allows processing but terminates after reaching max responses across multiple runs.""" + + def __init__(self, max_responses: int = 1): + self.max_responses = max_responses + self.response_count = 0 + + async def process( + self, + context: AgentRunContext, + next: Callable[[AgentRunContext], Awaitable[None]], + ) -> None: + print(f"[PostTerminationMiddleware] Processing request (response count: {self.response_count})") + + # Check if we should terminate before processing + if self.response_count >= self.max_responses: + print( + f"[PostTerminationMiddleware] Maximum responses ({self.max_responses}) reached. " + "Terminating further processing." + ) + context.terminate = True + + # Allow the agent to process normally + await next(context) + + # Increment response count after processing + self.response_count += 1 + + +async def pre_termination_middleware() -> None: + """Demonstrate pre-termination middleware that blocks requests with certain words.""" + print("\n--- Example 1: Pre-termination Middleware ---") + async with ( + AzureCliCredential() as credential, + FoundryChatClient(async_credential=credential).create_agent( + name="WeatherAgent", + instructions="You are a helpful weather assistant.", + tools=get_weather, + middleware=PreTerminationMiddleware(blocked_words=["bad", "inappropriate"]), + ) as agent, + ): + # Test with normal query + print("\n1. Normal query:") + query = "What's the weather like in Seattle?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result.text}") + + # Test with blocked word + print("\n2. Query with blocked word:") + query = "What's the bad weather in New York?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result.text}") + + +async def post_termination_middleware() -> None: + """Demonstrate post-termination middleware that limits responses across multiple runs.""" + print("\n--- Example 2: Post-termination Middleware ---") + async with ( + AzureCliCredential() as credential, + FoundryChatClient(async_credential=credential).create_agent( + name="WeatherAgent", + instructions="You are a helpful weather assistant.", + tools=get_weather, + middleware=PostTerminationMiddleware(max_responses=1), + ) as agent, + ): + # First run (should work) + print("\n1. First run:") + query = "What's the weather in Paris?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result.text}") + + # Second run (should be terminated by middleware) + print("\n2. Second run (should be terminated):") + query = "What about the weather in London?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result.text if result.text else 'No response (terminated)'}") + + # Third run (should also be terminated) + print("\n3. Third run (should also be terminated):") + query = "And New York?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result.text if result.text else 'No response (terminated)'}") + + +async def main() -> None: + """Example demonstrating middleware termination functionality.""" + print("=== Middleware Termination Example ===") + await pre_termination_middleware() + await post_termination_middleware() + + +if __name__ == "__main__": + asyncio.run(main())