mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Removed context parameter from call_next (#3829)
This commit is contained in:
committed by
GitHub
Unverified
parent
38f22ef006
commit
1fdc4be88d
@@ -267,7 +267,7 @@ class TerminatingMiddleware(FunctionMiddleware):
|
||||
if self.should_terminate(context):
|
||||
context.result = "terminated by middleware"
|
||||
raise MiddlewareTermination # Exit function invocation loop
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
```
|
||||
|
||||
## Arguments Added/Altered at Each Layer
|
||||
@@ -347,7 +347,7 @@ class CachingMiddleware(FunctionMiddleware):
|
||||
return # Upstream post-processing still runs
|
||||
|
||||
# Option B: Call call_next, then return normally
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
self.cache[context.function.name] = context.result
|
||||
return # Normal completion
|
||||
```
|
||||
@@ -362,7 +362,7 @@ class BlockedFunctionMiddleware(FunctionMiddleware):
|
||||
if context.function.name in self.blocked_functions:
|
||||
context.result = "Function blocked by policy"
|
||||
raise MiddlewareTermination("Blocked") # Skips ALL post-processing
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
```
|
||||
|
||||
### 3. Raise Any Other Exception
|
||||
@@ -374,7 +374,7 @@ class ValidationMiddleware(FunctionMiddleware):
|
||||
async def process(self, context: FunctionInvocationContext, call_next):
|
||||
if not self.is_valid(context.arguments):
|
||||
raise ValueError("Invalid arguments") # Bubbles up to user
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
```
|
||||
|
||||
## `return` vs `raise MiddlewareTermination`
|
||||
@@ -385,7 +385,7 @@ The key difference is what happens to **upstream middleware's post-processing**:
|
||||
class MiddlewareA(AgentMiddleware):
|
||||
async def process(self, context, call_next):
|
||||
print("A: before")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
print("A: after") # Does this run?
|
||||
|
||||
class MiddlewareB(AgentMiddleware):
|
||||
@@ -410,7 +410,7 @@ With middleware registered as `[MiddlewareA, MiddlewareB]`:
|
||||
|
||||
## Calling `call_next()` or Not
|
||||
|
||||
The decision to call `call_next(context)` determines whether downstream middleware and the actual operation execute:
|
||||
The decision to call `call_next()` determines whether downstream middleware and the actual operation execute:
|
||||
|
||||
### Without calling `call_next()` - Skip downstream
|
||||
|
||||
@@ -430,7 +430,7 @@ async def process(self, context, call_next):
|
||||
```python
|
||||
async def process(self, context, call_next):
|
||||
# Pre-processing
|
||||
await call_next(context) # Execute downstream + actual operation
|
||||
await call_next() # Execute downstream + actual operation
|
||||
# Post-processing (context.result now contains real result)
|
||||
return
|
||||
```
|
||||
@@ -450,7 +450,7 @@ async def process(self, context, call_next):
|
||||
| `raise MiddlewareTermination` | Yes | ✅ | ✅ | ❌ No |
|
||||
| `raise OtherException` | Either | Depends | Depends | ❌ No (exception propagates) |
|
||||
|
||||
> **Note:** The first row (`return` after calling `call_next()`) is the default behavior. Python functions implicitly return `None` at the end, so simply calling `await call_next(context)` without an explicit `return` statement achieves this pattern.
|
||||
> **Note:** The first row (`return` after calling `call_next()`) is the default behavior. Python functions implicitly return `None` at the end, so simply calling `await call_next()` without an explicit `return` statement achieves this pattern.
|
||||
|
||||
## Streaming vs Non-Streaming
|
||||
|
||||
|
||||
@@ -20,13 +20,13 @@ multiple specialized agents, each focusing on specific tasks.
|
||||
|
||||
async def logging_middleware(
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""MiddlewareTypes that logs tool invocations to show the delegation flow."""
|
||||
print(f"[Calling tool: {context.function.name}]")
|
||||
print(f"[Request: {context.arguments}]")
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
print(f"[Response: {context.result}]")
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ response generation, showing both streaming and non-streaming responses.
|
||||
@chat_middleware
|
||||
async def security_and_override_middleware(
|
||||
context: ChatContext,
|
||||
call_next: Callable[[ChatContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Function-based middleware that implements security filtering and response override."""
|
||||
print("[SecurityMiddleware] Processing input...")
|
||||
@@ -60,7 +60,7 @@ async def security_and_override_middleware(
|
||||
raise MiddlewareTermination(result=context.result)
|
||||
|
||||
# Continue to next middleware or AI execution
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
print("[SecurityMiddleware] Response generated.")
|
||||
print(type(context.result))
|
||||
|
||||
+2
-2
@@ -19,13 +19,13 @@ multiple specialized agents, each focusing on specific tasks.
|
||||
|
||||
async def logging_middleware(
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""MiddlewareTypes that logs tool invocations to show the delegation flow."""
|
||||
print(f"[Calling tool: {context.function.name}]")
|
||||
print(f"[Request: {context.arguments}]")
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
print(f"[Response: {context.result}]")
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ def cleanup_resources():
|
||||
@chat_middleware
|
||||
async def security_filter_middleware(
|
||||
context: ChatContext,
|
||||
call_next: Callable[[ChatContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Chat middleware that blocks requests containing sensitive information."""
|
||||
blocked_terms = ["password", "secret", "api_key", "token"]
|
||||
@@ -80,13 +80,13 @@ async def security_filter_middleware(
|
||||
|
||||
raise MiddlewareTermination(result=context.result)
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
|
||||
@function_middleware
|
||||
async def atlantis_location_filter_middleware(
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Function middleware that blocks weather requests for Atlantis."""
|
||||
# Check if location parameter is "atlantis"
|
||||
@@ -98,7 +98,7 @@ async def atlantis_location_filter_middleware(
|
||||
)
|
||||
raise MiddlewareTermination(result=context.result)
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
|
||||
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/getting_started/tools/function_tool_with_approval.py and samples/getting_started/tools/function_tool_with_approval_and_threads.py.
|
||||
|
||||
@@ -68,7 +68,7 @@ def get_weather(
|
||||
class SecurityAgentMiddleware(AgentMiddleware):
|
||||
"""Agent-level security middleware that validates all requests."""
|
||||
|
||||
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
print("[SecurityMiddleware] Checking security for all requests...")
|
||||
|
||||
# Check for security violations in the last user message
|
||||
@@ -81,18 +81,18 @@ class SecurityAgentMiddleware(AgentMiddleware):
|
||||
|
||||
print("[SecurityMiddleware] Security check passed.")
|
||||
context.metadata["security_validated"] = True
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
|
||||
async def performance_monitor_middleware(
|
||||
context: AgentContext,
|
||||
call_next: Callable[[AgentContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Agent-level performance monitoring for all runs."""
|
||||
print("[PerformanceMonitor] Starting performance monitoring...")
|
||||
start_time = time.time()
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
@@ -104,7 +104,7 @@ async def performance_monitor_middleware(
|
||||
class HighPriorityMiddleware(AgentMiddleware):
|
||||
"""Run-level middleware for high priority requests."""
|
||||
|
||||
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
print("[HighPriority] Processing high priority request with expedited handling...")
|
||||
|
||||
# Read metadata set by agent-level middleware
|
||||
@@ -115,13 +115,13 @@ class HighPriorityMiddleware(AgentMiddleware):
|
||||
context.metadata["priority"] = "high"
|
||||
context.metadata["expedited"] = True
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
print("[HighPriority] High priority processing completed")
|
||||
|
||||
|
||||
async def debugging_middleware(
|
||||
context: AgentContext,
|
||||
call_next: Callable[[AgentContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Run-level debugging middleware for troubleshooting specific runs."""
|
||||
print("[Debug] Debug mode enabled for this run")
|
||||
@@ -134,7 +134,7 @@ async def debugging_middleware(
|
||||
|
||||
context.metadata["debug_enabled"] = True
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
print("[Debug] Debug information collected")
|
||||
|
||||
@@ -145,7 +145,7 @@ class CachingMiddleware(AgentMiddleware):
|
||||
def __init__(self) -> None:
|
||||
self.cache: dict[str, AgentResponse] = {}
|
||||
|
||||
async def process(self, context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Create a simple cache key from the last message
|
||||
last_message = context.messages[-1] if context.messages else None
|
||||
cache_key: str = last_message.text if last_message and last_message.text else "no_message"
|
||||
@@ -158,7 +158,7 @@ class CachingMiddleware(AgentMiddleware):
|
||||
print(f"[Cache] Cache MISS for: '{cache_key[:30]}...'")
|
||||
context.metadata["cache_key"] = cache_key
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Cache the result if we have one
|
||||
if context.result:
|
||||
@@ -168,14 +168,14 @@ class CachingMiddleware(AgentMiddleware):
|
||||
|
||||
async def function_logging_middleware(
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Function middleware that logs all function calls."""
|
||||
function_name = context.function.name
|
||||
args = context.arguments
|
||||
print(f"[FunctionLog] Calling function: {function_name} with args: {args}")
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
print(f"[FunctionLog] Function {function_name} completed")
|
||||
|
||||
@@ -275,7 +275,7 @@ async def main() -> None:
|
||||
query = "What's the secret weather password for Berlin?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'Request was blocked by security middleware'}")
|
||||
print(f"Agent: {result.text if result and result.text else 'Request was blocked by security middleware'}")
|
||||
print()
|
||||
|
||||
# Run 7: Normal query again (no run-level middleware interference)
|
||||
|
||||
@@ -57,7 +57,7 @@ class InputObserverMiddleware(ChatMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: ChatContext,
|
||||
call_next: Callable[[ChatContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Observe and modify input messages before they are sent to AI."""
|
||||
print("[InputObserverMiddleware] Observing input messages:")
|
||||
@@ -91,7 +91,7 @@ class InputObserverMiddleware(ChatMiddleware):
|
||||
context.messages[:] = modified_messages
|
||||
|
||||
# Continue to next middleware or AI execution
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Observe that processing is complete
|
||||
print("[InputObserverMiddleware] Processing completed")
|
||||
@@ -100,7 +100,7 @@ class InputObserverMiddleware(ChatMiddleware):
|
||||
@chat_middleware
|
||||
async def security_and_override_middleware(
|
||||
context: ChatContext,
|
||||
call_next: Callable[[ChatContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Function-based middleware that implements security filtering and response override."""
|
||||
print("[SecurityMiddleware] Processing input...")
|
||||
@@ -131,7 +131,7 @@ async def security_and_override_middleware(
|
||||
raise MiddlewareTermination
|
||||
|
||||
# Continue to next middleware or AI execution
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
|
||||
async def class_based_chat_middleware() -> None:
|
||||
|
||||
@@ -50,7 +50,7 @@ class SecurityAgentMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: AgentContext,
|
||||
call_next: Callable[[AgentContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
# Check for potential security violations in the query
|
||||
# Look at the last user message
|
||||
@@ -67,7 +67,7 @@ class SecurityAgentMiddleware(AgentMiddleware):
|
||||
return
|
||||
|
||||
print("[SecurityAgentMiddleware] Security check passed.")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
|
||||
class LoggingFunctionMiddleware(FunctionMiddleware):
|
||||
@@ -76,14 +76,14 @@ class LoggingFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
function_name = context.function.name
|
||||
print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
|
||||
@@ -53,7 +53,7 @@ def get_current_time() -> str:
|
||||
async def simple_agent_middleware(context, call_next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality
|
||||
"""Agent middleware that runs before and after agent execution."""
|
||||
print("[Agent MiddlewareTypes] Before agent execution")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
print("[Agent MiddlewareTypes] After agent execution")
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ async def simple_agent_middleware(context, call_next): # type: ignore - paramet
|
||||
async def simple_function_middleware(context, call_next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality
|
||||
"""Function middleware that runs before and after function calls."""
|
||||
print(f"[Function MiddlewareTypes] Before calling: {context.function.name}") # type: ignore
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
print(f"[Function MiddlewareTypes] After calling: {context.function.name}") # type: ignore
|
||||
|
||||
|
||||
|
||||
@@ -35,13 +35,13 @@ def unstable_data_service(
|
||||
|
||||
|
||||
async def exception_handling_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
function_name = context.function.name
|
||||
|
||||
try:
|
||||
print(f"[ExceptionHandlingMiddleware] Executing function: {function_name}")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
print(f"[ExceptionHandlingMiddleware] Function {function_name} completed successfully.")
|
||||
except TimeoutError as e:
|
||||
print(f"[ExceptionHandlingMiddleware] Caught TimeoutError: {e}")
|
||||
|
||||
@@ -43,7 +43,7 @@ def get_weather(
|
||||
|
||||
async def security_agent_middleware(
|
||||
context: AgentContext,
|
||||
call_next: Callable[[AgentContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Agent middleware that checks for security violations."""
|
||||
# Check for potential security violations in the query
|
||||
@@ -57,12 +57,12 @@ async def security_agent_middleware(
|
||||
return
|
||||
|
||||
print("[SecurityAgentMiddleware] Security check passed.")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
|
||||
async def logging_function_middleware(
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Function middleware that logs function calls."""
|
||||
function_name = context.function.name
|
||||
@@ -70,7 +70,7 @@ async def logging_function_middleware(
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
@@ -105,7 +105,7 @@ async def main() -> None:
|
||||
query = "What's the secret weather password?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'No response'}\n")
|
||||
print(f"Agent: {result.text if result and result.text else 'No response'}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -49,7 +49,7 @@ class PreTerminationMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: AgentContext,
|
||||
call_next: Callable[[AgentContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
# Check if the user message contains any blocked words
|
||||
last_message = context.messages[-1] if context.messages else None
|
||||
@@ -75,7 +75,7 @@ class PreTerminationMiddleware(AgentMiddleware):
|
||||
# Terminate to prevent further processing
|
||||
raise MiddlewareTermination(result=context.result)
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
|
||||
class PostTerminationMiddleware(AgentMiddleware):
|
||||
@@ -88,7 +88,7 @@ class PostTerminationMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: AgentContext,
|
||||
call_next: Callable[[AgentContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
print(f"[PostTerminationMiddleware] Processing request (response count: {self.response_count})")
|
||||
|
||||
@@ -101,7 +101,7 @@ class PostTerminationMiddleware(AgentMiddleware):
|
||||
raise MiddlewareTermination
|
||||
|
||||
# Allow the agent to process normally
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Increment response count after processing
|
||||
self.response_count += 1
|
||||
@@ -158,14 +158,14 @@ async def post_termination_middleware() -> None:
|
||||
query = "What about the weather in London?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'No response (terminated)'}")
|
||||
print(f"Agent: {result.text if result and result.text else 'No response (terminated)'}")
|
||||
|
||||
# Third run (should also be terminated)
|
||||
print("\n3. Third run (should also be terminated):")
|
||||
query = "And New York?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'No response (terminated)'}")
|
||||
print(f"Agent: {result.text if result and result.text else 'No response (terminated)'}")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
|
||||
@@ -49,11 +49,11 @@ def get_weather(
|
||||
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
|
||||
|
||||
|
||||
async def weather_override_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def weather_override_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
"""Chat middleware that overrides weather results for both streaming and non-streaming cases."""
|
||||
|
||||
# Let the original agent execution complete first
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Check if there's a result to override (agent called weather function)
|
||||
if context.result is not None:
|
||||
@@ -84,9 +84,9 @@ async def weather_override_middleware(context: ChatContext, call_next: Callable[
|
||||
context.result = ChatResponse(messages=[Message(role=Role.ASSISTANT, text=custom_message)])
|
||||
|
||||
|
||||
async def validate_weather_middleware(context: ChatContext, call_next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
async def validate_weather_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
"""Chat middleware that simulates result validation for both streaming and non-streaming cases."""
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
validation_note = "Validation: weather data verified."
|
||||
|
||||
@@ -104,9 +104,9 @@ async def validate_weather_middleware(context: ChatContext, call_next: Callable[
|
||||
context.result.messages.append(Message(role=Role.ASSISTANT, text=validation_note))
|
||||
|
||||
|
||||
async def agent_cleanup_middleware(context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]) -> None:
|
||||
async def agent_cleanup_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
"""Agent middleware that validates chat middleware effects and cleans the result."""
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
if context.result is None:
|
||||
return
|
||||
|
||||
@@ -54,7 +54,7 @@ class SessionContextContainer:
|
||||
async def inject_context_middleware(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""MiddlewareTypes that extracts runtime context from kwargs and stores in container.
|
||||
|
||||
@@ -74,7 +74,7 @@ class SessionContextContainer:
|
||||
print(f" - Session Metadata Keys: {list(self.session_metadata.keys())}")
|
||||
|
||||
# Continue to tool execution
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
|
||||
# Create a container instance that will be shared via closure
|
||||
@@ -278,19 +278,19 @@ async def pattern_2_hierarchical_with_kwargs_propagation() -> None:
|
||||
|
||||
@function_middleware
|
||||
async def email_kwargs_tracker(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
email_agent_kwargs.update(context.kwargs)
|
||||
print(f"[EmailAgent] Received runtime context: {list(context.kwargs.keys())}")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
@function_middleware
|
||||
async def sms_kwargs_tracker(
|
||||
context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
sms_agent_kwargs.update(context.kwargs)
|
||||
print(f"[SMSAgent] Received runtime context: {list(context.kwargs.keys())}")
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
client = OpenAIChatClient(model_id="gpt-4o-mini")
|
||||
|
||||
@@ -359,7 +359,7 @@ class AuthContextMiddleware:
|
||||
self.validated_tokens: list[str] = []
|
||||
|
||||
async def validate_and_track(
|
||||
self, context: FunctionInvocationContext, call_next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
self, context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Validate API token and track usage."""
|
||||
api_token = context.kwargs.get("api_token")
|
||||
@@ -375,7 +375,7 @@ class AuthContextMiddleware:
|
||||
else:
|
||||
print("[AuthMiddleware] No API token provided")
|
||||
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
|
||||
@tool(approval_mode="never_require")
|
||||
|
||||
@@ -57,7 +57,7 @@ class MiddlewareContainer:
|
||||
async def call_counter_middleware(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""First middleware: increments call count in shared state."""
|
||||
# Increment the shared call count
|
||||
@@ -66,18 +66,18 @@ class MiddlewareContainer:
|
||||
print(f"[CallCounter] This is function call #{self.call_count}")
|
||||
|
||||
# Call the next middleware/function
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
async def result_enhancer_middleware(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Second middleware: uses shared call count to enhance function results."""
|
||||
print(f"[ResultEnhancer] Current total calls so far: {self.call_count}")
|
||||
|
||||
# Call the next middleware/function
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# After function execution, enhance the result using shared state
|
||||
if context.result:
|
||||
|
||||
@@ -46,7 +46,7 @@ def get_weather(
|
||||
|
||||
async def thread_tracking_middleware(
|
||||
context: AgentContext,
|
||||
call_next: Callable[[AgentContext], Awaitable[None]],
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""MiddlewareTypes that tracks and logs thread behavior across runs."""
|
||||
thread_messages = []
|
||||
@@ -57,7 +57,7 @@ async def thread_tracking_middleware(
|
||||
print(f"[MiddlewareTypes pre-execution] Thread history messages: {len(thread_messages)}")
|
||||
|
||||
# Call call_next to execute the agent
|
||||
await call_next(context)
|
||||
await call_next()
|
||||
|
||||
# Check thread state after agent execution
|
||||
updated_thread_messages = []
|
||||
|
||||
Reference in New Issue
Block a user