From 1fdc4be88d3e9168afc875c23c47df9fd6ae1bb3 Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Wed, 11 Feb 2026 02:47:41 -0800 Subject: [PATCH 1/3] Removed context parameter from call_next (#3829) --- python/packages/core/AGENTS.md | 2 +- .../core/agent_framework/_middleware.py | 84 +++--- .../core/test_as_tool_kwargs_propagation.py | 42 +-- .../core/test_function_invocation_logic.py | 10 +- .../core/tests/core/test_middleware.py | 228 +++++++-------- .../core/test_middleware_context_result.py | 58 ++-- .../tests/core/test_middleware_with_agent.py | 262 ++++++++---------- .../tests/core/test_middleware_with_chat.py | 66 ++--- .../ollama/tests/test_ollama_chat_client.py | 2 +- .../_handoff.py | 4 +- .../agent_framework_purview/_middleware.py | 8 +- .../tests/purview/test_chat_middleware.py | 54 ++-- .../purview/tests/purview/test_middleware.py | 66 ++--- python/samples/concepts/tools/README.md | 16 +- .../azure_ai/azure_ai_with_agent_as_tool.py | 4 +- .../openai/openai_responses_client_basic.py | 4 +- ...nai_responses_client_with_agent_as_tool.py | 4 +- .../devui/weather_agent_azure/agent.py | 8 +- .../agent_and_run_level_middleware.py | 26 +- .../middleware/chat_middleware.py | 8 +- .../middleware/class_based_middleware.py | 8 +- .../middleware/decorator_middleware.py | 4 +- .../exception_handling_with_middleware.py | 4 +- .../middleware/function_based_middleware.py | 10 +- .../middleware/middleware_termination.py | 12 +- .../override_result_with_middleware.py | 12 +- .../middleware/runtime_context_delegation.py | 16 +- .../middleware/shared_state_middleware.py | 8 +- .../middleware/thread_behavior_middleware.py | 4 +- 29 files changed, 451 insertions(+), 583 deletions(-) diff --git a/python/packages/core/AGENTS.md b/python/packages/core/AGENTS.md index 823b601b76..3958957596 100644 --- a/python/packages/core/AGENTS.md +++ b/python/packages/core/AGENTS.md @@ -119,7 +119,7 @@ from agent_framework import Agent, AgentMiddleware, AgentContext class LoggingMiddleware(AgentMiddleware): async def process(self, context: AgentContext, call_next) -> None: print(f"Input: {context.messages}") - await call_next(context) + await call_next() print(f"Output: {context.result}") agent = Agent(..., middleware=[LoggingMiddleware()]) diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index ac6630a03f..e595be76e3 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -145,7 +145,7 @@ class AgentContext: context.metadata["start_time"] = time.time() # Continue execution - await call_next(context) + await call_next() # Access result after execution print(f"Result: {context.result}") @@ -229,7 +229,7 @@ class FunctionInvocationContext: raise MiddlewareTermination("Validation failed") # Continue execution - await call_next(context) + await call_next() """ def __init__( @@ -293,7 +293,7 @@ class ChatContext: context.metadata["input_tokens"] = self.count_tokens(context.messages) # Continue execution - await call_next(context) + await call_next() # Access result and count output tokens if context.result: @@ -365,7 +365,7 @@ class AgentMiddleware(ABC): async def process(self, context: AgentContext, call_next): for attempt in range(self.max_retries): - await call_next(context) + await call_next() if context.result and not context.result.is_error: break print(f"Retry {attempt + 1}/{self.max_retries}") @@ -379,7 +379,7 @@ class AgentMiddleware(ABC): async def process( self, context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Process an agent invocation. @@ -431,7 +431,7 @@ class FunctionMiddleware(ABC): raise MiddlewareTermination() # Execute function - await call_next(context) + await call_next() # Cache result if context.result: @@ -446,7 +446,7 @@ class FunctionMiddleware(ABC): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Process a function invocation. @@ -493,7 +493,7 @@ class ChatMiddleware(ABC): context.messages.insert(0, Message(role="system", text=self.system_prompt)) # Continue execution - await call_next(context) + await call_next() # Use with an agent @@ -508,7 +508,7 @@ class ChatMiddleware(ABC): async def process( self, context: ChatContext, - call_next: Callable[[ChatContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Process a chat client request. @@ -531,15 +531,13 @@ class ChatMiddleware(ABC): # Pure function type definitions for convenience -AgentMiddlewareCallable = Callable[[AgentContext, Callable[[AgentContext], Awaitable[None]]], Awaitable[None]] +AgentMiddlewareCallable = Callable[[AgentContext, Callable[[], Awaitable[None]]], Awaitable[None]] AgentMiddlewareTypes: TypeAlias = AgentMiddleware | AgentMiddlewareCallable -FunctionMiddlewareCallable = Callable[ - [FunctionInvocationContext, Callable[[FunctionInvocationContext], Awaitable[None]]], Awaitable[None] -] +FunctionMiddlewareCallable = Callable[[FunctionInvocationContext, Callable[[], Awaitable[None]]], Awaitable[None]] FunctionMiddlewareTypes: TypeAlias = FunctionMiddleware | FunctionMiddlewareCallable -ChatMiddlewareCallable = Callable[[ChatContext, Callable[[ChatContext], Awaitable[None]]], Awaitable[None]] +ChatMiddlewareCallable = Callable[[ChatContext, Callable[[], Awaitable[None]]], Awaitable[None]] ChatMiddlewareTypes: TypeAlias = ChatMiddleware | ChatMiddlewareCallable ChatAndFunctionMiddlewareTypes: TypeAlias = ( @@ -578,7 +576,7 @@ def agent_middleware(func: AgentMiddlewareCallable) -> AgentMiddlewareCallable: @agent_middleware async def logging_middleware(context: AgentContext, call_next): print(f"Before: {context.agent.name}") - await call_next(context) + await call_next() print(f"After: {context.result}") @@ -611,7 +609,7 @@ def function_middleware(func: FunctionMiddlewareCallable) -> FunctionMiddlewareC @function_middleware async def logging_middleware(context: FunctionInvocationContext, call_next): print(f"Calling: {context.function.name}") - await call_next(context) + await call_next() print(f"Result: {context.result}") @@ -644,7 +642,7 @@ def chat_middleware(func: ChatMiddlewareCallable) -> ChatMiddlewareCallable: @chat_middleware async def logging_middleware(context: ChatContext, call_next): print(f"Messages: {len(context.messages)}") - await call_next(context) + await call_next() print(f"Response: {context.result}") @@ -666,10 +664,10 @@ class MiddlewareWrapper(Generic[ContextT]): ContextT: The type of context object this middleware operates on. """ - def __init__(self, func: Callable[[ContextT, Callable[[ContextT], Awaitable[None]]], Awaitable[None]]) -> None: + def __init__(self, func: Callable[[ContextT, Callable[[], Awaitable[None]]], Awaitable[None]]) -> None: self.func = func - async def process(self, context: ContextT, call_next: Callable[[ContextT], Awaitable[None]]) -> None: + async def process(self, context: ContextT, call_next: Callable[[], Awaitable[None]]) -> None: await self.func(context, call_next) @@ -772,25 +770,25 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline): context.result = await context.result return context.result - def create_next_handler(index: int) -> Callable[[AgentContext], Awaitable[None]]: + def create_next_handler(index: int) -> Callable[[], Awaitable[None]]: if index >= len(self._middleware): - async def final_wrapper(c: AgentContext) -> None: - c.result = final_handler(c) # type: ignore[assignment] - if inspect.isawaitable(c.result): - c.result = await c.result + async def final_wrapper() -> None: + context.result = final_handler(context) # type: ignore[assignment] + if inspect.isawaitable(context.result): + context.result = await context.result return final_wrapper - async def current_handler(c: AgentContext) -> None: + async def current_handler() -> None: # MiddlewareTermination bubbles up to execute() to skip post-processing - await self._middleware[index].process(c, create_next_handler(index + 1)) + await self._middleware[index].process(context, create_next_handler(index + 1)) return current_handler first_handler = create_next_handler(0) with contextlib.suppress(MiddlewareTermination): - await first_handler(context) + await first_handler() if context.result and isinstance(context.result, ResponseStream): for hook in context.stream_transform_hooks: @@ -847,25 +845,25 @@ class FunctionMiddlewarePipeline(BaseMiddlewarePipeline): if not self._middleware: return await final_handler(context) - def create_next_handler(index: int) -> Callable[[FunctionInvocationContext], Awaitable[None]]: + def create_next_handler(index: int) -> Callable[[], Awaitable[None]]: if index >= len(self._middleware): - async def final_wrapper(c: FunctionInvocationContext) -> None: - c.result = final_handler(c) - if inspect.isawaitable(c.result): - c.result = await c.result + async def final_wrapper() -> None: + context.result = final_handler(context) + if inspect.isawaitable(context.result): + context.result = await context.result return final_wrapper - async def current_handler(c: FunctionInvocationContext) -> None: + async def current_handler() -> None: # MiddlewareTermination bubbles up to execute() to skip post-processing - await self._middleware[index].process(c, create_next_handler(index + 1)) + await self._middleware[index].process(context, create_next_handler(index + 1)) return current_handler first_handler = create_next_handler(0) # Don't suppress MiddlewareTermination - let it propagate to signal loop termination - await first_handler(context) + await first_handler() return context.result @@ -922,25 +920,25 @@ class ChatMiddlewarePipeline(BaseMiddlewarePipeline): raise ValueError("Streaming agent middleware requires a ResponseStream result.") return context.result - def create_next_handler(index: int) -> Callable[[ChatContext], Awaitable[None]]: + def create_next_handler(index: int) -> Callable[[], Awaitable[None]]: if index >= len(self._middleware): - async def final_wrapper(c: ChatContext) -> None: - c.result = final_handler(c) # type: ignore[assignment] - if inspect.isawaitable(c.result): - c.result = await c.result + async def final_wrapper() -> None: + context.result = final_handler(context) # type: ignore[assignment] + if inspect.isawaitable(context.result): + context.result = await context.result return final_wrapper - async def current_handler(c: ChatContext) -> None: + async def current_handler() -> None: # MiddlewareTermination bubbles up to execute() to skip post-processing - await self._middleware[index].process(c, create_next_handler(index + 1)) + await self._middleware[index].process(context, create_next_handler(index + 1)) return current_handler first_handler = create_next_handler(0) with contextlib.suppress(MiddlewareTermination): - await first_handler(context) + await first_handler() if context.result and isinstance(context.result, ResponseStream): for hook in context.stream_transform_hooks: diff --git a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py index b34164b86b..da8e907c40 100644 --- a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py +++ b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py @@ -19,12 +19,10 @@ class TestAsToolKwargsPropagation: captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Capture kwargs passed to the sub-agent captured_kwargs.update(context.kwargs) - await call_next(context) + await call_next() # Setup mock response client.responses = [ @@ -62,11 +60,9 @@ class TestAsToolKwargsPropagation: captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) - await call_next(context) + await call_next() # Setup mock response client.responses = [ @@ -99,12 +95,10 @@ class TestAsToolKwargsPropagation: captured_kwargs_list: list[dict[str, Any]] = [] @agent_middleware - async def capture_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Capture kwargs at each level captured_kwargs_list.append(dict(context.kwargs)) - await call_next(context) + await call_next() # Setup mock responses to trigger nested tool invocation: B calls tool C, then completes. client.responses = [ @@ -162,11 +156,9 @@ class TestAsToolKwargsPropagation: captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) - await call_next(context) + await call_next() # Setup mock streaming responses from agent_framework import ChatResponseUpdate @@ -224,11 +216,9 @@ class TestAsToolKwargsPropagation: captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) - await call_next(context) + await call_next() # Setup mock response client.responses = [ @@ -266,16 +256,14 @@ class TestAsToolKwargsPropagation: call_count = 0 @agent_middleware - async def capture_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: nonlocal call_count call_count += 1 if call_count == 1: first_call_kwargs.update(context.kwargs) elif call_count == 2: second_call_kwargs.update(context.kwargs) - await call_next(context) + await call_next() # Setup mock responses for both calls client.responses = [ @@ -318,11 +306,9 @@ class TestAsToolKwargsPropagation: captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) - await call_next(context) + await call_next() # Setup mock response client.responses = [ diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index dcc28958f5..e135e2fee6 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -2298,9 +2298,7 @@ async def test_streaming_error_recovery_resets_counter(chat_client_base: Support class TerminateLoopMiddleware(FunctionMiddleware): """Middleware that raises MiddlewareTermination to exit the function calling loop.""" - async def process( - self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]] - ) -> None: + async def process(self, context: FunctionInvocationContext, next_handler: Callable[[], Awaitable[None]]) -> None: # Set result to a simple value - the framework will wrap it in FunctionResultContent context.result = "terminated by middleware" raise MiddlewareTermination @@ -2355,14 +2353,12 @@ async def test_terminate_loop_single_function_call(chat_client_base: SupportsCha class SelectiveTerminateMiddleware(FunctionMiddleware): """Only terminates for terminating_function.""" - async def process( - self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]] - ) -> None: + async def process(self, context: FunctionInvocationContext, next_handler: Callable[[], Awaitable[None]]) -> None: if context.function.name == "terminating_function": # Set result to a simple value - the framework will wrap it in FunctionResultContent context.result = "terminated by middleware" raise MiddlewareTermination - await next_handler(context) + await next_handler() async def test_terminate_loop_multiple_function_calls_one_terminates(chat_client_base: SupportsChatGetResponse): diff --git a/python/packages/core/tests/core/test_middleware.py b/python/packages/core/tests/core/test_middleware.py index 41c15b2c70..e5bd23751f 100644 --- a/python/packages/core/tests/core/test_middleware.py +++ b/python/packages/core/tests/core/test_middleware.py @@ -135,12 +135,12 @@ class TestAgentMiddlewarePipeline: """Test cases for AgentMiddlewarePipeline.""" class PreNextTerminateMiddleware(AgentMiddleware): - async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: raise MiddlewareTermination class PostNextTerminateMiddleware(AgentMiddleware): async def process(self, context: AgentContext, call_next: Any) -> None: - await call_next(context) + await call_next() raise MiddlewareTermination def test_init_empty(self) -> None: @@ -157,8 +157,8 @@ class TestAgentMiddlewarePipeline: def test_init_with_function_middleware(self) -> None: """Test AgentMiddlewarePipeline initialization with function-based middleware.""" - async def test_middleware(context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: - await call_next(context) + async def test_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() pipeline = AgentMiddlewarePipeline(test_middleware) assert pipeline.has_middlewares @@ -185,11 +185,9 @@ class TestAgentMiddlewarePipeline: def __init__(self, name: str): self.name = name - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") - await call_next(context) + await call_next() execution_order.append(f"{self.name}_after") middleware = OrderTrackingMiddleware("test") @@ -238,11 +236,9 @@ class TestAgentMiddlewarePipeline: def __init__(self, name: str): self.name = name - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") - await call_next(context) + await call_next() execution_order.append(f"{self.name}_after") middleware = StreamOrderTrackingMiddleware("test") @@ -367,12 +363,10 @@ class TestAgentMiddlewarePipeline: captured_thread = None class ThreadCapturingMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: nonlocal captured_thread captured_thread = context.thread - await call_next(context) + await call_next() middleware = ThreadCapturingMiddleware() pipeline = AgentMiddlewarePipeline(middleware) @@ -394,12 +388,10 @@ class TestAgentMiddlewarePipeline: captured_thread = "not_none" # Use string to distinguish from None class ThreadCapturingMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: nonlocal captured_thread captured_thread = context.thread - await call_next(context) + await call_next() middleware = ThreadCapturingMiddleware() pipeline = AgentMiddlewarePipeline(middleware) @@ -425,7 +417,7 @@ class TestFunctionMiddlewarePipeline: class PostNextTerminateFunctionMiddleware(FunctionMiddleware): async def process(self, context: FunctionInvocationContext, call_next: Any) -> None: - await call_next(context) + await call_next() raise MiddlewareTermination async def test_execute_with_pre_next_termination(self, mock_function: FunctionTool[Any, Any]) -> None: @@ -482,10 +474,8 @@ class TestFunctionMiddlewarePipeline: def test_init_with_function_middleware(self) -> None: """Test FunctionMiddlewarePipeline initialization with function-based middleware.""" - async def test_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] - ) -> None: - await call_next(context) + async def test_middleware(context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() pipeline = FunctionMiddlewarePipeline(test_middleware) assert pipeline.has_middlewares @@ -515,10 +505,10 @@ class TestFunctionMiddlewarePipeline: async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append(f"{self.name}_before") - await call_next(context) + await call_next() execution_order.append(f"{self.name}_after") middleware = OrderTrackingFunctionMiddleware("test") @@ -541,12 +531,12 @@ class TestChatMiddlewarePipeline: """Test cases for ChatMiddlewarePipeline.""" class PreNextTerminateChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: raise MiddlewareTermination class PostNextTerminateChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: - await call_next(context) + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() raise MiddlewareTermination def test_init_empty(self) -> None: @@ -563,8 +553,8 @@ class TestChatMiddlewarePipeline: def test_init_with_function_middleware(self) -> None: """Test ChatMiddlewarePipeline initialization with function-based middleware.""" - async def test_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: - await call_next(context) + async def test_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() pipeline = ChatMiddlewarePipeline(test_middleware) assert pipeline.has_middlewares @@ -592,9 +582,9 @@ class TestChatMiddlewarePipeline: def __init__(self, name: str): self.name = name - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") - await call_next(context) + await call_next() execution_order.append(f"{self.name}_after") middleware = OrderTrackingChatMiddleware("test") @@ -644,9 +634,9 @@ class TestChatMiddlewarePipeline: def __init__(self, name: str): self.name = name - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") - await call_next(context) + await call_next() execution_order.append(f"{self.name}_after") middleware = StreamOrderTrackingChatMiddleware("test") @@ -774,12 +764,10 @@ class TestClassBasedMiddleware: metadata_updates: list[str] = [] class MetadataAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: context.metadata["before"] = True metadata_updates.append("before") - await call_next(context) + await call_next() context.metadata["after"] = True metadata_updates.append("after") @@ -807,11 +795,11 @@ class TestClassBasedMiddleware: async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: context.metadata["before"] = True metadata_updates.append("before") - await call_next(context) + await call_next() context.metadata["after"] = True metadata_updates.append("after") @@ -839,12 +827,10 @@ class TestFunctionBasedMiddleware: """Test function-based agent middleware.""" execution_order: list[str] = [] - async def test_agent_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def test_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("function_before") context.metadata["function_middleware"] = True - await call_next(context) + await call_next() execution_order.append("function_after") pipeline = AgentMiddlewarePipeline(test_agent_middleware) @@ -866,11 +852,11 @@ class TestFunctionBasedMiddleware: execution_order: list[str] = [] async def test_function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("function_before") context.metadata["function_middleware"] = True - await call_next(context) + await call_next() execution_order.append("function_after") pipeline = FunctionMiddlewarePipeline(test_function_middleware) @@ -896,18 +882,14 @@ class TestMixedMiddleware: execution_order: list[str] = [] class ClassMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("class_before") - await call_next(context) + await call_next() execution_order.append("class_after") - async def function_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def function_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("function_before") - await call_next(context) + await call_next() execution_order.append("function_after") pipeline = AgentMiddlewarePipeline(ClassMiddleware(), function_middleware) @@ -931,17 +913,17 @@ class TestMixedMiddleware: async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append("class_before") - await call_next(context) + await call_next() execution_order.append("class_after") async def function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("function_before") - await call_next(context) + await call_next() execution_order.append("function_after") pipeline = FunctionMiddlewarePipeline(ClassMiddleware(), function_middleware) @@ -962,16 +944,14 @@ class TestMixedMiddleware: execution_order: list[str] = [] class ClassChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("class_before") - await call_next(context) + await call_next() execution_order.append("class_after") - async def function_chat_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def function_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("function_before") - await call_next(context) + await call_next() execution_order.append("function_after") pipeline = ChatMiddlewarePipeline(ClassChatMiddleware(), function_chat_middleware) @@ -997,27 +977,21 @@ class TestMultipleMiddlewareOrdering: execution_order: list[str] = [] class FirstMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("first_before") - await call_next(context) + await call_next() execution_order.append("first_after") class SecondMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("second_before") - await call_next(context) + await call_next() execution_order.append("second_after") class ThirdMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("third_before") - await call_next(context) + await call_next() execution_order.append("third_after") middleware = [FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()] @@ -1051,20 +1025,20 @@ class TestMultipleMiddlewareOrdering: async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append("first_before") - await call_next(context) + await call_next() execution_order.append("first_after") class SecondMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append("second_before") - await call_next(context) + await call_next() execution_order.append("second_after") middleware = [FirstMiddleware(), SecondMiddleware()] @@ -1087,21 +1061,21 @@ class TestMultipleMiddlewareOrdering: execution_order: list[str] = [] class FirstChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("first_before") - await call_next(context) + await call_next() execution_order.append("first_after") class SecondChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("second_before") - await call_next(context) + await call_next() execution_order.append("second_after") class ThirdChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("third_before") - await call_next(context) + await call_next() execution_order.append("third_after") middleware = [FirstChatMiddleware(), SecondChatMiddleware(), ThirdChatMiddleware()] @@ -1136,9 +1110,7 @@ class TestContextContentValidation: """Test that agent context contains expected data.""" class ContextValidationMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Verify context has all expected attributes assert hasattr(context, "agent") assert hasattr(context, "messages") @@ -1156,7 +1128,7 @@ class TestContextContentValidation: # Add custom metadata context.metadata["validated"] = True - await call_next(context) + await call_next() middleware = ContextValidationMiddleware() pipeline = AgentMiddlewarePipeline(middleware) @@ -1178,7 +1150,7 @@ class TestContextContentValidation: async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # Verify context has all expected attributes assert hasattr(context, "function") @@ -1194,7 +1166,7 @@ class TestContextContentValidation: # Add custom metadata context.metadata["validated"] = True - await call_next(context) + await call_next() middleware = ContextValidationMiddleware() pipeline = FunctionMiddlewarePipeline(middleware) @@ -1213,7 +1185,7 @@ class TestContextContentValidation: """Test that chat context contains expected data.""" class ChatContextValidationMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: # Verify context has all expected attributes assert hasattr(context, "client") assert hasattr(context, "messages") @@ -1235,7 +1207,7 @@ class TestContextContentValidation: # Add custom metadata context.metadata["validated"] = True - await call_next(context) + await call_next() middleware = ChatContextValidationMiddleware() pipeline = ChatMiddlewarePipeline(middleware) @@ -1260,11 +1232,9 @@ class TestStreamingScenarios: streaming_flags: list[bool] = [] class StreamingFlagMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: streaming_flags.append(context.stream) - await call_next(context) + await call_next() middleware = StreamingFlagMiddleware() pipeline = AgentMiddlewarePipeline(middleware) @@ -1302,11 +1272,9 @@ class TestStreamingScenarios: chunks_processed: list[str] = [] class StreamProcessingMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: chunks_processed.append("before_stream") - await call_next(context) + await call_next() chunks_processed.append("after_stream") middleware = StreamProcessingMiddleware() @@ -1345,9 +1313,9 @@ class TestStreamingScenarios: streaming_flags: list[bool] = [] class ChatStreamingFlagMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: streaming_flags.append(context.stream) - await call_next(context) + await call_next() middleware = ChatStreamingFlagMiddleware() pipeline = ChatMiddlewarePipeline(middleware) @@ -1386,9 +1354,9 @@ class TestStreamingScenarios: chunks_processed: list[str] = [] class ChatStreamProcessingMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: chunks_processed.append("before_stream") - await call_next(context) + await call_next() chunks_processed.append("after_stream") middleware = ChatStreamProcessingMiddleware() @@ -1436,24 +1404,22 @@ class FunctionTestArgs(BaseModel): class TestAgentMiddleware(AgentMiddleware): """Test implementation of AgentMiddleware.""" - async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: - await call_next(context) + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() class TestFunctionMiddleware(FunctionMiddleware): """Test implementation of FunctionMiddleware.""" - async def process( - self, context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] - ) -> None: - await call_next(context) + async def process(self, context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() class TestChatMiddleware(ChatMiddleware): """Test implementation of ChatMiddleware.""" - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: - await call_next(context) + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() class MockFunctionArgs(BaseModel): @@ -1469,9 +1435,7 @@ class TestMiddlewareExecutionControl: """Test that when agent middleware doesn't call next(), no execution happens.""" class NoNextMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Don't call next() - this should prevent any execution pass @@ -1498,9 +1462,7 @@ class TestMiddlewareExecutionControl: """Test that when agent middleware doesn't call next(), no streaming execution happens.""" class NoNextStreamingMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Don't call next() - this should prevent any execution pass @@ -1537,7 +1499,7 @@ class TestMiddlewareExecutionControl: async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # Don't call next() - this should prevent any execution pass @@ -1566,18 +1528,14 @@ class TestMiddlewareExecutionControl: execution_order: list[str] = [] class FirstMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("first") # Don't call next() - this should stop the pipeline class SecondMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("second") - await call_next(context) + await call_next() pipeline = AgentMiddlewarePipeline(FirstMiddleware(), SecondMiddleware()) messages = [Message(role="user", text="test")] @@ -1601,7 +1559,7 @@ class TestMiddlewareExecutionControl: """Test that when chat middleware doesn't call next(), no execution happens.""" class NoNextChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: # Don't call next() - this should prevent any execution pass @@ -1629,7 +1587,7 @@ class TestMiddlewareExecutionControl: """Test that when chat middleware doesn't call next(), no streaming execution happens.""" class NoNextStreamingChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: # Don't call next() - this should prevent any execution pass @@ -1670,14 +1628,14 @@ class TestMiddlewareExecutionControl: execution_order: list[str] = [] class FirstChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("first") # Don't call next() - this should stop the pipeline class SecondChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("second") - await call_next(context) + await call_next() pipeline = ChatMiddlewarePipeline(FirstChatMiddleware(), SecondChatMiddleware()) messages = [Message(role="user", text="test")] diff --git a/python/packages/core/tests/core/test_middleware_context_result.py b/python/packages/core/tests/core/test_middleware_context_result.py index d17e99a85e..c5744fdca5 100644 --- a/python/packages/core/tests/core/test_middleware_context_result.py +++ b/python/packages/core/tests/core/test_middleware_context_result.py @@ -43,11 +43,9 @@ class TestResultOverrideMiddleware: override_response = AgentResponse(messages=[Message(role="assistant", text="overridden response")]) class ResponseOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Execute the pipeline first, then override the response - await call_next(context) + await call_next() context.result = override_response middleware = ResponseOverrideMiddleware() @@ -79,11 +77,9 @@ class TestResultOverrideMiddleware: yield AgentResponseUpdate(contents=[Content.from_text(text=" stream")]) class StreamResponseOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Execute the pipeline first, then override the response stream - await call_next(context) + await call_next() context.result = ResponseStream(override_stream()) middleware = StreamResponseOverrideMiddleware() @@ -115,10 +111,10 @@ class TestResultOverrideMiddleware: async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # Execute the pipeline first, then override the result - await call_next(context) + await call_next() context.result = override_result middleware = ResultOverrideMiddleware() @@ -145,11 +141,9 @@ class TestResultOverrideMiddleware: mock_chat_client = MockChatClient() class ChatAgentResponseOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Always call next() first to allow execution - await call_next(context) + await call_next() # Then conditionally override based on content if any("special" in msg.text for msg in context.messages if msg.text): context.result = AgentResponse( @@ -184,15 +178,13 @@ class TestResultOverrideMiddleware: yield AgentResponseUpdate(contents=[Content.from_text(text=" response!")]) class ChatAgentStreamOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Check if we want to override BEFORE calling next to avoid creating unused streams if any("custom stream" in msg.text for msg in context.messages if msg.text): context.result = ResponseStream(custom_stream()) return # Don't call next() - we're overriding the entire result # Normal case - let the agent handle it - await call_next(context) + await call_next() # Create Agent with override middleware middleware = ChatAgentStreamOverrideMiddleware() @@ -223,12 +215,10 @@ class TestResultOverrideMiddleware: """Test that when agent middleware conditionally doesn't call next(), no execution happens.""" class ConditionalNoNextMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Only call next() if message contains "execute" if any("execute" in msg.text for msg in context.messages if msg.text): - await call_next(context) + await call_next() # Otherwise, don't call next() - no execution should happen middleware = ConditionalNoNextMiddleware() @@ -269,13 +259,13 @@ class TestResultOverrideMiddleware: async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # Only call next() if argument name contains "execute" args = context.arguments assert isinstance(args, FunctionTestArgs) if "execute" in args.name: - await call_next(context) + await call_next() # Otherwise, don't call next() - no execution should happen middleware = ConditionalNoNextFunctionMiddleware() @@ -318,14 +308,12 @@ class TestResultObservability: observed_responses: list[AgentResponse] = [] class ObservabilityMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Context should be empty before next() assert context.result is None # Call next to execute - await call_next(context) + await call_next() # Context should now contain the response for observability assert context.result is not None @@ -355,13 +343,13 @@ class TestResultObservability: async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # Context should be empty before next() assert context.result is None # Call next to execute - await call_next(context) + await call_next() # Context should now contain the result for observability assert context.result is not None @@ -386,11 +374,9 @@ class TestResultObservability: """Test that middleware can override response after observing execution.""" class PostExecutionOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Call next to execute first - await call_next(context) + await call_next() # Now observe and conditionally override assert context.result is not None @@ -423,10 +409,10 @@ class TestResultObservability: async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # Call next to execute first - await call_next(context) + await call_next() # Now observe and conditionally override assert context.result is not None diff --git a/python/packages/core/tests/core/test_middleware_with_agent.py b/python/packages/core/tests/core/test_middleware_with_agent.py index 17f0faf4f0..597ca12dbd 100644 --- a/python/packages/core/tests/core/test_middleware_with_agent.py +++ b/python/packages/core/tests/core/test_middleware_with_agent.py @@ -44,11 +44,9 @@ class TestChatAgentClassBasedMiddleware: def __init__(self, name: str): self.name = name - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") - await call_next(context) + await call_next() execution_order.append(f"{self.name}_after") # Create Agent with middleware @@ -76,9 +74,9 @@ class TestChatAgentClassBasedMiddleware: async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: - await call_next(context) + await call_next() middleware = TrackingFunctionMiddleware() Agent(client=client, middleware=[middleware]) @@ -96,10 +94,10 @@ class TestChatAgentClassBasedMiddleware: async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append(f"{self.name}_before") - await call_next(context) + await call_next() execution_order.append(f"{self.name}_after") middleware = TrackingFunctionMiddleware("function_middleware") @@ -122,13 +120,11 @@ class TestChatAgentFunctionBasedMiddleware: execution_order: list[str] = [] class PreTerminationMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("middleware_before") raise MiddlewareTermination # Code after raise is unreachable - await call_next(context) + await call_next() execution_order.append("middleware_after") # Create Agent with terminating middleware @@ -153,11 +149,9 @@ class TestChatAgentFunctionBasedMiddleware: execution_order: list[str] = [] class PostTerminationMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("middleware_before") - await call_next(context) + await call_next() execution_order.append("middleware_after") context.terminate = True @@ -193,12 +187,12 @@ class TestChatAgentFunctionBasedMiddleware: async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append("middleware_before") context.terminate = True # We call next() but since terminate=True, subsequent middleware and handler should not execute - await call_next(context) + await call_next() execution_order.append("middleware_after") Agent(client=client, middleware=[PreTerminationFunctionMiddleware()], tools=[]) @@ -211,10 +205,10 @@ class TestChatAgentFunctionBasedMiddleware: async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append("middleware_before") - await call_next(context) + await call_next() execution_order.append("middleware_after") context.terminate = True @@ -224,11 +218,9 @@ class TestChatAgentFunctionBasedMiddleware: """Test function-based agent middleware with Agent.""" execution_order: list[str] = [] - async def tracking_agent_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def tracking_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("agent_function_before") - await call_next(context) + await call_next() execution_order.append("agent_function_after") # Create Agent with function middleware @@ -252,9 +244,9 @@ class TestChatAgentFunctionBasedMiddleware: """Test function-based function middleware with Agent.""" async def tracking_function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: - await call_next(context) + await call_next() Agent(client=client, middleware=[tracking_function_middleware]) @@ -265,10 +257,10 @@ class TestChatAgentFunctionBasedMiddleware: execution_order: list[str] = [] async def tracking_function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("function_function_before") - await call_next(context) + await call_next() execution_order.append("function_function_after") agent = Agent(client=chat_client_base, middleware=[tracking_function_middleware]) @@ -290,12 +282,10 @@ class TestChatAgentStreamingMiddleware: streaming_flags: list[bool] = [] class StreamingTrackingMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("middleware_before") streaming_flags.append(context.stream) - await call_next(context) + await call_next() execution_order.append("middleware_after") # Create Agent with middleware @@ -334,11 +324,9 @@ class TestChatAgentStreamingMiddleware: streaming_flags: list[bool] = [] class FlagTrackingMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: streaming_flags.append(context.stream) - await call_next(context) + await call_next() # Create Agent with middleware middleware = FlagTrackingMiddleware() @@ -368,11 +356,9 @@ class TestChatAgentMultipleMiddlewareOrdering: def __init__(self, name: str): self.name = name - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") - await call_next(context) + await call_next() execution_order.append(f"{self.name}_after") # Create multiple middleware @@ -400,35 +386,31 @@ class TestChatAgentMultipleMiddlewareOrdering: execution_order: list[str] = [] class ClassAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("class_agent_before") - await call_next(context) + await call_next() execution_order.append("class_agent_after") - async def function_agent_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def function_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("function_agent_before") - await call_next(context) + await call_next() execution_order.append("function_agent_after") class ClassFunctionMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append("class_function_before") - await call_next(context) + await call_next() execution_order.append("class_function_after") async def function_function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("function_function_before") - await call_next(context) + await call_next() execution_order.append("function_function_after") agent = Agent( @@ -447,25 +429,21 @@ class TestChatAgentMultipleMiddlewareOrdering: execution_order: list[str] = [] class ClassAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("class_agent_before") - await call_next(context) + await call_next() execution_order.append("class_agent_after") - async def function_agent_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def function_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("function_agent_before") - await call_next(context) + await call_next() execution_order.append("function_agent_after") async def function_function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("function_function_before") - await call_next(context) + await call_next() execution_order.append("function_function_after") agent = Agent( @@ -521,10 +499,10 @@ class TestChatAgentFunctionMiddlewareWithTools: async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append(f"{self.name}_before") - await call_next(context) + await call_next() execution_order.append(f"{self.name}_after") # Set up mock to return a function call first, then a regular response @@ -583,10 +561,10 @@ class TestChatAgentFunctionMiddlewareWithTools: execution_order: list[str] = [] async def tracking_function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("function_middleware_before") - await call_next(context) + await call_next() execution_order.append("function_middleware_after") # Set up mock to return a function call first, then a regular response @@ -647,20 +625,20 @@ class TestChatAgentFunctionMiddlewareWithTools: async def process( self, context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append("agent_middleware_before") - await call_next(context) + await call_next() execution_order.append("agent_middleware_after") class TrackingFunctionMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append("function_middleware_before") - await call_next(context) + await call_next() execution_order.append("function_middleware_after") # Set up mock to return a function call first, then a regular response @@ -728,7 +706,7 @@ class TestChatAgentFunctionMiddlewareWithTools: @function_middleware async def kwargs_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: nonlocal middleware_called middleware_called = True @@ -748,7 +726,7 @@ class TestChatAgentFunctionMiddlewareWithTools: modified_kwargs["new_param"] = context.kwargs.get("new_param") modified_kwargs["custom_param"] = context.kwargs.get("custom_param") - await call_next(context) + await call_next() chat_client_base.run_responses = [ ChatResponse( @@ -801,9 +779,9 @@ class TestMiddlewareDynamicRebuild: self.name = name self.execution_log = execution_log - async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: self.execution_log.append(f"{self.name}_start") - await call_next(context) + await call_next() self.execution_log.append(f"{self.name}_end") async def test_middleware_dynamic_rebuild_non_streaming(self, client: "MockChatClient") -> None: @@ -924,9 +902,9 @@ class TestRunLevelMiddleware: self.name = name self.execution_log = execution_log - async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: self.execution_log.append(f"{self.name}_start") - await call_next(context) + await call_next() self.execution_log.append(f"{self.name}_end") async def test_run_level_middleware_isolation(self, client: "MockChatClient") -> None: @@ -976,29 +954,25 @@ class TestRunLevelMiddleware: def __init__(self, name: str): self.name = name - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_log.append(f"{self.name}_start") # Set metadata to pass information to run middleware context.metadata[f"{self.name}_key"] = f"{self.name}_value" - await call_next(context) + await call_next() execution_log.append(f"{self.name}_end") class MetadataRunMiddleware(AgentMiddleware): def __init__(self, name: str): self.name = name - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_log.append(f"{self.name}_start") # Read metadata set by agent middleware for key, value in context.metadata.items(): metadata_log.append(f"{self.name}_reads_{key}:{value}") # Set run-level metadata context.metadata[f"{self.name}_key"] = f"{self.name}_value" - await call_next(context) + await call_next() execution_log.append(f"{self.name}_end") # Create agent with agent-level middleware @@ -1049,12 +1023,10 @@ class TestRunLevelMiddleware: def __init__(self, name: str): self.name = name - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_log.append(f"{self.name}_start") streaming_flags.append(context.stream) - await call_next(context) + await call_next() execution_log.append(f"{self.name}_end") # Create agent without agent-level middleware @@ -1093,48 +1065,44 @@ class TestRunLevelMiddleware: # Agent-level middleware class AgentLevelAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_log.append("agent_level_agent_start") context.metadata["agent_level_agent"] = "processed" - await call_next(context) + await call_next() execution_log.append("agent_level_agent_end") class AgentLevelFunctionMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_log.append("agent_level_function_start") context.metadata["agent_level_function"] = "processed" - await call_next(context) + await call_next() execution_log.append("agent_level_function_end") # Run-level middleware class RunLevelAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_log.append("run_level_agent_start") # Verify agent-level middleware metadata is available assert "agent_level_agent" in context.metadata context.metadata["run_level_agent"] = "processed" - await call_next(context) + await call_next() execution_log.append("run_level_agent_end") class RunLevelFunctionMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_log.append("run_level_function_start") # Verify agent-level function middleware metadata is available assert "agent_level_function" in context.metadata context.metadata["run_level_function"] = "processed" - await call_next(context) + await call_next() execution_log.append("run_level_function_end") # Create tool function for testing function middleware @@ -1217,18 +1185,16 @@ class TestMiddlewareDecoratorLogic: execution_order: list[str] = [] @agent_middleware - async def matching_agent_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def matching_agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("decorator_type_match_agent") - await call_next(context) + await call_next() @function_middleware async def matching_function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("decorator_type_match_function") - await call_next(context) + await call_next() # Create tool function for testing function middleware def custom_tool(message: str) -> str: @@ -1282,7 +1248,7 @@ class TestMiddlewareDecoratorLogic: context: FunctionInvocationContext, # Wrong type for @agent_middleware call_next: Any, ) -> None: - await call_next(context) + await call_next() agent = Agent(client=client, middleware=[mismatched_middleware]) await agent.run([Message(role="user", text="test")]) @@ -1294,12 +1260,12 @@ class TestMiddlewareDecoratorLogic: @agent_middleware async def decorator_only_agent(context: Any, call_next: Any) -> None: # No type annotation execution_order.append("decorator_only_agent") - await call_next(context) + await call_next() @function_middleware async def decorator_only_function(context: Any, call_next: Any) -> None: # No type annotation execution_order.append("decorator_only_function") - await call_next(context) + await call_next() # Create tool function for testing function middleware def custom_tool(message: str) -> str: @@ -1346,16 +1312,16 @@ class TestMiddlewareDecoratorLogic: execution_order: list[str] = [] # No decorator - async def type_only_agent(context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: + async def type_only_agent(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("type_only_agent") - await call_next(context) + await call_next() # No decorator async def type_only_function( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("type_only_function") - await call_next(context) + await call_next() # Create tool function for testing function middleware def custom_tool(message: str) -> str: @@ -1399,7 +1365,7 @@ class TestMiddlewareDecoratorLogic: """Neither decorator nor parameter type specified - should throw exception.""" async def no_info_middleware(context: Any, call_next: Any) -> None: # No decorator, no type - await call_next(context) + await call_next() # Should raise MiddlewareException with pytest.raises(MiddlewareException, match="Cannot determine middleware type"): @@ -1447,9 +1413,7 @@ class TestChatAgentThreadBehavior: thread_states: list[dict[str, Any]] = [] class ThreadTrackingMiddleware(AgentMiddleware): - async def process( - self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Capture state before next() call thread_messages = [] if context.thread and context.thread.message_store: @@ -1464,7 +1428,7 @@ class TestChatAgentThreadBehavior: } thread_states.append(before_state) - await call_next(context) + await call_next() # Capture state after next() call thread_messages_after = [] @@ -1560,9 +1524,9 @@ class TestChatAgentChatMiddleware: execution_order: list[str] = [] class TrackingChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("chat_middleware_before") - await call_next(context) + await call_next() execution_order.append("chat_middleware_after") # Create Agent with chat middleware @@ -1588,11 +1552,9 @@ class TestChatAgentChatMiddleware: """Test function-based chat middleware with Agent.""" execution_order: list[str] = [] - async def tracking_chat_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def tracking_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("chat_middleware_before") - await call_next(context) + await call_next() execution_order.append("chat_middleware_after") # Create Agent with function-based chat middleware @@ -1617,9 +1579,7 @@ class TestChatAgentChatMiddleware: """Test that chat middleware can modify messages before sending to model.""" @chat_middleware - async def message_modifier_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def message_modifier_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: # Modify the first message by adding a prefix if context.messages: for idx, msg in enumerate(context.messages): @@ -1628,7 +1588,7 @@ class TestChatAgentChatMiddleware: original_text = msg.text or "" context.messages[idx] = Message(role=msg.role, text=f"MODIFIED: {original_text}") break - await call_next(context) + await call_next() # Create Agent with message-modifying middleware client = MockBaseChatClient() @@ -1646,9 +1606,7 @@ class TestChatAgentChatMiddleware: """Test that chat middleware can override the response.""" @chat_middleware - async def response_override_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def response_override_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: # Override the response without calling next() context.result = ChatResponse( messages=[Message(role="assistant", text="MiddlewareTypes overridden response")], @@ -1675,15 +1633,15 @@ class TestChatAgentChatMiddleware: execution_order: list[str] = [] @chat_middleware - async def first_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("first_before") - await call_next(context) + await call_next() execution_order.append("first_after") @chat_middleware - async def second_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("second_before") - await call_next(context) + await call_next() execution_order.append("second_after") # Create Agent with multiple chat middleware @@ -1709,10 +1667,10 @@ class TestChatAgentChatMiddleware: streaming_flags: list[bool] = [] class StreamingTrackingChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("streaming_chat_before") streaming_flags.append(context.stream) - await call_next(context) + await call_next() execution_order.append("streaming_chat_after") # Create Agent with chat middleware @@ -1749,13 +1707,13 @@ class TestChatAgentChatMiddleware: execution_order: list[str] = [] class PreTerminationChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("middleware_before") # Set a custom response since we're terminating context.result = ChatResponse(messages=[Message(role="assistant", text="Terminated by middleware")]) raise MiddlewareTermination # We call next() but since terminate=True, execution should stop - await call_next(context) + await call_next() execution_order.append("middleware_after") # Create Agent with terminating middleware @@ -1777,9 +1735,9 @@ class TestChatAgentChatMiddleware: execution_order: list[str] = [] class PostTerminationChatMiddleware(ChatMiddleware): - async def process(self, context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("middleware_before") - await call_next(context) + await call_next() execution_order.append("middleware_after") context.terminate = True @@ -1804,21 +1762,21 @@ class TestChatAgentChatMiddleware: """Test Agent with combined middleware types.""" execution_order: list[str] = [] - async def agent_middleware(context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: + async def agent_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("agent_middleware_before") - await call_next(context) + await call_next() execution_order.append("agent_middleware_after") - async def chat_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("chat_middleware_before") - await call_next(context) + await call_next() execution_order.append("chat_middleware_after") async def function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("function_middleware_before") - await call_next(context) + await call_next() execution_order.append("function_middleware_after") # Create Agent with function middleware and tools @@ -1842,9 +1800,7 @@ class TestChatAgentChatMiddleware: modified_kwargs: dict[str, Any] = {} @agent_middleware - async def kwargs_middleware( - context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] - ) -> None: + async def kwargs_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Capture the original kwargs captured_kwargs.update(context.kwargs) @@ -1856,7 +1812,7 @@ class TestChatAgentChatMiddleware: # Store modified kwargs for verification modified_kwargs.update(context.kwargs) - await call_next(context) + await call_next() # Create Agent with agent middleware client = MockBaseChatClient() @@ -1895,10 +1851,10 @@ class TestChatAgentChatMiddleware: # class TrackingMiddleware(AgentMiddleware): # async def process( -# self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]] +# self, context: AgentContext, call_next: Callable[[], Awaitable[None]] # ) -> None: # execution_order.append("before") -# await call_next(context) +# await call_next() # execution_order.append("after") # @use_agent_middleware diff --git a/python/packages/core/tests/core/test_middleware_with_chat.py b/python/packages/core/tests/core/test_middleware_with_chat.py index 3c9d0246c7..62a168ccb0 100644 --- a/python/packages/core/tests/core/test_middleware_with_chat.py +++ b/python/packages/core/tests/core/test_middleware_with_chat.py @@ -32,10 +32,10 @@ class TestChatMiddleware: async def process( self, context: ChatContext, - call_next: Callable[[ChatContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append("chat_middleware_before") - await call_next(context) + await call_next() execution_order.append("chat_middleware_after") # Add middleware to chat client @@ -58,11 +58,9 @@ class TestChatMiddleware: execution_order: list[str] = [] @chat_middleware - async def logging_chat_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def logging_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("function_middleware_before") - await call_next(context) + await call_next() execution_order.append("function_middleware_after") # Add middleware to chat client @@ -84,14 +82,12 @@ class TestChatMiddleware: """Test that chat middleware can modify messages before sending to model.""" @chat_middleware - async def message_modifier_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def message_modifier_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: # Modify the first message by adding a prefix if context.messages and len(context.messages) > 0: original_text = context.messages[0].text or "" context.messages[0] = Message(role=context.messages[0].role, text=f"MODIFIED: {original_text}") - await call_next(context) + await call_next() # Add middleware to chat client chat_client_base.chat_middleware = [message_modifier_middleware] @@ -110,9 +106,7 @@ class TestChatMiddleware: """Test that chat middleware can override the response.""" @chat_middleware - async def response_override_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def response_override_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: # Override the response without calling next() context.result = ChatResponse( messages=[Message(role="assistant", text="MiddlewareTypes overridden response")], @@ -138,15 +132,15 @@ class TestChatMiddleware: execution_order: list[str] = [] @chat_middleware - async def first_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("first_before") - await call_next(context) + await call_next() execution_order.append("first_after") @chat_middleware - async def second_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("second_before") - await call_next(context) + await call_next() execution_order.append("second_after") # Add middleware to chat client (order should be preserved) @@ -173,11 +167,9 @@ class TestChatMiddleware: execution_order: list[str] = [] @chat_middleware - async def agent_level_chat_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def agent_level_chat_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("agent_chat_middleware_before") - await call_next(context) + await call_next() execution_order.append("agent_chat_middleware_after") client = MockBaseChatClient() @@ -205,15 +197,15 @@ class TestChatMiddleware: execution_order: list[str] = [] @chat_middleware - async def first_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("first_before") - await call_next(context) + await call_next() execution_order.append("first_after") @chat_middleware - async def second_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("second_before") - await call_next(context) + await call_next() execution_order.append("second_after") # Create Agent with multiple chat middleware @@ -240,9 +232,7 @@ class TestChatMiddleware: execution_order: list[str] = [] @chat_middleware - async def streaming_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def streaming_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_order.append("streaming_before") # Verify it's a streaming context assert context.stream is True @@ -254,7 +244,7 @@ class TestChatMiddleware: return update context.stream_transform_hooks.append(upper_case_update) - await call_next(context) + await call_next() execution_order.append("streaming_after") # Add middleware to chat client @@ -278,11 +268,9 @@ class TestChatMiddleware: execution_count = {"count": 0} @chat_middleware - async def counting_middleware( - context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]] - ) -> None: + async def counting_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: execution_count["count"] += 1 - await call_next(context) + await call_next() # First call with run-level middleware messages = [Message(role="user", text="first message")] @@ -310,7 +298,7 @@ class TestChatMiddleware: modified_kwargs: dict[str, Any] = {} @chat_middleware - async def kwargs_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: + async def kwargs_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: # Capture the original kwargs captured_kwargs.update(context.kwargs) @@ -322,7 +310,7 @@ class TestChatMiddleware: # Store modified kwargs for verification modified_kwargs.update(context.kwargs) - await call_next(context) + await call_next() # Add middleware to chat client chat_client_base.chat_middleware = [kwargs_middleware] @@ -355,11 +343,11 @@ class TestChatMiddleware: @function_middleware async def test_function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: nonlocal execution_order execution_order.append(f"function_middleware_before_{context.function.name}") - await call_next(context) + await call_next() execution_order.append(f"function_middleware_after_{context.function.name}") # Define a simple tool function @@ -421,10 +409,10 @@ class TestChatMiddleware: @function_middleware async def run_level_function_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: execution_order.append("run_level_function_middleware_before") - await call_next(context) + await call_next() execution_order.append("run_level_function_middleware_after") # Define a simple tool function diff --git a/python/packages/ollama/tests/test_ollama_chat_client.py b/python/packages/ollama/tests/test_ollama_chat_client.py index d65836b2bc..8d179b982d 100644 --- a/python/packages/ollama/tests/test_ollama_chat_client.py +++ b/python/packages/ollama/tests/test_ollama_chat_client.py @@ -207,7 +207,7 @@ def test_serialize(ollama_unit_test_env: dict[str, str]) -> None: def test_chat_middleware(ollama_unit_test_env: dict[str, str]) -> None: @chat_middleware async def sample_middleware(context, call_next): - await call_next(context) + await call_next() ollama_chat_client = OllamaChatClient(middleware=[sample_middleware]) assert len(ollama_chat_client.middleware) == 1 diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py index 37f499d763..e574528395 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py @@ -129,11 +129,11 @@ class _AutoHandoffMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Intercept matching handoff tool calls and inject synthetic results.""" if context.function.name not in self._handoff_functions: - await call_next(context) + await call_next() return from agent_framework._middleware import MiddlewareTermination diff --git a/python/packages/purview/agent_framework_purview/_middleware.py b/python/packages/purview/agent_framework_purview/_middleware.py index 10e0443b0b..2da8de84ee 100644 --- a/python/packages/purview/agent_framework_purview/_middleware.py +++ b/python/packages/purview/agent_framework_purview/_middleware.py @@ -65,7 +65,7 @@ class PurviewPolicyMiddleware(AgentMiddleware): async def process( self, context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # type: ignore[override] resolved_user_id: str | None = None try: @@ -92,7 +92,7 @@ class PurviewPolicyMiddleware(AgentMiddleware): if not self._settings.ignore_exceptions: raise - await call_next(context) + await call_next() try: # Post (response) check only if we have a normal AgentResponse @@ -162,7 +162,7 @@ class PurviewChatPolicyMiddleware(ChatMiddleware): async def process( self, context: ChatContext, - call_next: Callable[[ChatContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # type: ignore[override] resolved_user_id: str | None = None try: @@ -187,7 +187,7 @@ class PurviewChatPolicyMiddleware(ChatMiddleware): if not self._settings.ignore_exceptions: raise - await call_next(context) + await call_next() try: # Post (response) evaluation only if non-streaming and we have messages result shape diff --git a/python/packages/purview/tests/purview/test_chat_middleware.py b/python/packages/purview/tests/purview/test_chat_middleware.py index 677e3e277b..bc9be01e1f 100644 --- a/python/packages/purview/tests/purview/test_chat_middleware.py +++ b/python/packages/purview/tests/purview/test_chat_middleware.py @@ -49,7 +49,7 @@ class TestPurviewChatPolicyMiddleware: with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: next_called = False - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: nonlocal next_called next_called = True @@ -57,7 +57,7 @@ class TestPurviewChatPolicyMiddleware: def __init__(self): self.messages = [Message(role="assistant", text="Hi there")] - ctx.result = Result() + chat_context.result = Result() await middleware.process(chat_context, mock_next) assert next_called @@ -67,7 +67,7 @@ class TestPurviewChatPolicyMiddleware: async def test_blocks_prompt(self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext) -> None: with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")): - async def mock_next(ctx: ChatContext) -> None: # should not run + async def mock_next() -> None: # should not run raise AssertionError("next should not be called when prompt blocked") with pytest.raises(MiddlewareTermination): @@ -88,12 +88,12 @@ class TestPurviewChatPolicyMiddleware: with patch.object(middleware._processor, "process_messages", side_effect=side_effect): - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: class Result: def __init__(self): self.messages = [Message(role="assistant", text="Sensitive output")] # pragma: no cover - ctx.result = Result() + chat_context.result = Result() await middleware.process(chat_context, mock_next) assert call_state["count"] == 2 @@ -114,8 +114,8 @@ class TestPurviewChatPolicyMiddleware: ) with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: ChatContext) -> None: - ctx.result = MagicMock() + async def mock_next() -> None: + streaming_context.result = MagicMock() await middleware.process(streaming_context, mock_next) assert mock_proc.call_count == 1 @@ -138,10 +138,10 @@ class TestPurviewChatPolicyMiddleware: with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: result = MagicMock() result.messages = [Message(role="assistant", text="Response")] - ctx.result = result + chat_context.result = result await middleware.process(chat_context, mock_next) @@ -162,10 +162,10 @@ class TestPurviewChatPolicyMiddleware: with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: result = MagicMock() result.messages = [Message(role="assistant", text="Response")] - ctx.result = result + chat_context.result = result await middleware.process(chat_context, mock_next) @@ -194,7 +194,7 @@ class TestPurviewChatPolicyMiddleware: with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: raise AssertionError("next should not be called") # Should raise the exception @@ -224,10 +224,10 @@ class TestPurviewChatPolicyMiddleware: with patch.object(middleware._processor, "process_messages", side_effect=side_effect): - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: result = MagicMock() result.messages = [Message(role="assistant", text="OK")] - ctx.result = result + context.result = result with pytest.raises(PurviewPaymentRequiredError): await middleware.process(context, mock_next) @@ -249,7 +249,7 @@ class TestPurviewChatPolicyMiddleware: with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: result = MagicMock() result.messages = [Message(role="assistant", text="Response")] context.result = result @@ -265,9 +265,9 @@ class TestPurviewChatPolicyMiddleware: """Test middleware handles result that doesn't have messages attribute.""" with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")): - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: # Set result to something without messages attribute - ctx.result = "Some string result" + chat_context.result = "Some string result" await middleware.process(chat_context, mock_next) @@ -289,7 +289,7 @@ class TestPurviewChatPolicyMiddleware: with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: result = MagicMock() result.messages = [Message(role="assistant", text="Response")] context.result = result @@ -313,7 +313,7 @@ class TestPurviewChatPolicyMiddleware: with patch.object(middleware._processor, "process_messages", side_effect=ValueError("boom")): - async def mock_next(_: ChatContext) -> None: + async def mock_next() -> None: raise AssertionError("next should not be called") with pytest.raises(ValueError, match="boom"): @@ -342,10 +342,10 @@ class TestPurviewChatPolicyMiddleware: with patch.object(middleware._processor, "process_messages", side_effect=side_effect): - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: result = MagicMock() result.messages = [Message(role="assistant", text="OK")] - ctx.result = result + context.result = result with pytest.raises(ValueError, match="post"): await middleware.process(context, mock_next) @@ -361,10 +361,10 @@ class TestPurviewChatPolicyMiddleware: with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: result = MagicMock() result.messages = [Message(role="assistant", text="Hi")] - ctx.result = result + context.result = result await middleware.process(context, mock_next) @@ -382,10 +382,10 @@ class TestPurviewChatPolicyMiddleware: with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: result = MagicMock() result.messages = [Message(role="assistant", text="Hi")] - ctx.result = result + context.result = result await middleware.process(context, mock_next) @@ -401,10 +401,10 @@ class TestPurviewChatPolicyMiddleware: with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: ChatContext) -> None: + async def mock_next() -> None: result = MagicMock() result.messages = [Message(role="assistant", text="Response")] - ctx.result = result + context.result = result await middleware.process(context, mock_next) diff --git a/python/packages/purview/tests/purview/test_middleware.py b/python/packages/purview/tests/purview/test_middleware.py index ff77331155..98dafab1e1 100644 --- a/python/packages/purview/tests/purview/test_middleware.py +++ b/python/packages/purview/tests/purview/test_middleware.py @@ -55,10 +55,10 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")): next_called = False - async def mock_next(ctx: AgentContext) -> None: + async def mock_next() -> None: nonlocal next_called next_called = True - ctx.result = AgentResponse(messages=[Message(role="assistant", text="I'm good, thanks!")]) + context.result = AgentResponse(messages=[Message(role="assistant", text="I'm good, thanks!")]) await middleware.process(context, mock_next) @@ -74,7 +74,7 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")): next_called = False - async def mock_next(ctx: AgentContext) -> None: + async def mock_next() -> None: nonlocal next_called next_called = True @@ -101,8 +101,8 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse( + async def mock_next() -> None: + context.result = AgentResponse( messages=[Message(role="assistant", text="Here's some sensitive information")] ) @@ -125,8 +125,8 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")): - async def mock_next(ctx: AgentContext) -> None: - ctx.result = "Some non-standard result" + async def mock_next() -> None: + context.result = "Some non-standard result" await middleware.process(context, mock_next) @@ -142,8 +142,8 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_process: - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) await middleware.process(context, mock_next) @@ -160,8 +160,8 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="streaming")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="streaming")]) await middleware.process(context, mock_next) @@ -181,7 +181,7 @@ class TestPurviewPolicyMiddleware: side_effect=PurviewPaymentRequiredError("Payment required"), ): - async def mock_next(_: AgentContext) -> None: + async def mock_next() -> None: raise AssertionError("next should not be called") with pytest.raises(PurviewPaymentRequiredError): @@ -206,8 +206,8 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", side_effect=side_effect): - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="OK")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="OK")]) with pytest.raises(PurviewPaymentRequiredError): await middleware.process(context, mock_next) @@ -231,8 +231,8 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", side_effect=side_effect): - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="OK")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="OK")]) with pytest.raises(ValueError, match="Post-check blew up"): await middleware.process(context, mock_next) @@ -250,8 +250,8 @@ class TestPurviewPolicyMiddleware: middleware._processor, "process_messages", side_effect=Exception("Pre-check error") ) as mock_process: - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) await middleware.process(context, mock_next) @@ -280,8 +280,8 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) await middleware.process(context, mock_next) @@ -306,8 +306,8 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx): - ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) + async def mock_next(): + context.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) # Should not raise, just log await middleware.process(context, mock_next) @@ -330,7 +330,7 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx): + async def mock_next(): pass # Should raise the exception @@ -346,8 +346,8 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="Hi")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="Hi")]) await middleware.process(context, mock_next) @@ -364,8 +364,8 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="Hi")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="Hi")]) await middleware.process(context, mock_next) @@ -383,8 +383,8 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="Hi")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="Hi")]) await middleware.process(context, mock_next) @@ -399,8 +399,8 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="Hi")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="Hi")]) await middleware.process(context, mock_next) @@ -416,8 +416,8 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: AgentContext) -> None: - ctx.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) + async def mock_next() -> None: + context.result = AgentResponse(messages=[Message(role="assistant", text="Response")]) await middleware.process(context, mock_next) diff --git a/python/samples/concepts/tools/README.md b/python/samples/concepts/tools/README.md index 91c481842d..2af617cf4c 100644 --- a/python/samples/concepts/tools/README.md +++ b/python/samples/concepts/tools/README.md @@ -267,7 +267,7 @@ class TerminatingMiddleware(FunctionMiddleware): if self.should_terminate(context): context.result = "terminated by middleware" raise MiddlewareTermination # Exit function invocation loop - await call_next(context) + await call_next() ``` ## Arguments Added/Altered at Each Layer @@ -347,7 +347,7 @@ class CachingMiddleware(FunctionMiddleware): return # Upstream post-processing still runs # Option B: Call call_next, then return normally - await call_next(context) + await call_next() self.cache[context.function.name] = context.result return # Normal completion ``` @@ -362,7 +362,7 @@ class BlockedFunctionMiddleware(FunctionMiddleware): if context.function.name in self.blocked_functions: context.result = "Function blocked by policy" raise MiddlewareTermination("Blocked") # Skips ALL post-processing - await call_next(context) + await call_next() ``` ### 3. Raise Any Other Exception @@ -374,7 +374,7 @@ class ValidationMiddleware(FunctionMiddleware): async def process(self, context: FunctionInvocationContext, call_next): if not self.is_valid(context.arguments): raise ValueError("Invalid arguments") # Bubbles up to user - await call_next(context) + await call_next() ``` ## `return` vs `raise MiddlewareTermination` @@ -385,7 +385,7 @@ The key difference is what happens to **upstream middleware's post-processing**: class MiddlewareA(AgentMiddleware): async def process(self, context, call_next): print("A: before") - await call_next(context) + await call_next() print("A: after") # Does this run? class MiddlewareB(AgentMiddleware): @@ -410,7 +410,7 @@ With middleware registered as `[MiddlewareA, MiddlewareB]`: ## Calling `call_next()` or Not -The decision to call `call_next(context)` determines whether downstream middleware and the actual operation execute: +The decision to call `call_next()` determines whether downstream middleware and the actual operation execute: ### Without calling `call_next()` - Skip downstream @@ -430,7 +430,7 @@ async def process(self, context, call_next): ```python async def process(self, context, call_next): # Pre-processing - await call_next(context) # Execute downstream + actual operation + await call_next() # Execute downstream + actual operation # Post-processing (context.result now contains real result) return ``` @@ -450,7 +450,7 @@ async def process(self, context, call_next): | `raise MiddlewareTermination` | Yes | ✅ | ✅ | ❌ No | | `raise OtherException` | Either | Depends | Depends | ❌ No (exception propagates) | -> **Note:** The first row (`return` after calling `call_next()`) is the default behavior. Python functions implicitly return `None` at the end, so simply calling `await call_next(context)` without an explicit `return` statement achieves this pattern. +> **Note:** The first row (`return` after calling `call_next()`) is the default behavior. Python functions implicitly return `None` at the end, so simply calling `await call_next()` without an explicit `return` statement achieves this pattern. ## Streaming vs Non-Streaming diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_agent_as_tool.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_agent_as_tool.py index f03fc4beb1..2d873f2930 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_agent_as_tool.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_agent_as_tool.py @@ -20,13 +20,13 @@ multiple specialized agents, each focusing on specific tasks. async def logging_middleware( context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """MiddlewareTypes that logs tool invocations to show the delegation flow.""" print(f"[Calling tool: {context.function.name}]") print(f"[Request: {context.arguments}]") - await call_next(context) + await call_next() print(f"[Response: {context.result}]") diff --git a/python/samples/getting_started/agents/openai/openai_responses_client_basic.py b/python/samples/getting_started/agents/openai/openai_responses_client_basic.py index e3ca638783..c1b94cc35a 100644 --- a/python/samples/getting_started/agents/openai/openai_responses_client_basic.py +++ b/python/samples/getting_started/agents/openai/openai_responses_client_basic.py @@ -29,7 +29,7 @@ response generation, showing both streaming and non-streaming responses. @chat_middleware async def security_and_override_middleware( context: ChatContext, - call_next: Callable[[ChatContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Function-based middleware that implements security filtering and response override.""" print("[SecurityMiddleware] Processing input...") @@ -60,7 +60,7 @@ async def security_and_override_middleware( raise MiddlewareTermination(result=context.result) # Continue to next middleware or AI execution - await call_next(context) + await call_next() print("[SecurityMiddleware] Response generated.") print(type(context.result)) diff --git a/python/samples/getting_started/agents/openai/openai_responses_client_with_agent_as_tool.py b/python/samples/getting_started/agents/openai/openai_responses_client_with_agent_as_tool.py index d37d5a9b4a..774231d0d6 100644 --- a/python/samples/getting_started/agents/openai/openai_responses_client_with_agent_as_tool.py +++ b/python/samples/getting_started/agents/openai/openai_responses_client_with_agent_as_tool.py @@ -19,13 +19,13 @@ multiple specialized agents, each focusing on specific tasks. async def logging_middleware( context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """MiddlewareTypes that logs tool invocations to show the delegation flow.""" print(f"[Calling tool: {context.function.name}]") print(f"[Request: {context.arguments}]") - await call_next(context) + await call_next() print(f"[Response: {context.result}]") diff --git a/python/samples/getting_started/devui/weather_agent_azure/agent.py b/python/samples/getting_started/devui/weather_agent_azure/agent.py index dca5b69bbc..a754d32ead 100644 --- a/python/samples/getting_started/devui/weather_agent_azure/agent.py +++ b/python/samples/getting_started/devui/weather_agent_azure/agent.py @@ -38,7 +38,7 @@ def cleanup_resources(): @chat_middleware async def security_filter_middleware( context: ChatContext, - call_next: Callable[[ChatContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Chat middleware that blocks requests containing sensitive information.""" blocked_terms = ["password", "secret", "api_key", "token"] @@ -80,13 +80,13 @@ async def security_filter_middleware( raise MiddlewareTermination(result=context.result) - await call_next(context) + await call_next() @function_middleware async def atlantis_location_filter_middleware( context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Function middleware that blocks weather requests for Atlantis.""" # Check if location parameter is "atlantis" @@ -98,7 +98,7 @@ async def atlantis_location_filter_middleware( ) raise MiddlewareTermination(result=context.result) - await call_next(context) + await call_next() # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/getting_started/tools/function_tool_with_approval.py and samples/getting_started/tools/function_tool_with_approval_and_threads.py. diff --git a/python/samples/getting_started/middleware/agent_and_run_level_middleware.py b/python/samples/getting_started/middleware/agent_and_run_level_middleware.py index 70408472ad..1f80c7742f 100644 --- a/python/samples/getting_started/middleware/agent_and_run_level_middleware.py +++ b/python/samples/getting_started/middleware/agent_and_run_level_middleware.py @@ -68,7 +68,7 @@ def get_weather( class SecurityAgentMiddleware(AgentMiddleware): """Agent-level security middleware that validates all requests.""" - async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: print("[SecurityMiddleware] Checking security for all requests...") # Check for security violations in the last user message @@ -81,18 +81,18 @@ class SecurityAgentMiddleware(AgentMiddleware): print("[SecurityMiddleware] Security check passed.") context.metadata["security_validated"] = True - await call_next(context) + await call_next() async def performance_monitor_middleware( context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Agent-level performance monitoring for all runs.""" print("[PerformanceMonitor] Starting performance monitoring...") start_time = time.time() - await call_next(context) + await call_next() end_time = time.time() duration = end_time - start_time @@ -104,7 +104,7 @@ async def performance_monitor_middleware( class HighPriorityMiddleware(AgentMiddleware): """Run-level middleware for high priority requests.""" - async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: print("[HighPriority] Processing high priority request with expedited handling...") # Read metadata set by agent-level middleware @@ -115,13 +115,13 @@ class HighPriorityMiddleware(AgentMiddleware): context.metadata["priority"] = "high" context.metadata["expedited"] = True - await call_next(context) + await call_next() print("[HighPriority] High priority processing completed") async def debugging_middleware( context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Run-level debugging middleware for troubleshooting specific runs.""" print("[Debug] Debug mode enabled for this run") @@ -134,7 +134,7 @@ async def debugging_middleware( context.metadata["debug_enabled"] = True - await call_next(context) + await call_next() print("[Debug] Debug information collected") @@ -145,7 +145,7 @@ class CachingMiddleware(AgentMiddleware): def __init__(self) -> None: self.cache: dict[str, AgentResponse] = {} - async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: # Create a simple cache key from the last message last_message = context.messages[-1] if context.messages else None cache_key: str = last_message.text if last_message and last_message.text else "no_message" @@ -158,7 +158,7 @@ class CachingMiddleware(AgentMiddleware): print(f"[Cache] Cache MISS for: '{cache_key[:30]}...'") context.metadata["cache_key"] = cache_key - await call_next(context) + await call_next() # Cache the result if we have one if context.result: @@ -168,14 +168,14 @@ class CachingMiddleware(AgentMiddleware): async def function_logging_middleware( context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Function middleware that logs all function calls.""" function_name = context.function.name args = context.arguments print(f"[FunctionLog] Calling function: {function_name} with args: {args}") - await call_next(context) + await call_next() print(f"[FunctionLog] Function {function_name} completed") @@ -275,7 +275,7 @@ async def main() -> None: query = "What's the secret weather password for Berlin?" print(f"User: {query}") result = await agent.run(query) - print(f"Agent: {result.text if result.text else 'Request was blocked by security middleware'}") + print(f"Agent: {result.text if result and result.text else 'Request was blocked by security middleware'}") print() # Run 7: Normal query again (no run-level middleware interference) diff --git a/python/samples/getting_started/middleware/chat_middleware.py b/python/samples/getting_started/middleware/chat_middleware.py index 424db96457..f0c9ef153e 100644 --- a/python/samples/getting_started/middleware/chat_middleware.py +++ b/python/samples/getting_started/middleware/chat_middleware.py @@ -57,7 +57,7 @@ class InputObserverMiddleware(ChatMiddleware): async def process( self, context: ChatContext, - call_next: Callable[[ChatContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Observe and modify input messages before they are sent to AI.""" print("[InputObserverMiddleware] Observing input messages:") @@ -91,7 +91,7 @@ class InputObserverMiddleware(ChatMiddleware): context.messages[:] = modified_messages # Continue to next middleware or AI execution - await call_next(context) + await call_next() # Observe that processing is complete print("[InputObserverMiddleware] Processing completed") @@ -100,7 +100,7 @@ class InputObserverMiddleware(ChatMiddleware): @chat_middleware async def security_and_override_middleware( context: ChatContext, - call_next: Callable[[ChatContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Function-based middleware that implements security filtering and response override.""" print("[SecurityMiddleware] Processing input...") @@ -131,7 +131,7 @@ async def security_and_override_middleware( raise MiddlewareTermination # Continue to next middleware or AI execution - await call_next(context) + await call_next() async def class_based_chat_middleware() -> None: diff --git a/python/samples/getting_started/middleware/class_based_middleware.py b/python/samples/getting_started/middleware/class_based_middleware.py index 208dddc96d..e3cb884c69 100644 --- a/python/samples/getting_started/middleware/class_based_middleware.py +++ b/python/samples/getting_started/middleware/class_based_middleware.py @@ -50,7 +50,7 @@ class SecurityAgentMiddleware(AgentMiddleware): async def process( self, context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # Check for potential security violations in the query # Look at the last user message @@ -67,7 +67,7 @@ class SecurityAgentMiddleware(AgentMiddleware): return print("[SecurityAgentMiddleware] Security check passed.") - await call_next(context) + await call_next() class LoggingFunctionMiddleware(FunctionMiddleware): @@ -76,14 +76,14 @@ class LoggingFunctionMiddleware(FunctionMiddleware): async def process( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: function_name = context.function.name print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.") start_time = time.time() - await call_next(context) + await call_next() end_time = time.time() duration = end_time - start_time diff --git a/python/samples/getting_started/middleware/decorator_middleware.py b/python/samples/getting_started/middleware/decorator_middleware.py index 3f5e57e48e..e432473a30 100644 --- a/python/samples/getting_started/middleware/decorator_middleware.py +++ b/python/samples/getting_started/middleware/decorator_middleware.py @@ -53,7 +53,7 @@ def get_current_time() -> str: async def simple_agent_middleware(context, call_next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality """Agent middleware that runs before and after agent execution.""" print("[Agent MiddlewareTypes] Before agent execution") - await call_next(context) + await call_next() print("[Agent MiddlewareTypes] After agent execution") @@ -61,7 +61,7 @@ async def simple_agent_middleware(context, call_next): # type: ignore - paramet async def simple_function_middleware(context, call_next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality """Function middleware that runs before and after function calls.""" print(f"[Function MiddlewareTypes] Before calling: {context.function.name}") # type: ignore - await call_next(context) + await call_next() print(f"[Function MiddlewareTypes] After calling: {context.function.name}") # type: ignore diff --git a/python/samples/getting_started/middleware/exception_handling_with_middleware.py b/python/samples/getting_started/middleware/exception_handling_with_middleware.py index b929af4c94..1f7ed59542 100644 --- a/python/samples/getting_started/middleware/exception_handling_with_middleware.py +++ b/python/samples/getting_started/middleware/exception_handling_with_middleware.py @@ -35,13 +35,13 @@ def unstable_data_service( async def exception_handling_middleware( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: function_name = context.function.name try: print(f"[ExceptionHandlingMiddleware] Executing function: {function_name}") - await call_next(context) + await call_next() print(f"[ExceptionHandlingMiddleware] Function {function_name} completed successfully.") except TimeoutError as e: print(f"[ExceptionHandlingMiddleware] Caught TimeoutError: {e}") diff --git a/python/samples/getting_started/middleware/function_based_middleware.py b/python/samples/getting_started/middleware/function_based_middleware.py index d9b9062003..38272a4cd1 100644 --- a/python/samples/getting_started/middleware/function_based_middleware.py +++ b/python/samples/getting_started/middleware/function_based_middleware.py @@ -43,7 +43,7 @@ def get_weather( async def security_agent_middleware( context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Agent middleware that checks for security violations.""" # Check for potential security violations in the query @@ -57,12 +57,12 @@ async def security_agent_middleware( return print("[SecurityAgentMiddleware] Security check passed.") - await call_next(context) + await call_next() async def logging_function_middleware( context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Function middleware that logs function calls.""" function_name = context.function.name @@ -70,7 +70,7 @@ async def logging_function_middleware( start_time = time.time() - await call_next(context) + await call_next() end_time = time.time() duration = end_time - start_time @@ -105,7 +105,7 @@ async def main() -> None: query = "What's the secret weather password?" print(f"User: {query}") result = await agent.run(query) - print(f"Agent: {result.text if result.text else 'No response'}\n") + print(f"Agent: {result.text if result and result.text else 'No response'}\n") if __name__ == "__main__": diff --git a/python/samples/getting_started/middleware/middleware_termination.py b/python/samples/getting_started/middleware/middleware_termination.py index 9f48e662c5..ce2db3e376 100644 --- a/python/samples/getting_started/middleware/middleware_termination.py +++ b/python/samples/getting_started/middleware/middleware_termination.py @@ -49,7 +49,7 @@ class PreTerminationMiddleware(AgentMiddleware): async def process( self, context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: # Check if the user message contains any blocked words last_message = context.messages[-1] if context.messages else None @@ -75,7 +75,7 @@ class PreTerminationMiddleware(AgentMiddleware): # Terminate to prevent further processing raise MiddlewareTermination(result=context.result) - await call_next(context) + await call_next() class PostTerminationMiddleware(AgentMiddleware): @@ -88,7 +88,7 @@ class PostTerminationMiddleware(AgentMiddleware): async def process( self, context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: print(f"[PostTerminationMiddleware] Processing request (response count: {self.response_count})") @@ -101,7 +101,7 @@ class PostTerminationMiddleware(AgentMiddleware): raise MiddlewareTermination # Allow the agent to process normally - await call_next(context) + await call_next() # Increment response count after processing self.response_count += 1 @@ -158,14 +158,14 @@ async def post_termination_middleware() -> None: query = "What about the weather in London?" print(f"User: {query}") result = await agent.run(query) - print(f"Agent: {result.text if result.text else 'No response (terminated)'}") + print(f"Agent: {result.text if result and result.text else 'No response (terminated)'}") # Third run (should also be terminated) print("\n3. Third run (should also be terminated):") query = "And New York?" print(f"User: {query}") result = await agent.run(query) - print(f"Agent: {result.text if result.text else 'No response (terminated)'}") + print(f"Agent: {result.text if result and result.text else 'No response (terminated)'}") async def main() -> None: diff --git a/python/samples/getting_started/middleware/override_result_with_middleware.py b/python/samples/getting_started/middleware/override_result_with_middleware.py index 2239136c3c..d05ec1b4f3 100644 --- a/python/samples/getting_started/middleware/override_result_with_middleware.py +++ b/python/samples/getting_started/middleware/override_result_with_middleware.py @@ -49,11 +49,11 @@ def get_weather( return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." -async def weather_override_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: +async def weather_override_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: """Chat middleware that overrides weather results for both streaming and non-streaming cases.""" # Let the original agent execution complete first - await call_next(context) + await call_next() # Check if there's a result to override (agent called weather function) if context.result is not None: @@ -84,9 +84,9 @@ async def weather_override_middleware(context: ChatContext, call_next: Callable[ context.result = ChatResponse(messages=[Message(role=Role.ASSISTANT, text=custom_message)]) -async def validate_weather_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None: +async def validate_weather_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: """Chat middleware that simulates result validation for both streaming and non-streaming cases.""" - await call_next(context) + await call_next() validation_note = "Validation: weather data verified." @@ -104,9 +104,9 @@ async def validate_weather_middleware(context: ChatContext, call_next: Callable[ context.result.messages.append(Message(role=Role.ASSISTANT, text=validation_note)) -async def agent_cleanup_middleware(context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None: +async def agent_cleanup_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: """Agent middleware that validates chat middleware effects and cleans the result.""" - await call_next(context) + await call_next() if context.result is None: return diff --git a/python/samples/getting_started/middleware/runtime_context_delegation.py b/python/samples/getting_started/middleware/runtime_context_delegation.py index 700b1da6f5..d839960da7 100644 --- a/python/samples/getting_started/middleware/runtime_context_delegation.py +++ b/python/samples/getting_started/middleware/runtime_context_delegation.py @@ -54,7 +54,7 @@ class SessionContextContainer: async def inject_context_middleware( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """MiddlewareTypes that extracts runtime context from kwargs and stores in container. @@ -74,7 +74,7 @@ class SessionContextContainer: print(f" - Session Metadata Keys: {list(self.session_metadata.keys())}") # Continue to tool execution - await call_next(context) + await call_next() # Create a container instance that will be shared via closure @@ -278,19 +278,19 @@ async def pattern_2_hierarchical_with_kwargs_propagation() -> None: @function_middleware async def email_kwargs_tracker( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: email_agent_kwargs.update(context.kwargs) print(f"[EmailAgent] Received runtime context: {list(context.kwargs.keys())}") - await call_next(context) + await call_next() @function_middleware async def sms_kwargs_tracker( - context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: sms_agent_kwargs.update(context.kwargs) print(f"[SMSAgent] Received runtime context: {list(context.kwargs.keys())}") - await call_next(context) + await call_next() client = OpenAIChatClient(model_id="gpt-4o-mini") @@ -359,7 +359,7 @@ class AuthContextMiddleware: self.validated_tokens: list[str] = [] async def validate_and_track( - self, context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]] + self, context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] ) -> None: """Validate API token and track usage.""" api_token = context.kwargs.get("api_token") @@ -375,7 +375,7 @@ class AuthContextMiddleware: else: print("[AuthMiddleware] No API token provided") - await call_next(context) + await call_next() @tool(approval_mode="never_require") diff --git a/python/samples/getting_started/middleware/shared_state_middleware.py b/python/samples/getting_started/middleware/shared_state_middleware.py index a377d7dfd3..a3aae59ccd 100644 --- a/python/samples/getting_started/middleware/shared_state_middleware.py +++ b/python/samples/getting_started/middleware/shared_state_middleware.py @@ -57,7 +57,7 @@ class MiddlewareContainer: async def call_counter_middleware( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """First middleware: increments call count in shared state.""" # Increment the shared call count @@ -66,18 +66,18 @@ class MiddlewareContainer: print(f"[CallCounter] This is function call #{self.call_count}") # Call the next middleware/function - await call_next(context) + await call_next() async def result_enhancer_middleware( self, context: FunctionInvocationContext, - call_next: Callable[[FunctionInvocationContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """Second middleware: uses shared call count to enhance function results.""" print(f"[ResultEnhancer] Current total calls so far: {self.call_count}") # Call the next middleware/function - await call_next(context) + await call_next() # After function execution, enhance the result using shared state if context.result: diff --git a/python/samples/getting_started/middleware/thread_behavior_middleware.py b/python/samples/getting_started/middleware/thread_behavior_middleware.py index e3306eef7b..680fd01d50 100644 --- a/python/samples/getting_started/middleware/thread_behavior_middleware.py +++ b/python/samples/getting_started/middleware/thread_behavior_middleware.py @@ -46,7 +46,7 @@ def get_weather( async def thread_tracking_middleware( context: AgentContext, - call_next: Callable[[AgentContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: """MiddlewareTypes that tracks and logs thread behavior across runs.""" thread_messages = [] @@ -57,7 +57,7 @@ async def thread_tracking_middleware( print(f"[MiddlewareTypes pre-execution] Thread history messages: {len(thread_messages)}") # Call call_next to execute the agent - await call_next(context) + await call_next() # Check thread state after agent execution updated_thread_messages = [] From 235c5780595183df85220c279936b1716230d082 Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Wed, 11 Feb 2026 10:04:17 -0500 Subject: [PATCH 2/3] ci: Add the Workflows SourceGenerators project into the release filter (#3815) --- dotnet/agent-framework-release.slnf | 1 + 1 file changed, 1 insertion(+) diff --git a/dotnet/agent-framework-release.slnf b/dotnet/agent-framework-release.slnf index 98e882ee29..ebd33c0767 100644 --- a/dotnet/agent-framework-release.slnf +++ b/dotnet/agent-framework-release.slnf @@ -25,6 +25,7 @@ "src\\Microsoft.Agents.AI.Purview\\Microsoft.Agents.AI.Purview.csproj", "src\\Microsoft.Agents.AI.Workflows.Declarative.AzureAI\\Microsoft.Agents.AI.Workflows.Declarative.AzureAI.csproj", "src\\Microsoft.Agents.AI.Workflows.Declarative\\Microsoft.Agents.AI.Workflows.Declarative.csproj", + "src\\Microsoft.Agents.AI.Workflows.Generators\\Microsoft.Agents.AI.Workflows.Generators.csproj", "src\\Microsoft.Agents.AI.Workflows\\Microsoft.Agents.AI.Workflows.csproj", "src\\Microsoft.Agents.AI\\Microsoft.Agents.AI.csproj" ] From a427af91a9b12e0c2781a57597a36670b2f7380b Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Wed, 11 Feb 2026 16:46:25 +0100 Subject: [PATCH 3/3] Python: Allow AzureOpenAIResponsesClient creation with Foundry project endpoint (#3814) * Initial plan * feat: extend AzureOpenAIResponsesClient to support Foundry project endpoints Add project_client and project_endpoint parameters to allow creating the client via an Azure AI Foundry project. When provided, the client uses AIProjectClient.get_openai_client() to obtain the OpenAI client. The azure-ai-projects package is imported lazily and only required when using the project endpoint path. Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com> * fix: address code review - remove duplicate MagicMock imports in tests Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com> * fix: add type field to Responses API input items and add Foundry sample - Add 'type: message' to input items in _prepare_message_for_openai to comply with the Responses API schema requirement - Filter out empty dicts from unsupported content types to prevent sending items with invalid empty type values - Add azure_responses_client_with_foundry.py sample demonstrating AzureOpenAIResponsesClient with project_endpoint - Update README and pyrightconfig.samples.json accordingly * updates to response format and setup * fix: patch AIProjectClient at correct module path in test Patch agent_framework.azure._responses_client.AIProjectClient instead of azure.ai.projects.aio.AIProjectClient since the import is at module level. * docs: add Foundry sample to READMEs and document AZURE_AI_PROJECT_ENDPOINT env var --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com> Co-authored-by: eavanvalkenburg --- python/packages/azure-ai/pyproject.toml | 1 - .../azure/_responses_client.py | 98 ++++++++++++++- .../core/agent_framework/azure/_shared.py | 3 +- .../openai/_responses_client.py | 19 ++- python/packages/core/pyproject.toml | 1 + .../azure/test_azure_responses_client.py | 114 ++++++++++++++++++ .../openai/test_openai_responses_client.py | 16 +-- python/pyrightconfig.samples.json | 3 +- python/samples/README.md | 1 + .../agents/azure_openai/README.md | 4 + .../azure_responses_client_with_foundry.py | 113 +++++++++++++++++ python/uv.lock | 4 +- 12 files changed, 349 insertions(+), 28 deletions(-) create mode 100644 python/samples/getting_started/agents/azure_openai/azure_responses_client_with_foundry.py diff --git a/python/packages/azure-ai/pyproject.toml b/python/packages/azure-ai/pyproject.toml index 86e0d3342f..0ad16c4362 100644 --- a/python/packages/azure-ai/pyproject.toml +++ b/python/packages/azure-ai/pyproject.toml @@ -24,7 +24,6 @@ classifiers = [ ] dependencies = [ "agent-framework-core>=1.0.0b260210", - "azure-ai-projects >= 2.0.0b3", "azure-ai-agents == 1.2.0b5", "aiohttp", ] diff --git a/python/packages/core/agent_framework/azure/_responses_client.py b/python/packages/core/agent_framework/azure/_responses_client.py index cc57beb57c..0049979a6a 100644 --- a/python/packages/core/agent_framework/azure/_responses_client.py +++ b/python/packages/core/agent_framework/azure/_responses_client.py @@ -7,11 +7,14 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Generic from urllib.parse import urljoin +from azure.ai.projects.aio import AIProjectClient from azure.core.credentials import TokenCredential -from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI +from openai import AsyncOpenAI +from openai.lib.azure import AsyncAzureADTokenProvider from pydantic import ValidationError from .._middleware import ChatMiddlewareLayer +from .._telemetry import AGENT_FRAMEWORK_USER_AGENT from .._tools import FunctionInvocationConfiguration, FunctionInvocationLayer from ..exceptions import ServiceInitializationError from ..observability import ChatTelemetryLayer @@ -72,7 +75,9 @@ class AzureOpenAIResponsesClient( # type: ignore[misc] token_endpoint: str | None = None, credential: TokenCredential | None = None, default_headers: Mapping[str, str] | None = None, - async_client: AsyncAzureOpenAI | None = None, + async_client: AsyncOpenAI | None = None, + project_client: Any | None = None, + project_endpoint: str | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, instruction_role: str | None = None, @@ -82,6 +87,14 @@ class AzureOpenAIResponsesClient( # type: ignore[misc] ) -> None: """Initialize an Azure OpenAI Responses client. + The client can be created in two ways: + + 1. **Direct Azure OpenAI** (default): Provide endpoint, api_key, or credential + to connect directly to an Azure OpenAI deployment. + 2. **Foundry project endpoint**: Provide a ``project_client`` or ``project_endpoint`` + (with ``credential``) to create the client via an Azure AI Foundry project. + This requires the ``azure-ai-projects`` package to be installed. + Keyword Args: api_key: The API key. If provided, will override the value in the env vars or .env file. Can also be set via environment variable AZURE_OPENAI_API_KEY. @@ -105,6 +118,12 @@ class AzureOpenAIResponsesClient( # type: ignore[misc] default_headers: The default headers mapping of string keys to string values for HTTP requests. async_client: An existing client to use. + project_client: An existing ``AIProjectClient`` (from ``azure.ai.projects.aio``) to use. + The OpenAI client will be obtained via ``project_client.get_openai_client()``. + Requires the ``azure-ai-projects`` package. + project_endpoint: The Azure AI Foundry project endpoint URL. + When provided with ``credential``, an ``AIProjectClient`` will be created + and used to obtain the OpenAI client. Requires the ``azure-ai-projects`` package. env_file_path: Use the environment settings file as a fallback to using env vars. env_file_encoding: The encoding of the environment settings file, defaults to 'utf-8'. instruction_role: The role to use for 'instruction' messages, for example, summarization @@ -132,6 +151,27 @@ class AzureOpenAIResponsesClient( # type: ignore[misc] # Or loading from a .env file client = AzureOpenAIResponsesClient(env_file_path="path/to/.env") + # Using a Foundry project endpoint + from azure.identity import DefaultAzureCredential + + client = AzureOpenAIResponsesClient( + project_endpoint="https://your-project.services.ai.azure.com", + deployment_name="gpt-4o", + credential=DefaultAzureCredential(), + ) + + # Or using an existing AIProjectClient + from azure.ai.projects.aio import AIProjectClient + + project_client = AIProjectClient( + endpoint="https://your-project.services.ai.azure.com", + credential=DefaultAzureCredential(), + ) + client = AzureOpenAIResponsesClient( + project_client=project_client, + deployment_name="gpt-4o", + ) + # Using custom ChatOptions with type safety: from typing import TypedDict from agent_framework.azure import AzureOpenAIResponsesOptions @@ -146,6 +186,15 @@ class AzureOpenAIResponsesClient( # type: ignore[misc] """ if model_id := kwargs.pop("model_id", None) and not deployment_name: deployment_name = str(model_id) + + # Project client path: create OpenAI client from an Azure AI Foundry project + if async_client is None and (project_client is not None or project_endpoint is not None): + async_client = self._create_client_from_project( + project_client=project_client, + project_endpoint=project_endpoint, + credential=credential, + ) + try: azure_openai_settings = AzureOpenAISettings( # pydantic settings will see if there is a value, if not, will try the env var or .env file @@ -195,9 +244,48 @@ class AzureOpenAIResponsesClient( # type: ignore[misc] function_invocation_configuration=function_invocation_configuration, ) + @staticmethod + def _create_client_from_project( + *, + project_client: AIProjectClient | None, + project_endpoint: str | None, + credential: TokenCredential | None, + ) -> AsyncOpenAI: + """Create an AsyncOpenAI client from an Azure AI Foundry project. + + Args: + project_client: An existing AIProjectClient to use. + project_endpoint: The Azure AI Foundry project endpoint URL. + credential: Azure credential for authentication. + + Returns: + An AsyncAzureOpenAI client obtained from the project client. + + Raises: + ServiceInitializationError: If required parameters are missing or + the azure-ai-projects package is not installed. + """ + if project_client is not None: + return project_client.get_openai_client() + + if not project_endpoint: + raise ServiceInitializationError( + "Azure AI project endpoint is required when project_client is not provided." + ) + if not credential: + raise ServiceInitializationError( + "Azure credential is required when using project_endpoint without a project_client." + ) + project_client = AIProjectClient( + endpoint=project_endpoint, + credential=credential, # type: ignore[arg-type] + user_agent=AGENT_FRAMEWORK_USER_AGENT, + ) + return project_client.get_openai_client() + @override - def _check_model_presence(self, run_options: dict[str, Any]) -> None: - if not run_options.get("model"): + def _check_model_presence(self, options: dict[str, Any]) -> None: + if not options.get("model"): if not self.model_id: raise ValueError("deployment_name must be a non-empty string") - run_options["model"] = self.model_id + options["model"] = self.model_id diff --git a/python/packages/core/agent_framework/azure/_shared.py b/python/packages/core/agent_framework/azure/_shared.py index 8e90002a75..5ef0585f96 100644 --- a/python/packages/core/agent_framework/azure/_shared.py +++ b/python/packages/core/agent_framework/azure/_shared.py @@ -9,6 +9,7 @@ from copy import copy from typing import Any, ClassVar, Final from azure.core.credentials import TokenCredential +from openai import AsyncOpenAI from openai.lib.azure import AsyncAzureOpenAI from pydantic import SecretStr, model_validator @@ -162,7 +163,7 @@ class AzureOpenAIConfigMixin(OpenAIBase): token_endpoint: str | None = None, credential: TokenCredential | None = None, default_headers: Mapping[str, str] | None = None, - client: AsyncAzureOpenAI | None = None, + client: AsyncOpenAI | None = None, instruction_role: str | None = None, **kwargs: Any, ) -> None: diff --git a/python/packages/core/agent_framework/openai/_responses_client.py b/python/packages/core/agent_framework/openai/_responses_client.py index 5902ad0e46..f239221c49 100644 --- a/python/packages/core/agent_framework/openai/_responses_client.py +++ b/python/packages/core/agent_framework/openai/_responses_client.py @@ -901,6 +901,7 @@ class RawOpenAIResponsesClient( # type: ignore[misc] """Prepare a chat message for the OpenAI Responses API format.""" all_messages: list[dict[str, Any]] = [] args: dict[str, Any] = { + "type": "message", "role": message.role, } for content in message.contents: @@ -911,16 +912,22 @@ class RawOpenAIResponsesClient( # type: ignore[misc] case "function_result": new_args: dict[str, Any] = {} new_args.update(self._prepare_content_for_openai(message.role, content, call_id_to_id)) # type: ignore[arg-type] - all_messages.append(new_args) + if new_args: + all_messages.append(new_args) case "function_call": function_call = self._prepare_content_for_openai(message.role, content, call_id_to_id) # type: ignore[arg-type] - all_messages.append(function_call) # type: ignore + if function_call: + all_messages.append(function_call) # type: ignore case "function_approval_response" | "function_approval_request": - all_messages.append(self._prepare_content_for_openai(message.role, content, call_id_to_id)) # type: ignore + prepared = self._prepare_content_for_openai(Role(message.role), content, call_id_to_id) + if prepared: + all_messages.append(prepared) # type: ignore case _: - if "content" not in args: - args["content"] = [] - args["content"].append(self._prepare_content_for_openai(message.role, content, call_id_to_id)) # type: ignore + prepared_content = self._prepare_content_for_openai(message.role, content, call_id_to_id) # type: ignore + if prepared_content: + if "content" not in args: + args["content"] = [] + args["content"].append(prepared_content) # type: ignore if "content" in args or "tool_calls" in args: all_messages.append(args) return all_messages diff --git a/python/packages/core/pyproject.toml b/python/packages/core/pyproject.toml index f4f28c898a..5a90b479e7 100644 --- a/python/packages/core/pyproject.toml +++ b/python/packages/core/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ # connectors and functions "openai>=1.99.0", "azure-identity>=1,<2", + "azure-ai-projects >= 2.0.0b3", "mcp[ws]>=1.24.0,<2", "packaging>=24.1", ] diff --git a/python/packages/core/tests/azure/test_azure_responses_client.py b/python/packages/core/tests/azure/test_azure_responses_client.py index 434674d50c..1d40c769db 100644 --- a/python/packages/core/tests/azure/test_azure_responses_client.py +++ b/python/packages/core/tests/azure/test_azure_responses_client.py @@ -3,6 +3,7 @@ import json import os from typing import Annotated, Any +from unittest.mock import MagicMock import pytest from azure.identity import AzureCliCredential @@ -115,6 +116,119 @@ def test_init_with_empty_model_id(azure_openai_unit_test_env: dict[str, str]) -> ) +def test_init_with_project_client(azure_openai_unit_test_env: dict[str, str]) -> None: + """Test initialization with an existing AIProjectClient.""" + from unittest.mock import patch + + from openai import AsyncOpenAI + + # Create a mock AIProjectClient that returns a mock AsyncOpenAI client + mock_openai_client = MagicMock(spec=AsyncOpenAI) + mock_openai_client.default_headers = {} + + mock_project_client = MagicMock() + mock_project_client.get_openai_client.return_value = mock_openai_client + + with patch( + "agent_framework.azure._responses_client.AzureOpenAIResponsesClient._create_client_from_project", + return_value=mock_openai_client, + ): + azure_responses_client = AzureOpenAIResponsesClient( + project_client=mock_project_client, + deployment_name="gpt-4o", + ) + + assert azure_responses_client.model_id == "gpt-4o" + assert azure_responses_client.client is mock_openai_client + assert isinstance(azure_responses_client, SupportsChatGetResponse) + + +def test_init_with_project_endpoint(azure_openai_unit_test_env: dict[str, str]) -> None: + """Test initialization with a project endpoint and credential.""" + from unittest.mock import patch + + from openai import AsyncOpenAI + + mock_openai_client = MagicMock(spec=AsyncOpenAI) + mock_openai_client.default_headers = {} + + with patch( + "agent_framework.azure._responses_client.AzureOpenAIResponsesClient._create_client_from_project", + return_value=mock_openai_client, + ): + azure_responses_client = AzureOpenAIResponsesClient( + project_endpoint="https://test-project.services.ai.azure.com", + deployment_name="gpt-4o", + credential=AzureCliCredential(), + ) + + assert azure_responses_client.model_id == "gpt-4o" + assert azure_responses_client.client is mock_openai_client + assert isinstance(azure_responses_client, SupportsChatGetResponse) + + +def test_create_client_from_project_with_project_client() -> None: + """Test _create_client_from_project with an existing project client.""" + from openai import AsyncOpenAI + + mock_openai_client = MagicMock(spec=AsyncOpenAI) + mock_project_client = MagicMock() + mock_project_client.get_openai_client.return_value = mock_openai_client + + result = AzureOpenAIResponsesClient._create_client_from_project( + project_client=mock_project_client, + project_endpoint=None, + credential=None, + ) + + assert result is mock_openai_client + mock_project_client.get_openai_client.assert_called_once() + + +def test_create_client_from_project_with_endpoint() -> None: + """Test _create_client_from_project with a project endpoint.""" + from unittest.mock import patch + + from openai import AsyncOpenAI + + mock_openai_client = MagicMock(spec=AsyncOpenAI) + mock_credential = MagicMock() + + with patch("agent_framework.azure._responses_client.AIProjectClient") as MockAIProjectClient: + mock_instance = MockAIProjectClient.return_value + mock_instance.get_openai_client.return_value = mock_openai_client + + result = AzureOpenAIResponsesClient._create_client_from_project( + project_client=None, + project_endpoint="https://test-project.services.ai.azure.com", + credential=mock_credential, + ) + + assert result is mock_openai_client + MockAIProjectClient.assert_called_once() + mock_instance.get_openai_client.assert_called_once() + + +def test_create_client_from_project_missing_endpoint() -> None: + """Test _create_client_from_project raises error when endpoint is missing.""" + with pytest.raises(ServiceInitializationError, match="project endpoint is required"): + AzureOpenAIResponsesClient._create_client_from_project( + project_client=None, + project_endpoint=None, + credential=MagicMock(), + ) + + +def test_create_client_from_project_missing_credential() -> None: + """Test _create_client_from_project raises error when credential is missing.""" + with pytest.raises(ServiceInitializationError, match="credential is required"): + AzureOpenAIResponsesClient._create_client_from_project( + project_client=None, + project_endpoint="https://test-project.services.ai.azure.com", + credential=None, + ) + + def test_serialize(azure_openai_unit_test_env: dict[str, str]) -> None: default_headers = {"X-Unit-Test": "test-guid"} diff --git a/python/packages/core/tests/openai/test_openai_responses_client.py b/python/packages/core/tests/openai/test_openai_responses_client.py index e51ed4e989..e3f982d826 100644 --- a/python/packages/core/tests/openai/test_openai_responses_client.py +++ b/python/packages/core/tests/openai/test_openai_responses_client.py @@ -798,12 +798,8 @@ def test_chat_message_with_error_content() -> None: result = client._prepare_message_for_openai(message, call_id_to_id) - # Message should be prepared with empty content list since ErrorContent returns {} - assert len(result) == 1 - prepared_message = result[0] - assert prepared_message["role"] == "assistant" - # Content should be a list with empty dict since ErrorContent returns {} - assert prepared_message.get("content") == [{}] + # Message should be empty since ErrorContent is filtered out + assert len(result) == 0 def test_chat_message_with_usage_content() -> None: @@ -823,12 +819,8 @@ def test_chat_message_with_usage_content() -> None: result = client._prepare_message_for_openai(message, call_id_to_id) - # Message should be prepared with empty content list since UsageContent returns {} - assert len(result) == 1 - prepared_message = result[0] - assert prepared_message["role"] == "assistant" - # Content should be a list with empty dict since UsageContent returns {} - assert prepared_message.get("content") == [{}] + # Message should be empty since UsageContent is filtered out + assert len(result) == 0 def test_hosted_file_content_preparation() -> None: diff --git a/python/pyrightconfig.samples.json b/python/pyrightconfig.samples.json index a74e252474..5dae59f141 100644 --- a/python/pyrightconfig.samples.json +++ b/python/pyrightconfig.samples.json @@ -5,7 +5,8 @@ "**/autogen-migration/**", "**/semantic-kernel-migration/**", "**/demos/**", - "**/agent_with_foundry_tracing.py" + "**/agent_with_foundry_tracing.py", + "**/azure_responses_client_with_foundry.py" ], "typeCheckingMode": "off", "reportMissingImports": "error", diff --git a/python/samples/README.md b/python/samples/README.md index fc64dced52..eb234cdc3e 100644 --- a/python/samples/README.md +++ b/python/samples/README.md @@ -78,6 +78,7 @@ This directory contains samples demonstrating the capabilities of Microsoft Agen | [`getting_started/agents/azure_openai/azure_responses_client_image_analysis.py`](./getting_started/agents/azure_openai/azure_responses_client_image_analysis.py) | Azure OpenAI Responses Client with Image Analysis Example | | [`getting_started/agents/azure_openai/azure_responses_client_with_code_interpreter.py`](./getting_started/agents/azure_openai/azure_responses_client_with_code_interpreter.py) | Azure OpenAI Responses Client with Code Interpreter Example | | [`getting_started/agents/azure_openai/azure_responses_client_with_explicit_settings.py`](./getting_started/agents/azure_openai/azure_responses_client_with_explicit_settings.py) | Azure OpenAI Responses Client with Explicit Settings Example | +| [`getting_started/agents/azure_openai/azure_responses_client_with_foundry.py`](./getting_started/agents/azure_openai/azure_responses_client_with_foundry.py) | Azure OpenAI Responses Client with Foundry Project Example | | [`getting_started/agents/azure_openai/azure_responses_client_with_function_tools.py`](./getting_started/agents/azure_openai/azure_responses_client_with_function_tools.py) | Azure OpenAI Responses Client with Function Tools Example | | [`getting_started/agents/azure_openai/azure_responses_client_with_hosted_mcp.py`](./getting_started/agents/azure_openai/azure_responses_client_with_hosted_mcp.py) | Azure OpenAI Responses Client with Hosted Model Context Protocol (MCP) Example | | [`getting_started/agents/azure_openai/azure_responses_client_with_local_mcp.py`](./getting_started/agents/azure_openai/azure_responses_client_with_local_mcp.py) | Azure OpenAI Responses Client with local Model Context Protocol (MCP) Example | diff --git a/python/samples/getting_started/agents/azure_openai/README.md b/python/samples/getting_started/agents/azure_openai/README.md index fea029c209..614e60b14d 100644 --- a/python/samples/getting_started/agents/azure_openai/README.md +++ b/python/samples/getting_started/agents/azure_openai/README.md @@ -22,6 +22,7 @@ This folder contains examples demonstrating different ways to create and use age | [`azure_responses_client_with_code_interpreter.py`](azure_responses_client_with_code_interpreter.py) | Shows how to use `AzureOpenAIResponsesClient.get_code_interpreter_tool()` with Azure agents to write and execute Python code. Includes helper methods for accessing code interpreter data from response chunks. | | [`azure_responses_client_with_explicit_settings.py`](azure_responses_client_with_explicit_settings.py) | Shows how to initialize an agent with a specific responses client, configuring settings explicitly including endpoint and deployment name. | | [`azure_responses_client_with_file_search.py`](azure_responses_client_with_file_search.py) | Demonstrates using `AzureOpenAIResponsesClient.get_file_search_tool()` with Azure OpenAI Responses Client for direct document-based question answering and information retrieval from vector stores. | +| [`azure_responses_client_with_foundry.py`](azure_responses_client_with_foundry.py) | Shows how to create an agent using an Azure AI Foundry project endpoint instead of a direct Azure OpenAI endpoint. Requires the `azure-ai-projects` package. | | [`azure_responses_client_with_function_tools.py`](azure_responses_client_with_function_tools.py) | Demonstrates how to use function tools with agents. Shows both agent-level tools (defined when creating the agent) and query-level tools (provided with specific queries). | | [`azure_responses_client_with_hosted_mcp.py`](azure_responses_client_with_hosted_mcp.py) | Shows how to integrate Azure OpenAI Responses Client with hosted Model Context Protocol (MCP) servers using `AzureOpenAIResponsesClient.get_mcp_tool()` for extended functionality. | | [`azure_responses_client_with_local_mcp.py`](azure_responses_client_with_local_mcp.py) | Shows how to integrate Azure OpenAI Responses Client with local Model Context Protocol (MCP) servers using MCPStreamableHTTPTool for extended functionality. | @@ -35,6 +36,9 @@ Make sure to set the following environment variables before running the examples - `AZURE_OPENAI_CHAT_DEPLOYMENT_NAME`: The name of your Azure OpenAI chat model deployment - `AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME`: The name of your Azure OpenAI Responses deployment +For the Foundry project sample (`azure_responses_client_with_foundry.py`), also set: +- `AZURE_AI_PROJECT_ENDPOINT`: Your Azure AI Foundry project endpoint + Optionally, you can set: - `AZURE_OPENAI_API_VERSION`: The API version to use (default is `2024-02-15-preview`) - `AZURE_OPENAI_API_KEY`: Your Azure OpenAI API key (if not using `AzureCliCredential`) diff --git a/python/samples/getting_started/agents/azure_openai/azure_responses_client_with_foundry.py b/python/samples/getting_started/agents/azure_openai/azure_responses_client_with_foundry.py new file mode 100644 index 0000000000..7020121db9 --- /dev/null +++ b/python/samples/getting_started/agents/azure_openai/azure_responses_client_with_foundry.py @@ -0,0 +1,113 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import os +from random import randint +from typing import Annotated + +from agent_framework import tool +from agent_framework.azure import AzureOpenAIResponsesClient +from azure.identity import AzureCliCredential +from dotenv import load_dotenv +from pydantic import Field + +""" +Azure OpenAI Responses Client with Foundry Project Example + +This sample demonstrates how to create an AzureOpenAIResponsesClient using an +Azure AI Foundry project endpoint. Instead of providing an Azure OpenAI endpoint +directly, you provide a Foundry project endpoint and the client is created via +the Azure AI Foundry project SDK. + +This requires: +- The `azure-ai-projects` package to be installed. +- The `AZURE_AI_PROJECT_ENDPOINT` environment variable set to your Foundry project endpoint. +- The `AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME` environment variable set to the model deployment name. +""" + +load_dotenv() # Load environment variables from .env file if present + + +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/getting_started/tools/function_tool_with_approval.py and samples/getting_started/tools/function_tool_with_approval_and_threads.py. +@tool(approval_mode="never_require") +def get_weather( + location: Annotated[str, Field(description="The location to get the weather for.")], +) -> str: + """Get the weather for a given location.""" + conditions = ["sunny", "cloudy", "rainy", "stormy"] + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + + +async def non_streaming_example() -> None: + """Example of non-streaming response (get the complete result at once).""" + print("=== Non-streaming Response Example ===") + + # 1. Create the AzureOpenAIResponsesClient using a Foundry project endpoint. + # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred + # authentication option. + credential = AzureCliCredential() + agent = AzureOpenAIResponsesClient( + project_endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], + deployment_name=os.environ["AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME"], + credential=credential, + ).as_agent( + instructions="You are a helpful weather agent.", + tools=get_weather, + ) + + # 2. Run a query and print the result. + query = "What's the weather like in Seattle?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Result: {result}\n") + + +async def streaming_example() -> None: + """Example of streaming response (get results as they are generated).""" + print("=== Streaming Response Example ===") + + # 1. Create the AzureOpenAIResponsesClient using a Foundry project endpoint. + # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred + # authentication option. + credential = AzureCliCredential() + agent = AzureOpenAIResponsesClient( + project_endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], + deployment_name=os.environ["AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME"], + credential=credential, + ).as_agent( + instructions="You are a helpful weather agent.", + tools=get_weather, + ) + + # 2. Stream the response and print each chunk as it arrives. + query = "What's the weather like in Portland?" + print(f"User: {query}") + print("Agent: ", end="", flush=True) + async for chunk in agent.run(query, stream=True): + if chunk.text: + print(chunk.text, end="", flush=True) + print("\n") + + +async def main() -> None: + print("=== Azure OpenAI Responses Client with Foundry Project Example ===") + + await non_streaming_example() + await streaming_example() + + +if __name__ == "__main__": + asyncio.run(main()) + + +""" +Sample output: +=== Azure OpenAI Responses Client with Foundry Project Example === +=== Non-streaming Response Example === +User: What's the weather like in Seattle? +Result: The weather in Seattle is cloudy with a high of 18°C. + +=== Streaming Response Example === +User: What's the weather like in Portland? +Agent: The weather in Portland is sunny with a high of 25°C. +""" diff --git a/python/uv.lock b/python/uv.lock index fac55c8e21..24396ef396 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -209,7 +209,6 @@ dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "aiohttp", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "azure-ai-agents", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "azure-ai-projects", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] [package.metadata] @@ -217,7 +216,6 @@ requires-dist = [ { name = "agent-framework-core", editable = "packages/core" }, { name = "aiohttp" }, { name = "azure-ai-agents", specifier = "==1.2.0b5" }, - { name = "azure-ai-projects", specifier = ">=2.0.0b3" }, ] [[package]] @@ -324,6 +322,7 @@ name = "agent-framework-core" version = "1.0.0b260210" source = { editable = "packages/core" } dependencies = [ + { name = "azure-ai-projects", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "azure-identity", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "mcp", extra = ["ws"], marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -378,6 +377,7 @@ requires-dist = [ { name = "agent-framework-orchestrations", marker = "extra == 'all'", editable = "packages/orchestrations" }, { name = "agent-framework-purview", marker = "extra == 'all'", editable = "packages/purview" }, { name = "agent-framework-redis", marker = "extra == 'all'", editable = "packages/redis" }, + { name = "azure-ai-projects", specifier = ">=2.0.0b3" }, { name = "azure-identity", specifier = ">=1,<2" }, { name = "mcp", extras = ["ws"], specifier = ">=1.24.0,<2" }, { name = "openai", specifier = ">=1.99.0" },