Removed context parameter from call_next (#3829)

This commit is contained in:
Dmytro Struk
2026-02-11 02:47:41 -08:00
committed by GitHub
Unverified
parent 38f22ef006
commit 1fdc4be88d
29 changed files with 451 additions and 583 deletions
+8 -8
View File
@@ -267,7 +267,7 @@ class TerminatingMiddleware(FunctionMiddleware):
if self.should_terminate(context):
context.result = "terminated by middleware"
raise MiddlewareTermination # Exit function invocation loop
await call_next(context)
await call_next()
```
## Arguments Added/Altered at Each Layer
@@ -347,7 +347,7 @@ class CachingMiddleware(FunctionMiddleware):
return # Upstream post-processing still runs
# Option B: Call call_next, then return normally
await call_next(context)
await call_next()
self.cache[context.function.name] = context.result
return # Normal completion
```
@@ -362,7 +362,7 @@ class BlockedFunctionMiddleware(FunctionMiddleware):
if context.function.name in self.blocked_functions:
context.result = "Function blocked by policy"
raise MiddlewareTermination("Blocked") # Skips ALL post-processing
await call_next(context)
await call_next()
```
### 3. Raise Any Other Exception
@@ -374,7 +374,7 @@ class ValidationMiddleware(FunctionMiddleware):
async def process(self, context: FunctionInvocationContext, call_next):
if not self.is_valid(context.arguments):
raise ValueError("Invalid arguments") # Bubbles up to user
await call_next(context)
await call_next()
```
## `return` vs `raise MiddlewareTermination`
@@ -385,7 +385,7 @@ The key difference is what happens to **upstream middleware's post-processing**:
class MiddlewareA(AgentMiddleware):
async def process(self, context, call_next):
print("A: before")
await call_next(context)
await call_next()
print("A: after") # Does this run?
class MiddlewareB(AgentMiddleware):
@@ -410,7 +410,7 @@ With middleware registered as `[MiddlewareA, MiddlewareB]`:
## Calling `call_next()` or Not
The decision to call `call_next(context)` determines whether downstream middleware and the actual operation execute:
The decision to call `call_next()` determines whether downstream middleware and the actual operation execute:
### Without calling `call_next()` - Skip downstream
@@ -430,7 +430,7 @@ async def process(self, context, call_next):
```python
async def process(self, context, call_next):
# Pre-processing
await call_next(context) # Execute downstream + actual operation
await call_next() # Execute downstream + actual operation
# Post-processing (context.result now contains real result)
return
```
@@ -450,7 +450,7 @@ async def process(self, context, call_next):
| `raise MiddlewareTermination` | Yes | ✅ | ✅ | ❌ No |
| `raise OtherException` | Either | Depends | Depends | ❌ No (exception propagates) |
> **Note:** The first row (`return` after calling `call_next()`) is the default behavior. Python functions implicitly return `None` at the end, so simply calling `await call_next(context)` without an explicit `return` statement achieves this pattern.
> **Note:** The first row (`return` after calling `call_next()`) is the default behavior. Python functions implicitly return `None` at the end, so simply calling `await call_next()` without an explicit `return` statement achieves this pattern.
## Streaming vs Non-Streaming
@@ -20,13 +20,13 @@ multiple specialized agents, each focusing on specific tasks.
async def logging_middleware(
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""MiddlewareTypes that logs tool invocations to show the delegation flow."""
print(f"[Calling tool: {context.function.name}]")
print(f"[Request: {context.arguments}]")
await call_next(context)
await call_next()
print(f"[Response: {context.result}]")
@@ -29,7 +29,7 @@ response generation, showing both streaming and non-streaming responses.
@chat_middleware
async def security_and_override_middleware(
context: ChatContext,
call_next: Callable[[ChatContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Function-based middleware that implements security filtering and response override."""
print("[SecurityMiddleware] Processing input...")
@@ -60,7 +60,7 @@ async def security_and_override_middleware(
raise MiddlewareTermination(result=context.result)
# Continue to next middleware or AI execution
await call_next(context)
await call_next()
print("[SecurityMiddleware] Response generated.")
print(type(context.result))
@@ -19,13 +19,13 @@ multiple specialized agents, each focusing on specific tasks.
async def logging_middleware(
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""MiddlewareTypes that logs tool invocations to show the delegation flow."""
print(f"[Calling tool: {context.function.name}]")
print(f"[Request: {context.arguments}]")
await call_next(context)
await call_next()
print(f"[Response: {context.result}]")
@@ -38,7 +38,7 @@ def cleanup_resources():
@chat_middleware
async def security_filter_middleware(
context: ChatContext,
call_next: Callable[[ChatContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Chat middleware that blocks requests containing sensitive information."""
blocked_terms = ["password", "secret", "api_key", "token"]
@@ -80,13 +80,13 @@ async def security_filter_middleware(
raise MiddlewareTermination(result=context.result)
await call_next(context)
await call_next()
@function_middleware
async def atlantis_location_filter_middleware(
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Function middleware that blocks weather requests for Atlantis."""
# Check if location parameter is "atlantis"
@@ -98,7 +98,7 @@ async def atlantis_location_filter_middleware(
)
raise MiddlewareTermination(result=context.result)
await call_next(context)
await call_next()
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/getting_started/tools/function_tool_with_approval.py and samples/getting_started/tools/function_tool_with_approval_and_threads.py.
@@ -68,7 +68,7 @@ def get_weather(
class SecurityAgentMiddleware(AgentMiddleware):
"""Agent-level security middleware that validates all requests."""
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
print("[SecurityMiddleware] Checking security for all requests...")
# Check for security violations in the last user message
@@ -81,18 +81,18 @@ class SecurityAgentMiddleware(AgentMiddleware):
print("[SecurityMiddleware] Security check passed.")
context.metadata["security_validated"] = True
await call_next(context)
await call_next()
async def performance_monitor_middleware(
context: AgentContext,
call_next: Callable[[AgentContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Agent-level performance monitoring for all runs."""
print("[PerformanceMonitor] Starting performance monitoring...")
start_time = time.time()
await call_next(context)
await call_next()
end_time = time.time()
duration = end_time - start_time
@@ -104,7 +104,7 @@ async def performance_monitor_middleware(
class HighPriorityMiddleware(AgentMiddleware):
"""Run-level middleware for high priority requests."""
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
print("[HighPriority] Processing high priority request with expedited handling...")
# Read metadata set by agent-level middleware
@@ -115,13 +115,13 @@ class HighPriorityMiddleware(AgentMiddleware):
context.metadata["priority"] = "high"
context.metadata["expedited"] = True
await call_next(context)
await call_next()
print("[HighPriority] High priority processing completed")
async def debugging_middleware(
context: AgentContext,
call_next: Callable[[AgentContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Run-level debugging middleware for troubleshooting specific runs."""
print("[Debug] Debug mode enabled for this run")
@@ -134,7 +134,7 @@ async def debugging_middleware(
context.metadata["debug_enabled"] = True
await call_next(context)
await call_next()
print("[Debug] Debug information collected")
@@ -145,7 +145,7 @@ class CachingMiddleware(AgentMiddleware):
def __init__(self) -> None:
self.cache: dict[str, AgentResponse] = {}
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Create a simple cache key from the last message
last_message = context.messages[-1] if context.messages else None
cache_key: str = last_message.text if last_message and last_message.text else "no_message"
@@ -158,7 +158,7 @@ class CachingMiddleware(AgentMiddleware):
print(f"[Cache] Cache MISS for: '{cache_key[:30]}...'")
context.metadata["cache_key"] = cache_key
await call_next(context)
await call_next()
# Cache the result if we have one
if context.result:
@@ -168,14 +168,14 @@ class CachingMiddleware(AgentMiddleware):
async def function_logging_middleware(
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Function middleware that logs all function calls."""
function_name = context.function.name
args = context.arguments
print(f"[FunctionLog] Calling function: {function_name} with args: {args}")
await call_next(context)
await call_next()
print(f"[FunctionLog] Function {function_name} completed")
@@ -275,7 +275,7 @@ async def main() -> None:
query = "What's the secret weather password for Berlin?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'Request was blocked by security middleware'}")
print(f"Agent: {result.text if result and result.text else 'Request was blocked by security middleware'}")
print()
# Run 7: Normal query again (no run-level middleware interference)
@@ -57,7 +57,7 @@ class InputObserverMiddleware(ChatMiddleware):
async def process(
self,
context: ChatContext,
call_next: Callable[[ChatContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Observe and modify input messages before they are sent to AI."""
print("[InputObserverMiddleware] Observing input messages:")
@@ -91,7 +91,7 @@ class InputObserverMiddleware(ChatMiddleware):
context.messages[:] = modified_messages
# Continue to next middleware or AI execution
await call_next(context)
await call_next()
# Observe that processing is complete
print("[InputObserverMiddleware] Processing completed")
@@ -100,7 +100,7 @@ class InputObserverMiddleware(ChatMiddleware):
@chat_middleware
async def security_and_override_middleware(
context: ChatContext,
call_next: Callable[[ChatContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Function-based middleware that implements security filtering and response override."""
print("[SecurityMiddleware] Processing input...")
@@ -131,7 +131,7 @@ async def security_and_override_middleware(
raise MiddlewareTermination
# Continue to next middleware or AI execution
await call_next(context)
await call_next()
async def class_based_chat_middleware() -> None:
@@ -50,7 +50,7 @@ class SecurityAgentMiddleware(AgentMiddleware):
async def process(
self,
context: AgentContext,
call_next: Callable[[AgentContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
# Check for potential security violations in the query
# Look at the last user message
@@ -67,7 +67,7 @@ class SecurityAgentMiddleware(AgentMiddleware):
return
print("[SecurityAgentMiddleware] Security check passed.")
await call_next(context)
await call_next()
class LoggingFunctionMiddleware(FunctionMiddleware):
@@ -76,14 +76,14 @@ class LoggingFunctionMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
function_name = context.function.name
print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.")
start_time = time.time()
await call_next(context)
await call_next()
end_time = time.time()
duration = end_time - start_time
@@ -53,7 +53,7 @@ def get_current_time() -> str:
async def simple_agent_middleware(context, call_next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality
"""Agent middleware that runs before and after agent execution."""
print("[Agent MiddlewareTypes] Before agent execution")
await call_next(context)
await call_next()
print("[Agent MiddlewareTypes] After agent execution")
@@ -61,7 +61,7 @@ async def simple_agent_middleware(context, call_next): # type: ignore - paramet
async def simple_function_middleware(context, call_next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality
"""Function middleware that runs before and after function calls."""
print(f"[Function MiddlewareTypes] Before calling: {context.function.name}") # type: ignore
await call_next(context)
await call_next()
print(f"[Function MiddlewareTypes] After calling: {context.function.name}") # type: ignore
@@ -35,13 +35,13 @@ def unstable_data_service(
async def exception_handling_middleware(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
function_name = context.function.name
try:
print(f"[ExceptionHandlingMiddleware] Executing function: {function_name}")
await call_next(context)
await call_next()
print(f"[ExceptionHandlingMiddleware] Function {function_name} completed successfully.")
except TimeoutError as e:
print(f"[ExceptionHandlingMiddleware] Caught TimeoutError: {e}")
@@ -43,7 +43,7 @@ def get_weather(
async def security_agent_middleware(
context: AgentContext,
call_next: Callable[[AgentContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Agent middleware that checks for security violations."""
# Check for potential security violations in the query
@@ -57,12 +57,12 @@ async def security_agent_middleware(
return
print("[SecurityAgentMiddleware] Security check passed.")
await call_next(context)
await call_next()
async def logging_function_middleware(
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Function middleware that logs function calls."""
function_name = context.function.name
@@ -70,7 +70,7 @@ async def logging_function_middleware(
start_time = time.time()
await call_next(context)
await call_next()
end_time = time.time()
duration = end_time - start_time
@@ -105,7 +105,7 @@ async def main() -> None:
query = "What's the secret weather password?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response'}\n")
print(f"Agent: {result.text if result and result.text else 'No response'}\n")
if __name__ == "__main__":
@@ -49,7 +49,7 @@ class PreTerminationMiddleware(AgentMiddleware):
async def process(
self,
context: AgentContext,
call_next: Callable[[AgentContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
# Check if the user message contains any blocked words
last_message = context.messages[-1] if context.messages else None
@@ -75,7 +75,7 @@ class PreTerminationMiddleware(AgentMiddleware):
# Terminate to prevent further processing
raise MiddlewareTermination(result=context.result)
await call_next(context)
await call_next()
class PostTerminationMiddleware(AgentMiddleware):
@@ -88,7 +88,7 @@ class PostTerminationMiddleware(AgentMiddleware):
async def process(
self,
context: AgentContext,
call_next: Callable[[AgentContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
print(f"[PostTerminationMiddleware] Processing request (response count: {self.response_count})")
@@ -101,7 +101,7 @@ class PostTerminationMiddleware(AgentMiddleware):
raise MiddlewareTermination
# Allow the agent to process normally
await call_next(context)
await call_next()
# Increment response count after processing
self.response_count += 1
@@ -158,14 +158,14 @@ async def post_termination_middleware() -> None:
query = "What about the weather in London?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response (terminated)'}")
print(f"Agent: {result.text if result and result.text else 'No response (terminated)'}")
# Third run (should also be terminated)
print("\n3. Third run (should also be terminated):")
query = "And New York?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response (terminated)'}")
print(f"Agent: {result.text if result and result.text else 'No response (terminated)'}")
async def main() -> None:
@@ -49,11 +49,11 @@ def get_weather(
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
async def weather_override_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def weather_override_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
"""Chat middleware that overrides weather results for both streaming and non-streaming cases."""
# Let the original agent execution complete first
await call_next(context)
await call_next()
# Check if there's a result to override (agent called weather function)
if context.result is not None:
@@ -84,9 +84,9 @@ async def weather_override_middleware(context: ChatContext, call_next: Callable[
context.result = ChatResponse(messages=[Message(role=Role.ASSISTANT, text=custom_message)])
async def validate_weather_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
async def validate_weather_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
"""Chat middleware that simulates result validation for both streaming and non-streaming cases."""
await call_next(context)
await call_next()
validation_note = "Validation: weather data verified."
@@ -104,9 +104,9 @@ async def validate_weather_middleware(context: ChatContext, call_next: Callable[
context.result.messages.append(Message(role=Role.ASSISTANT, text=validation_note))
async def agent_cleanup_middleware(context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def agent_cleanup_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
"""Agent middleware that validates chat middleware effects and cleans the result."""
await call_next(context)
await call_next()
if context.result is None:
return
@@ -54,7 +54,7 @@ class SessionContextContainer:
async def inject_context_middleware(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""MiddlewareTypes that extracts runtime context from kwargs and stores in container.
@@ -74,7 +74,7 @@ class SessionContextContainer:
print(f" - Session Metadata Keys: {list(self.session_metadata.keys())}")
# Continue to tool execution
await call_next(context)
await call_next()
# Create a container instance that will be shared via closure
@@ -278,19 +278,19 @@ async def pattern_2_hierarchical_with_kwargs_propagation() -> None:
@function_middleware
async def email_kwargs_tracker(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
email_agent_kwargs.update(context.kwargs)
print(f"[EmailAgent] Received runtime context: {list(context.kwargs.keys())}")
await call_next(context)
await call_next()
@function_middleware
async def sms_kwargs_tracker(
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
sms_agent_kwargs.update(context.kwargs)
print(f"[SMSAgent] Received runtime context: {list(context.kwargs.keys())}")
await call_next(context)
await call_next()
client = OpenAIChatClient(model_id="gpt-4o-mini")
@@ -359,7 +359,7 @@ class AuthContextMiddleware:
self.validated_tokens: list[str] = []
async def validate_and_track(
self, context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
self, context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
"""Validate API token and track usage."""
api_token = context.kwargs.get("api_token")
@@ -375,7 +375,7 @@ class AuthContextMiddleware:
else:
print("[AuthMiddleware] No API token provided")
await call_next(context)
await call_next()
@tool(approval_mode="never_require")
@@ -57,7 +57,7 @@ class MiddlewareContainer:
async def call_counter_middleware(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""First middleware: increments call count in shared state."""
# Increment the shared call count
@@ -66,18 +66,18 @@ class MiddlewareContainer:
print(f"[CallCounter] This is function call #{self.call_count}")
# Call the next middleware/function
await call_next(context)
await call_next()
async def result_enhancer_middleware(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Second middleware: uses shared call count to enhance function results."""
print(f"[ResultEnhancer] Current total calls so far: {self.call_count}")
# Call the next middleware/function
await call_next(context)
await call_next()
# After function execution, enhance the result using shared state
if context.result:
@@ -46,7 +46,7 @@ def get_weather(
async def thread_tracking_middleware(
context: AgentContext,
call_next: Callable[[AgentContext], Awaitable[None]],
call_next: Callable[[], Awaitable[None]],
) -> None:
"""MiddlewareTypes that tracks and logs thread behavior across runs."""
thread_messages = []
@@ -57,7 +57,7 @@ async def thread_tracking_middleware(
print(f"[MiddlewareTypes pre-execution] Thread history messages: {len(thread_messages)}")
# Call call_next to execute the agent
await call_next(context)
await call_next()
# Check thread state after agent execution
updated_thread_messages = []