Files
agent-framework/python/samples/getting_started/middleware/class_based_middleware.py
T
Eduard van Valkenburg 0521f5bed8 Python: [BREAKING] Simplify API: ChatAgent -> Agent, ChatMessage -> Message (#3747)
* [BREAKING] Rename ChatAgent -> Agent, ChatMessage -> Message, ChatClientProtocol -> SupportsChatGetResponse

Simplify the public API by removing redundant 'Chat' prefix from core types:
- ChatAgent -> Agent
- RawChatAgent -> RawAgent
- ChatMessage -> Message
- ChatClientProtocol -> SupportsChatGetResponse

Also renamed internal WorkflowMessage (was Message in _runner_context) to avoid collision.

No backward compatibility aliases - this is a clean breaking change.

* [BREAKING] Rename Agent chat_client parameter to client

* Fix rebase issues: WorkflowMessage references and broken markdown links

* Fix formatting and lint issues from code quality checks

* Fix import ordering in workflow sample files

* fixed rebase

* Fix test failures: use WorkflowMessage and A2AMessage after ChatMessage→Message rename

- Replace Message(data=..., source_id=...) with WorkflowMessage(...) in workflow tests
- Fix isinstance check in A2A agent to use A2AMessage instead of Message
- Fix import in test_workflow_observability.py (Message→WorkflowMessage)

* Fix lint, fmt, and sample errors after ChatMessage→Message rename

- Auto-fix 70+ ruff lint issues across samples (ChatMessage→Message refs)
- Fix HostedVectorStoreContent→Content.from_hosted_vector_store in file search sample
- Fix _normalize_messages→normalize_messages in custom agent sample
- Fix context.terminate→raise MiddlewareTermination in middleware samples
- Fix with_update_hook→with_transform_hook in override middleware sample
- Add TOptions_co import back to custom_chat_client sample
- Add noqa for FastAPI File() default in chatkit sample
- Fix B023 loop variable capture in weather agent sample

* fix: update Agent constructor calls from chat_client to client in declaration-only tool tests

* fix: add register_cleanup to devui lazy-loading proxy and type stub

* fixed tests and updated new pieces

* fix agui typevar

* fix merge errors

* fix merge conflicts

* fiux merge

* Remove unused links

---------

Co-authored-by: Evan Mattson <evan.mattson@microsoft.com>
2026-02-10 23:04:32 +00:00

126 lines
4.6 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
import asyncio
import time
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
AgentContext,
AgentMiddleware,
AgentResponse,
FunctionInvocationContext,
FunctionMiddleware,
Message,
tool,
)
from agent_framework.azure import AzureAIAgentClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Class-based MiddlewareTypes Example
This sample demonstrates how to implement middleware using class-based approach by inheriting
from AgentMiddleware and FunctionMiddleware base classes. The example includes:
- SecurityAgentMiddleware: Checks for security violations in user queries and blocks requests
containing sensitive information like passwords or secrets
- LoggingFunctionMiddleware: Logs function execution details including timing and parameters
This approach is useful when you need stateful middleware or complex logic that benefits
from object-oriented design patterns.
"""
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/getting_started/tools/function_tool_with_approval.py and samples/getting_started/tools/function_tool_with_approval_and_threads.py.
@tool(approval_mode="never_require")
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
class SecurityAgentMiddleware(AgentMiddleware):
"""Agent middleware that checks for security violations."""
async def process(
self,
context: AgentContext,
call_next: Callable[[AgentContext], Awaitable[None]],
) -> None:
# Check for potential security violations in the query
# Look at the last user message
last_message = context.messages[-1] if context.messages else None
if last_message and last_message.text:
query = last_message.text
if "password" in query.lower() or "secret" in query.lower():
print("[SecurityAgentMiddleware] Security Warning: Detected sensitive information, blocking request.")
# Override the result with warning message
context.result = AgentResponse(
messages=[Message("assistant", ["Detected sensitive information, the request is blocked."])]
)
# Simply don't call call_next() to prevent execution
return
print("[SecurityAgentMiddleware] Security check passed.")
await call_next(context)
class LoggingFunctionMiddleware(FunctionMiddleware):
"""Function middleware that logs function calls."""
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
function_name = context.function.name
print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.")
start_time = time.time()
await call_next(context)
end_time = time.time()
duration = end_time - start_time
print(f"[LoggingFunctionMiddleware] Function {function_name} completed in {duration:.5f}s.")
async def main() -> None:
"""Example demonstrating class-based middleware."""
print("=== Class-based MiddlewareTypes Example ===")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
AzureAIAgentClient(credential=credential).as_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=get_weather,
middleware=[SecurityAgentMiddleware(), LoggingFunctionMiddleware()],
) as agent,
):
# Test with normal query
print("\n--- Normal Query ---")
query = "What's the weather like in Seattle?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}\n")
# Test with security-related query
print("--- Security Test ---")
query = "What's the password for the weather service?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}\n")
if __name__ == "__main__":
asyncio.run(main())