From 09f59b21ad004eabb86bd5f7db0e830027f4ca05 Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Thu, 5 Feb 2026 22:47:51 -0800 Subject: [PATCH] Python: [BREAKING] Renamed AgentRunContext to AgentContext (#3714) * Renamed AgentRunContext to AgentContext * Update python/packages/core/AGENTS.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- python/packages/core/AGENTS.md | 6 +- .../core/agent_framework/_middleware.py | 44 ++--- .../core/agent_framework/_serialization.py | 10 +- .../core/test_as_tool_kwargs_propagation.py | 30 +-- .../core/tests/core/test_middleware.py | 182 +++++++----------- .../core/test_middleware_context_result.py | 52 ++--- .../tests/core/test_middleware_with_agent.py | 92 +++------ .../agent_framework_purview/_middleware.py | 6 +- .../packages/purview/tests/test_middleware.py | 50 ++--- python/samples/concepts/tools/README.md | 4 +- .../getting_started/middleware/README.md | 2 +- .../agent_and_run_level_middleware.py | 16 +- .../middleware/class_based_middleware.py | 10 +- .../middleware/decorator_middleware.py | 4 +- .../middleware/function_based_middleware.py | 6 +- .../middleware/middleware_termination.py | 10 +- .../override_result_with_middleware.py | 6 +- .../middleware/thread_behavior_middleware.py | 8 +- 18 files changed, 219 insertions(+), 319 deletions(-) diff --git a/python/packages/core/AGENTS.md b/python/packages/core/AGENTS.md index 946d077c8b..a41f5ed42f 100644 --- a/python/packages/core/AGENTS.md +++ b/python/packages/core/AGENTS.md @@ -55,7 +55,7 @@ agent_framework/ - **`AgentMiddleware`** - Intercepts agent `run()` calls - **`ChatMiddleware`** - Intercepts chat client `get_response()` calls - **`FunctionMiddleware`** - Intercepts function/tool invocations -- **`AgentRunContext`** / **`ChatContext`** / **`FunctionInvocationContext`** - Context objects passed through middleware +- **`AgentContext`** / **`ChatContext`** / **`FunctionInvocationContext`** - Context objects passed through middleware ### Threads (`_threads.py`) @@ -114,10 +114,10 @@ agent = OpenAIChatClient().as_agent( ### Middleware Pipeline ```python -from agent_framework import ChatAgent, AgentMiddleware, AgentRunContext +from agent_framework import ChatAgent, AgentMiddleware, AgentContext class LoggingMiddleware(AgentMiddleware): - async def invoke(self, context: AgentRunContext, next) -> AgentResponse: + async def process(self, context: AgentContext, next) -> AgentResponse: print(f"Input: {context.messages}") response = await next(context) print(f"Output: {response}") diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index 44a55b13b3..7f6619570e 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -43,10 +43,10 @@ if TYPE_CHECKING: TResponseModelT = TypeVar("TResponseModelT", bound=BaseModel) __all__ = [ + "AgentContext", "AgentMiddleware", "AgentMiddlewareLayer", "AgentMiddlewareTypes", - "AgentRunContext", "ChatAndFunctionMiddlewareTypes", "ChatContext", "ChatMiddleware", @@ -109,7 +109,7 @@ class MiddlewareType(str, Enum): CHAT = "chat" -class AgentRunContext: +class AgentContext: """Context object for agent middleware invocations. This context is passed through the agent middleware pipeline and contains all information @@ -131,11 +131,11 @@ class AgentRunContext: Examples: .. code-block:: python - from agent_framework import AgentMiddleware, AgentRunContext + from agent_framework import AgentMiddleware, AgentContext class LoggingMiddleware(AgentMiddleware): - async def process(self, context: AgentRunContext, next): + async def process(self, context: AgentContext, next): print(f"Agent: {context.agent.name}") print(f"Messages: {len(context.messages)}") print(f"Thread: {context.thread}") @@ -170,7 +170,7 @@ class AgentRunContext: | None = None, stream_cleanup_hooks: Sequence[Callable[[], Awaitable[None] | None]] | None = None, ) -> None: - """Initialize the AgentRunContext. + """Initialize the AgentContext. Args: agent: The agent being invoked. @@ -356,14 +356,14 @@ class AgentMiddleware(ABC): Examples: .. code-block:: python - from agent_framework import AgentMiddleware, AgentRunContext, ChatAgent + from agent_framework import AgentMiddleware, AgentContext, ChatAgent class RetryMiddleware(AgentMiddleware): def __init__(self, max_retries: int = 3): self.max_retries = max_retries - async def process(self, context: AgentRunContext, next): + async def process(self, context: AgentContext, next): for attempt in range(self.max_retries): await next(context) if context.result and not context.result.is_error: @@ -378,8 +378,8 @@ class AgentMiddleware(ABC): @abstractmethod async def process( self, - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: """Process an agent invocation. @@ -531,7 +531,7 @@ class ChatMiddleware(ABC): # Pure function type definitions for convenience -AgentMiddlewareCallable = Callable[[AgentRunContext, Callable[[AgentRunContext], Awaitable[None]]], Awaitable[None]] +AgentMiddlewareCallable = Callable[[AgentContext, Callable[[AgentContext], Awaitable[None]]], Awaitable[None]] AgentMiddlewareTypes: TypeAlias = AgentMiddleware | AgentMiddlewareCallable FunctionMiddlewareCallable = Callable[ @@ -561,7 +561,7 @@ def agent_middleware(func: AgentMiddlewareCallable) -> AgentMiddlewareCallable: """Decorator to mark a function as agent middleware. This decorator explicitly identifies a function as agent middleware, - which processes AgentRunContext objects. + which processes AgentContext objects. Args: func: The middleware function to mark as agent middleware. @@ -572,11 +572,11 @@ def agent_middleware(func: AgentMiddlewareCallable) -> AgentMiddlewareCallable: Examples: .. code-block:: python - from agent_framework import agent_middleware, AgentRunContext, ChatAgent + from agent_framework import agent_middleware, AgentContext, ChatAgent @agent_middleware - async def logging_middleware(context: AgentRunContext, next): + async def logging_middleware(context: AgentContext, next): print(f"Before: {context.agent.name}") await next(context) print(f"After: {context.result}") @@ -752,9 +752,9 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline): async def execute( self, - context: AgentRunContext, + context: AgentContext, final_handler: Callable[ - [AgentRunContext], Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse] + [AgentContext], Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse] ], ) -> AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse] | None: """Execute the agent middleware pipeline for streaming or non-streaming. @@ -772,17 +772,17 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline): context.result = await context.result return context.result - def create_next_handler(index: int) -> Callable[[AgentRunContext], Awaitable[None]]: + def create_next_handler(index: int) -> Callable[[AgentContext], Awaitable[None]]: if index >= len(self._middleware): - async def final_wrapper(c: AgentRunContext) -> None: + async def final_wrapper(c: AgentContext) -> None: c.result = final_handler(c) # type: ignore[assignment] if inspect.isawaitable(c.result): c.result = await c.result return final_wrapper - async def current_handler(c: AgentRunContext) -> None: + async def current_handler(c: AgentContext) -> None: # MiddlewareTermination bubbles up to execute() to skip post-processing await self._middleware[index].process(c, create_next_handler(index + 1)) @@ -1161,7 +1161,7 @@ class AgentMiddlewareLayer: if not pipeline.has_middlewares: return super().run(messages, stream=stream, thread=thread, options=options, **combined_kwargs) # type: ignore[misc, no-any-return] - context = AgentRunContext( + context = AgentContext( agent=self, # type: ignore[arg-type] messages=prepare_messages(messages), # type: ignore[arg-type] thread=thread, @@ -1194,7 +1194,7 @@ class AgentMiddlewareLayer: return _execute() # type: ignore[return-value] def _middleware_handler( - self, context: AgentRunContext + self, context: AgentContext ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: return super().run( # type: ignore[misc, no-any-return] context.messages, @@ -1231,7 +1231,7 @@ def _determine_middleware_type(middleware: Any) -> MiddlewareType: first_param = params[0] if hasattr(first_param.annotation, "__name__"): annotation_name = first_param.annotation.__name__ - if annotation_name == "AgentRunContext": + if annotation_name == "AgentContext": param_type = MiddlewareType.AGENT elif annotation_name == "FunctionInvocationContext": param_type = MiddlewareType.FUNCTION @@ -1270,7 +1270,7 @@ def _determine_middleware_type(middleware: Any) -> MiddlewareType: raise MiddlewareException( f"Cannot determine middleware type for function {middleware.__name__}. " f"Please either use @agent_middleware/@function_middleware/@chat_middleware decorators " - f"or specify parameter types (AgentRunContext, FunctionInvocationContext, or ChatContext)." + f"or specify parameter types (AgentContext, FunctionInvocationContext, or ChatContext)." ) diff --git a/python/packages/core/agent_framework/_serialization.py b/python/packages/core/agent_framework/_serialization.py index 0e9a34fed4..dd6b8f871f 100644 --- a/python/packages/core/agent_framework/_serialization.py +++ b/python/packages/core/agent_framework/_serialization.py @@ -477,12 +477,12 @@ class SerializationMixin: .. code-block:: python - from agent_framework._middleware import AgentRunContext + from agent_framework._middleware import AgentContext from agent_framework import BaseAgent - # AgentRunContext has INJECTABLE = {"agent", "result"} + # AgentContext has INJECTABLE = {"agent", "result"} context_data = { - "type": "agent_run_context", + "type": "agent_context", "messages": [{"role": "user", "text": "Hello"}], "stream": False, "metadata": {"session_id": "abc123"}, @@ -492,14 +492,14 @@ class SerializationMixin: # Inject agent and result during middleware processing my_agent = BaseAgent(name="test-agent") dependencies = { - "agent_run_context": { + "agent_context": { "agent": my_agent, "result": None, # Will be populated during execution } } # Reconstruct context with agent dependency for middleware chain - context = AgentRunContext.from_dict(context_data, dependencies=dependencies) + context = AgentContext.from_dict(context_data, dependencies=dependencies) # MiddlewareTypes can now access context.agent and process the execution This injection system allows the agent framework to maintain clean separation diff --git a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py index 8d262a5c23..8a2c4ceb5b 100644 --- a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py +++ b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py @@ -6,7 +6,7 @@ from collections.abc import Awaitable, Callable from typing import Any from agent_framework import ChatAgent, ChatMessage, ChatResponse, Content, agent_middleware -from agent_framework._middleware import AgentRunContext +from agent_framework._middleware import AgentContext from .conftest import MockChatClient @@ -19,9 +19,7 @@ class TestAsToolKwargsPropagation: captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Capture kwargs passed to the sub-agent captured_kwargs.update(context.kwargs) await next(context) @@ -62,9 +60,7 @@ class TestAsToolKwargsPropagation: captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) await next(context) @@ -99,9 +95,7 @@ class TestAsToolKwargsPropagation: captured_kwargs_list: list[dict[str, Any]] = [] @agent_middleware - async def capture_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Capture kwargs at each level captured_kwargs_list.append(dict(context.kwargs)) await next(context) @@ -162,9 +156,7 @@ class TestAsToolKwargsPropagation: captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) await next(context) @@ -224,9 +216,7 @@ class TestAsToolKwargsPropagation: captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) await next(context) @@ -266,9 +256,7 @@ class TestAsToolKwargsPropagation: call_count = 0 @agent_middleware - async def capture_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: nonlocal call_count call_count += 1 if call_count == 1: @@ -318,9 +306,7 @@ class TestAsToolKwargsPropagation: captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) await next(context) diff --git a/python/packages/core/tests/core/test_middleware.py b/python/packages/core/tests/core/test_middleware.py index f6a0267500..e6403fa2e2 100644 --- a/python/packages/core/tests/core/test_middleware.py +++ b/python/packages/core/tests/core/test_middleware.py @@ -18,9 +18,9 @@ from agent_framework import ( ResponseStream, ) from agent_framework._middleware import ( + AgentContext, AgentMiddleware, AgentMiddlewarePipeline, - AgentRunContext, ChatContext, ChatMiddleware, ChatMiddlewarePipeline, @@ -32,13 +32,13 @@ from agent_framework._middleware import ( from agent_framework._tools import FunctionTool -class TestAgentRunContext: - """Test cases for AgentRunContext.""" +class TestAgentContext: + """Test cases for AgentContext.""" def test_init_with_defaults(self, mock_agent: AgentProtocol) -> None: - """Test AgentRunContext initialization with default values.""" + """Test AgentContext initialization with default values.""" messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) assert context.agent is mock_agent assert context.messages == messages @@ -46,10 +46,10 @@ class TestAgentRunContext: assert context.metadata == {} def test_init_with_custom_values(self, mock_agent: AgentProtocol) -> None: - """Test AgentRunContext initialization with custom values.""" + """Test AgentContext initialization with custom values.""" messages = [ChatMessage(role="user", text="test")] metadata = {"key": "value"} - context = AgentRunContext(agent=mock_agent, messages=messages, stream=True, metadata=metadata) + context = AgentContext(agent=mock_agent, messages=messages, stream=True, metadata=metadata) assert context.agent is mock_agent assert context.messages == messages @@ -57,12 +57,12 @@ class TestAgentRunContext: assert context.metadata == metadata def test_init_with_thread(self, mock_agent: AgentProtocol) -> None: - """Test AgentRunContext initialization with thread parameter.""" + """Test AgentContext initialization with thread parameter.""" from agent_framework import AgentThread messages = [ChatMessage(role="user", text="test")] thread = AgentThread() - context = AgentRunContext(agent=mock_agent, messages=messages, thread=thread) + context = AgentContext(agent=mock_agent, messages=messages, thread=thread) assert context.agent is mock_agent assert context.messages == messages @@ -135,11 +135,11 @@ class TestAgentMiddlewarePipeline: """Test cases for AgentMiddlewarePipeline.""" class PreNextTerminateMiddleware(AgentMiddleware): - async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: raise MiddlewareTermination class PostNextTerminateMiddleware(AgentMiddleware): - async def process(self, context: AgentRunContext, next: Any) -> None: + async def process(self, context: AgentContext, next: Any) -> None: await next(context) raise MiddlewareTermination @@ -157,7 +157,7 @@ class TestAgentMiddlewarePipeline: def test_init_with_function_middleware(self) -> None: """Test AgentMiddlewarePipeline initialization with function-based middleware.""" - async def test_middleware(context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + async def test_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: await next(context) pipeline = AgentMiddlewarePipeline(test_middleware) @@ -167,11 +167,11 @@ class TestAgentMiddlewarePipeline: """Test pipeline execution with no middleware.""" pipeline = AgentMiddlewarePipeline() messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) expected_response = AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: return expected_response result = await pipeline.execute(context, final_handler) @@ -185,9 +185,7 @@ class TestAgentMiddlewarePipeline: def __init__(self, name: str): self.name = name - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") await next(context) execution_order.append(f"{self.name}_after") @@ -195,11 +193,11 @@ class TestAgentMiddlewarePipeline: middleware = OrderTrackingMiddleware("test") pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) expected_response = AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: execution_order.append("handler") return expected_response @@ -211,9 +209,9 @@ class TestAgentMiddlewarePipeline: """Test pipeline streaming execution with no middleware.""" pipeline = AgentMiddlewarePipeline() messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, stream=True) + context = AgentContext(agent=mock_agent, messages=messages, stream=True) - async def final_handler(ctx: AgentRunContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + async def final_handler(ctx: AgentContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: async def _stream() -> AsyncIterable[AgentResponseUpdate]: yield AgentResponseUpdate(contents=[Content.from_text(text="chunk1")]) yield AgentResponseUpdate(contents=[Content.from_text(text="chunk2")]) @@ -238,9 +236,7 @@ class TestAgentMiddlewarePipeline: def __init__(self, name: str): self.name = name - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") await next(context) execution_order.append(f"{self.name}_after") @@ -248,9 +244,9 @@ class TestAgentMiddlewarePipeline: middleware = StreamOrderTrackingMiddleware("test") pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, stream=True) + context = AgentContext(agent=mock_agent, messages=messages, stream=True) - async def final_handler(ctx: AgentRunContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + async def final_handler(ctx: AgentContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: async def _stream() -> AsyncIterable[AgentResponseUpdate]: execution_order.append("handler_start") yield AgentResponseUpdate(contents=[Content.from_text(text="chunk1")]) @@ -274,10 +270,10 @@ class TestAgentMiddlewarePipeline: middleware = self.PreNextTerminateMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) execution_order: list[str] = [] - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: # Handler should not be executed when terminated before next() execution_order.append("handler") return AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) @@ -292,10 +288,10 @@ class TestAgentMiddlewarePipeline: middleware = self.PostNextTerminateMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) execution_order: list[str] = [] - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: execution_order.append("handler") return AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) @@ -310,10 +306,10 @@ class TestAgentMiddlewarePipeline: middleware = self.PreNextTerminateMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, stream=True) + context = AgentContext(agent=mock_agent, messages=messages, stream=True) execution_order: list[str] = [] - async def final_handler(ctx: AgentRunContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + async def final_handler(ctx: AgentContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: async def _stream() -> AsyncIterable[AgentResponseUpdate]: # Handler should not be executed when terminated before next() execution_order.append("handler_start") @@ -338,10 +334,10 @@ class TestAgentMiddlewarePipeline: middleware = self.PostNextTerminateMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, stream=True) + context = AgentContext(agent=mock_agent, messages=messages, stream=True) execution_order: list[str] = [] - async def final_handler(ctx: AgentRunContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + async def final_handler(ctx: AgentContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: async def _stream() -> AsyncIterable[AgentResponseUpdate]: execution_order.append("handler_start") yield AgentResponseUpdate(contents=[Content.from_text(text="chunk1")]) @@ -367,9 +363,7 @@ class TestAgentMiddlewarePipeline: captured_thread = None class ThreadCapturingMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: nonlocal captured_thread captured_thread = context.thread await next(context) @@ -378,11 +372,11 @@ class TestAgentMiddlewarePipeline: pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] thread = AgentThread() - context = AgentRunContext(agent=mock_agent, messages=messages, thread=thread) + context = AgentContext(agent=mock_agent, messages=messages, thread=thread) expected_response = AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: return expected_response result = await pipeline.execute(context, final_handler) @@ -394,9 +388,7 @@ class TestAgentMiddlewarePipeline: captured_thread = "not_none" # Use string to distinguish from None class ThreadCapturingMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: nonlocal captured_thread captured_thread = context.thread await next(context) @@ -404,11 +396,11 @@ class TestAgentMiddlewarePipeline: middleware = ThreadCapturingMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, thread=None) + context = AgentContext(agent=mock_agent, messages=messages, thread=None) expected_response = AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: return expected_response result = await pipeline.execute(context, final_handler) @@ -774,9 +766,7 @@ class TestClassBasedMiddleware: metadata_updates: list[str] = [] class MetadataAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: context.metadata["before"] = True metadata_updates.append("before") await next(context) @@ -786,9 +776,9 @@ class TestClassBasedMiddleware: middleware = MetadataAgentMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: metadata_updates.append("handler") return AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) @@ -839,9 +829,7 @@ class TestFunctionBasedMiddleware: """Test function-based agent middleware.""" execution_order: list[str] = [] - async def test_agent_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def test_agent_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("function_before") context.metadata["function_middleware"] = True await next(context) @@ -849,9 +837,9 @@ class TestFunctionBasedMiddleware: pipeline = AgentMiddlewarePipeline(test_agent_middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: execution_order.append("handler") return AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) @@ -896,25 +884,21 @@ class TestMixedMiddleware: execution_order: list[str] = [] class ClassMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("class_before") await next(context) execution_order.append("class_after") - async def function_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def function_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("function_before") await next(context) execution_order.append("function_after") pipeline = AgentMiddlewarePipeline(ClassMiddleware(), function_middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: execution_order.append("handler") return AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) @@ -997,25 +981,19 @@ class TestMultipleMiddlewareOrdering: execution_order: list[str] = [] class FirstMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("first_before") await next(context) execution_order.append("first_after") class SecondMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("second_before") await next(context) execution_order.append("second_after") class ThirdMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("third_before") await next(context) execution_order.append("third_after") @@ -1023,9 +1001,9 @@ class TestMultipleMiddlewareOrdering: middleware = [FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()] pipeline = AgentMiddlewarePipeline(*middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: execution_order.append("handler") return AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) @@ -1136,9 +1114,7 @@ class TestContextContentValidation: """Test that agent context contains expected data.""" class ContextValidationMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Verify context has all expected attributes assert hasattr(context, "agent") assert hasattr(context, "messages") @@ -1161,9 +1137,9 @@ class TestContextContentValidation: middleware = ContextValidationMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: # Verify metadata was set by middleware assert ctx.metadata.get("validated") is True return AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) @@ -1260,9 +1236,7 @@ class TestStreamingScenarios: streaming_flags: list[bool] = [] class StreamingFlagMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: streaming_flags.append(context.stream) await next(context) @@ -1271,18 +1245,18 @@ class TestStreamingScenarios: messages = [ChatMessage(role="user", text="test")] # Test non-streaming - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: streaming_flags.append(ctx.stream) return AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) await pipeline.execute(context, final_handler) # Test streaming - context_stream = AgentRunContext(agent=mock_agent, messages=messages, stream=True) + context_stream = AgentContext(agent=mock_agent, messages=messages, stream=True) - async def final_stream_handler(ctx: AgentRunContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + async def final_stream_handler(ctx: AgentContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: async def _stream() -> AsyncIterable[AgentResponseUpdate]: streaming_flags.append(ctx.stream) yield AgentResponseUpdate(contents=[Content.from_text(text="chunk")]) @@ -1302,9 +1276,7 @@ class TestStreamingScenarios: chunks_processed: list[str] = [] class StreamProcessingMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: chunks_processed.append("before_stream") await next(context) chunks_processed.append("after_stream") @@ -1312,9 +1284,9 @@ class TestStreamingScenarios: middleware = StreamProcessingMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, stream=True) + context = AgentContext(agent=mock_agent, messages=messages, stream=True) - async def final_stream_handler(ctx: AgentRunContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + async def final_stream_handler(ctx: AgentContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: async def _stream() -> AsyncIterable[AgentResponseUpdate]: chunks_processed.append("stream_start") yield AgentResponseUpdate(contents=[Content.from_text(text="chunk1")]) @@ -1436,7 +1408,7 @@ class FunctionTestArgs(BaseModel): class TestAgentMiddleware(AgentMiddleware): """Test implementation of AgentMiddleware.""" - async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: await next(context) @@ -1469,20 +1441,18 @@ class TestMiddlewareExecutionControl: """Test that when agent middleware doesn't call next(), no execution happens.""" class NoNextMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Don't call next() - this should prevent any execution pass middleware = NoNextMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) handler_called = False - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: nonlocal handler_called handler_called = True return AgentResponse(messages=[ChatMessage(role="assistant", text="should not execute")]) @@ -1498,20 +1468,18 @@ class TestMiddlewareExecutionControl: """Test that when agent middleware doesn't call next(), no streaming execution happens.""" class NoNextStreamingMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Don't call next() - this should prevent any execution pass middleware = NoNextStreamingMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, stream=True) + context = AgentContext(agent=mock_agent, messages=messages, stream=True) handler_called = False - async def final_handler(ctx: AgentRunContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + async def final_handler(ctx: AgentContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: async def _stream() -> AsyncIterable[AgentResponseUpdate]: nonlocal handler_called handler_called = True @@ -1566,26 +1534,22 @@ class TestMiddlewareExecutionControl: execution_order: list[str] = [] class FirstMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("first") # Don't call next() - this should stop the pipeline class SecondMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("second") await next(context) pipeline = AgentMiddlewarePipeline(FirstMiddleware(), SecondMiddleware()) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) handler_called = False - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: nonlocal handler_called handler_called = True return AgentResponse(messages=[ChatMessage(role="assistant", text="should not execute")]) diff --git a/python/packages/core/tests/core/test_middleware_context_result.py b/python/packages/core/tests/core/test_middleware_context_result.py index 64eec8dc3b..29bb2e3aa2 100644 --- a/python/packages/core/tests/core/test_middleware_context_result.py +++ b/python/packages/core/tests/core/test_middleware_context_result.py @@ -17,9 +17,9 @@ from agent_framework import ( ResponseStream, ) from agent_framework._middleware import ( + AgentContext, AgentMiddleware, AgentMiddlewarePipeline, - AgentRunContext, FunctionInvocationContext, FunctionMiddleware, FunctionMiddlewarePipeline, @@ -43,9 +43,7 @@ class TestResultOverrideMiddleware: override_response = AgentResponse(messages=[ChatMessage(role="assistant", text="overridden response")]) class ResponseOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Execute the pipeline first, then override the response await next(context) context.result = override_response @@ -53,11 +51,11 @@ class TestResultOverrideMiddleware: middleware = ResponseOverrideMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) handler_called = False - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: nonlocal handler_called handler_called = True return AgentResponse(messages=[ChatMessage(role="assistant", text="original response")]) @@ -79,9 +77,7 @@ class TestResultOverrideMiddleware: yield AgentResponseUpdate(contents=[Content.from_text(text=" stream")]) class StreamResponseOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Execute the pipeline first, then override the response stream await next(context) context.result = ResponseStream(override_stream()) @@ -89,9 +85,9 @@ class TestResultOverrideMiddleware: middleware = StreamResponseOverrideMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, stream=True) + context = AgentContext(agent=mock_agent, messages=messages, stream=True) - async def final_handler(ctx: AgentRunContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + async def final_handler(ctx: AgentContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: async def _stream() -> AsyncIterable[AgentResponseUpdate]: yield AgentResponseUpdate(contents=[Content.from_text(text="original")]) @@ -145,9 +141,7 @@ class TestResultOverrideMiddleware: mock_chat_client = MockChatClient() class ChatAgentResponseOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Always call next() first to allow execution await next(context) # Then conditionally override based on content @@ -184,9 +178,7 @@ class TestResultOverrideMiddleware: yield AgentResponseUpdate(contents=[Content.from_text(text=" response!")]) class ChatAgentStreamOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Check if we want to override BEFORE calling next to avoid creating unused streams if any("custom stream" in msg.text for msg in context.messages if msg.text): context.result = ResponseStream(custom_stream()) @@ -223,9 +215,7 @@ class TestResultOverrideMiddleware: """Test that when agent middleware conditionally doesn't call next(), no execution happens.""" class ConditionalNoNextMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Only call next() if message contains "execute" if any("execute" in msg.text for msg in context.messages if msg.text): await next(context) @@ -236,14 +226,14 @@ class TestResultOverrideMiddleware: handler_called = False - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: nonlocal handler_called handler_called = True return AgentResponse(messages=[ChatMessage(role="assistant", text="executed response")]) # Test case where next() is NOT called no_execute_messages = [ChatMessage(role="user", text="Don't run this")] - no_execute_context = AgentRunContext(agent=mock_agent, messages=no_execute_messages, stream=False) + no_execute_context = AgentContext(agent=mock_agent, messages=no_execute_messages, stream=False) no_execute_result = await pipeline.execute(no_execute_context, final_handler) # When middleware doesn't call next(), result should be empty AgentResponse @@ -255,7 +245,7 @@ class TestResultOverrideMiddleware: # Test case where next() IS called execute_messages = [ChatMessage(role="user", text="Please execute this")] - execute_context = AgentRunContext(agent=mock_agent, messages=execute_messages, stream=False) + execute_context = AgentContext(agent=mock_agent, messages=execute_messages, stream=False) execute_result = await pipeline.execute(execute_context, final_handler) assert execute_result is not None @@ -318,9 +308,7 @@ class TestResultObservability: observed_responses: list[AgentResponse] = [] class ObservabilityMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Context should be empty before next() assert context.result is None @@ -335,9 +323,9 @@ class TestResultObservability: middleware = ObservabilityMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, stream=False) + context = AgentContext(agent=mock_agent, messages=messages, stream=False) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: return AgentResponse(messages=[ChatMessage(role="assistant", text="executed response")]) result = await pipeline.execute(context, final_handler) @@ -386,9 +374,7 @@ class TestResultObservability: """Test that middleware can override response after observing execution.""" class PostExecutionOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Call next to execute first await next(context) @@ -405,9 +391,9 @@ class TestResultObservability: middleware = PostExecutionOverrideMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, stream=False) + context = AgentContext(agent=mock_agent, messages=messages, stream=False) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: return AgentResponse(messages=[ChatMessage(role="assistant", text="response to modify")]) result = await pipeline.execute(context, final_handler) diff --git a/python/packages/core/tests/core/test_middleware_with_agent.py b/python/packages/core/tests/core/test_middleware_with_agent.py index 50146ab008..1bb91137e7 100644 --- a/python/packages/core/tests/core/test_middleware_with_agent.py +++ b/python/packages/core/tests/core/test_middleware_with_agent.py @@ -6,9 +6,9 @@ from typing import Any import pytest from agent_framework import ( + AgentContext, AgentMiddleware, AgentResponseUpdate, - AgentRunContext, ChatAgent, ChatClientProtocol, ChatContext, @@ -44,9 +44,7 @@ class TestChatAgentClassBasedMiddleware: def __init__(self, name: str): self.name = name - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") await next(context) execution_order.append(f"{self.name}_after") @@ -122,9 +120,7 @@ class TestChatAgentFunctionBasedMiddleware: execution_order: list[str] = [] class PreTerminationMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("middleware_before") raise MiddlewareTermination # Code after raise is unreachable @@ -153,9 +149,7 @@ class TestChatAgentFunctionBasedMiddleware: execution_order: list[str] = [] class PostTerminationMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("middleware_before") await next(context) execution_order.append("middleware_after") @@ -225,7 +219,7 @@ class TestChatAgentFunctionBasedMiddleware: execution_order: list[str] = [] async def tracking_agent_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + context: AgentContext, next: Callable[[AgentContext], Awaitable[None]] ) -> None: execution_order.append("agent_function_before") await next(context) @@ -290,9 +284,7 @@ class TestChatAgentStreamingMiddleware: streaming_flags: list[bool] = [] class StreamingTrackingMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("middleware_before") streaming_flags.append(context.stream) await next(context) @@ -334,9 +326,7 @@ class TestChatAgentStreamingMiddleware: streaming_flags: list[bool] = [] class FlagTrackingMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: streaming_flags.append(context.stream) await next(context) @@ -368,9 +358,7 @@ class TestChatAgentMultipleMiddlewareOrdering: def __init__(self, name: str): self.name = name - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") await next(context) execution_order.append(f"{self.name}_after") @@ -400,15 +388,13 @@ class TestChatAgentMultipleMiddlewareOrdering: execution_order: list[str] = [] class ClassAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("class_agent_before") await next(context) execution_order.append("class_agent_after") async def function_agent_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + context: AgentContext, next: Callable[[AgentContext], Awaitable[None]] ) -> None: execution_order.append("function_agent_before") await next(context) @@ -447,15 +433,13 @@ class TestChatAgentMultipleMiddlewareOrdering: execution_order: list[str] = [] class ClassAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("class_agent_before") await next(context) execution_order.append("class_agent_after") async def function_agent_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + context: AgentContext, next: Callable[[AgentContext], Awaitable[None]] ) -> None: execution_order.append("function_agent_before") await next(context) @@ -646,8 +630,8 @@ class TestChatAgentFunctionMiddlewareWithTools: class TrackingAgentMiddleware(AgentMiddleware): async def process( self, - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: execution_order.append("agent_middleware_before") await next(context) @@ -801,7 +785,7 @@ class TestMiddlewareDynamicRebuild: self.name = name self.execution_log = execution_log - async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: self.execution_log.append(f"{self.name}_start") await next(context) self.execution_log.append(f"{self.name}_end") @@ -924,7 +908,7 @@ class TestRunLevelMiddleware: self.name = name self.execution_log = execution_log - async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: self.execution_log.append(f"{self.name}_start") await next(context) self.execution_log.append(f"{self.name}_end") @@ -976,9 +960,7 @@ class TestRunLevelMiddleware: def __init__(self, name: str): self.name = name - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_log.append(f"{self.name}_start") # Set metadata to pass information to run middleware context.metadata[f"{self.name}_key"] = f"{self.name}_value" @@ -989,9 +971,7 @@ class TestRunLevelMiddleware: def __init__(self, name: str): self.name = name - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_log.append(f"{self.name}_start") # Read metadata set by agent middleware for key, value in context.metadata.items(): @@ -1049,9 +1029,7 @@ class TestRunLevelMiddleware: def __init__(self, name: str): self.name = name - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_log.append(f"{self.name}_start") streaming_flags.append(context.stream) await next(context) @@ -1093,9 +1071,7 @@ class TestRunLevelMiddleware: # Agent-level middleware class AgentLevelAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_log.append("agent_level_agent_start") context.metadata["agent_level_agent"] = "processed" await next(context) @@ -1114,9 +1090,7 @@ class TestRunLevelMiddleware: # Run-level middleware class RunLevelAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_log.append("run_level_agent_start") # Verify agent-level middleware metadata is available assert "agent_level_agent" in context.metadata @@ -1218,7 +1192,7 @@ class TestMiddlewareDecoratorLogic: @agent_middleware async def matching_agent_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + context: AgentContext, next: Callable[[AgentContext], Awaitable[None]] ) -> None: execution_order.append("decorator_type_match_agent") await next(context) @@ -1346,7 +1320,7 @@ class TestMiddlewareDecoratorLogic: execution_order: list[str] = [] # No decorator - async def type_only_agent(context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + async def type_only_agent(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("type_only_agent") await next(context) @@ -1440,16 +1414,14 @@ class TestMiddlewareDecoratorLogic: class TestChatAgentThreadBehavior: - """Test cases for thread behavior in AgentRunContext across multiple runs.""" + """Test cases for thread behavior in AgentContext across multiple runs.""" - async def test_agent_run_context_thread_behavior_across_multiple_runs(self, chat_client: "MockChatClient") -> None: - """Test that AgentRunContext.thread property behaves correctly across multiple agent runs.""" + async def test_agent_context_thread_behavior_across_multiple_runs(self, chat_client: "MockChatClient") -> None: + """Test that AgentContext.thread property behaves correctly across multiple agent runs.""" thread_states: list[dict[str, Any]] = [] class ThreadTrackingMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Capture state before next() call thread_messages = [] if context.thread and context.thread.message_store: @@ -1804,9 +1776,7 @@ class TestChatAgentChatMiddleware: """Test ChatAgent with combined middleware types.""" execution_order: list[str] = [] - async def agent_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def agent_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("agent_middleware_before") await next(context) execution_order.append("agent_middleware_after") @@ -1844,9 +1814,7 @@ class TestChatAgentChatMiddleware: modified_kwargs: dict[str, Any] = {} @agent_middleware - async def kwargs_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def kwargs_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Capture the original kwargs captured_kwargs.update(context.kwargs) @@ -1897,7 +1865,7 @@ class TestChatAgentChatMiddleware: # class TrackingMiddleware(AgentMiddleware): # async def process( -# self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] +# self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]] # ) -> None: # execution_order.append("before") # await next(context) diff --git a/python/packages/purview/agent_framework_purview/_middleware.py b/python/packages/purview/agent_framework_purview/_middleware.py index 2aabd5a57b..dba7a3f649 100644 --- a/python/packages/purview/agent_framework_purview/_middleware.py +++ b/python/packages/purview/agent_framework_purview/_middleware.py @@ -2,7 +2,7 @@ from collections.abc import Awaitable, Callable -from agent_framework import AgentMiddleware, AgentRunContext, ChatContext, ChatMiddleware, MiddlewareTermination +from agent_framework import AgentContext, AgentMiddleware, ChatContext, ChatMiddleware, MiddlewareTermination from agent_framework._logging import get_logger from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential @@ -47,8 +47,8 @@ class PurviewPolicyMiddleware(AgentMiddleware): async def process( self, - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: # type: ignore[override] resolved_user_id: str | None = None try: diff --git a/python/packages/purview/tests/test_middleware.py b/python/packages/purview/tests/test_middleware.py index 7c9edacd1a..b0aadd8cd5 100644 --- a/python/packages/purview/tests/test_middleware.py +++ b/python/packages/purview/tests/test_middleware.py @@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from agent_framework import AgentResponse, AgentRunContext, ChatMessage, MiddlewareTermination +from agent_framework import AgentContext, AgentResponse, ChatMessage, MiddlewareTermination from azure.core.credentials import AccessToken from agent_framework_purview import PurviewPolicyMiddleware, PurviewSettings @@ -49,12 +49,12 @@ class TestPurviewPolicyMiddleware: self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock ) -> None: """Test middleware allows prompt that passes policy check.""" - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello, how are you?")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello, how are you?")]) with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")): next_called = False - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: nonlocal next_called next_called = True ctx.result = AgentResponse(messages=[ChatMessage(role="assistant", text="I'm good, thanks!")]) @@ -68,12 +68,12 @@ class TestPurviewPolicyMiddleware: self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock ) -> None: """Test middleware blocks prompt that violates policy.""" - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Sensitive information")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Sensitive information")]) with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")): next_called = False - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: nonlocal next_called next_called = True @@ -88,7 +88,7 @@ class TestPurviewPolicyMiddleware: async def test_middleware_checks_response(self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock) -> None: """Test middleware checks agent response for policy violations.""" - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) call_count = 0 @@ -100,7 +100,7 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: ctx.result = AgentResponse( messages=[ChatMessage(role="assistant", text="Here's some sensitive information")] ) @@ -120,11 +120,11 @@ class TestPurviewPolicyMiddleware: # Set ignore_exceptions to True so AttributeError is caught and logged middleware._settings.ignore_exceptions = True - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")): - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: ctx.result = "Some non-standard result" await middleware.process(context, mock_next) @@ -137,11 +137,11 @@ class TestPurviewPolicyMiddleware: """Test middleware passes correct activity type to processor.""" from agent_framework_purview._models import Activity - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_process: - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: ctx.result = AgentResponse(messages=[ChatMessage(role="assistant", text="Response")]) await middleware.process(context, mock_next) @@ -154,12 +154,12 @@ class TestPurviewPolicyMiddleware: self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock ) -> None: """Test that streaming results skip post-check evaluation.""" - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) context.stream = True with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: ctx.result = AgentResponse(messages=[ChatMessage(role="assistant", text="streaming")]) await middleware.process(context, mock_next) @@ -172,7 +172,7 @@ class TestPurviewPolicyMiddleware: """Test that 402 in pre-check is raised when ignore_payment_required=False.""" from agent_framework_purview._exceptions import PurviewPaymentRequiredError - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) with patch.object( middleware._processor, @@ -180,7 +180,7 @@ class TestPurviewPolicyMiddleware: side_effect=PurviewPaymentRequiredError("Payment required"), ): - async def mock_next(_: AgentRunContext) -> None: + async def mock_next(_: AgentContext) -> None: raise AssertionError("next should not be called") with pytest.raises(PurviewPaymentRequiredError): @@ -192,7 +192,7 @@ class TestPurviewPolicyMiddleware: """Test that 402 in post-check is raised when ignore_payment_required=False.""" from agent_framework_purview._exceptions import PurviewPaymentRequiredError - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) call_count = 0 @@ -205,7 +205,7 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", side_effect=side_effect): - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: ctx.result = AgentResponse(messages=[ChatMessage(role="assistant", text="OK")]) with pytest.raises(PurviewPaymentRequiredError): @@ -217,7 +217,7 @@ class TestPurviewPolicyMiddleware: """Test that post-check exceptions are propagated when ignore_exceptions=False.""" middleware._settings.ignore_exceptions = False - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) call_count = 0 @@ -230,7 +230,7 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", side_effect=side_effect): - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: ctx.result = AgentResponse(messages=[ChatMessage(role="assistant", text="OK")]) with pytest.raises(ValueError, match="Post-check blew up"): @@ -243,13 +243,13 @@ class TestPurviewPolicyMiddleware: # Set ignore_exceptions to True middleware._settings.ignore_exceptions = True - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) with patch.object( middleware._processor, "process_messages", side_effect=Exception("Pre-check error") ) as mock_process: - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: ctx.result = AgentResponse(messages=[ChatMessage(role="assistant", text="Response")]) await middleware.process(context, mock_next) @@ -266,7 +266,7 @@ class TestPurviewPolicyMiddleware: # Set ignore_exceptions to True middleware._settings.ignore_exceptions = True - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) call_count = 0 @@ -279,7 +279,7 @@ class TestPurviewPolicyMiddleware: with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: ctx.result = AgentResponse(messages=[ChatMessage(role="assistant", text="Response")]) await middleware.process(context, mock_next) @@ -297,7 +297,7 @@ class TestPurviewPolicyMiddleware: mock_agent = MagicMock() mock_agent.name = "test-agent" - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) # Mock processor to raise an exception async def mock_process_messages(*args, **kwargs): @@ -321,7 +321,7 @@ class TestPurviewPolicyMiddleware: mock_agent = MagicMock() mock_agent.name = "test-agent" - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) # Mock processor to raise an exception async def mock_process_messages(*args, **kwargs): diff --git a/python/samples/concepts/tools/README.md b/python/samples/concepts/tools/README.md index 3a270b25aa..04b7c04569 100644 --- a/python/samples/concepts/tools/README.md +++ b/python/samples/concepts/tools/README.md @@ -34,7 +34,7 @@ sequenceDiagram Note over Agent,AML: Agent Middleware Layer Agent->>AML: run() with middleware param AML->>AML: categorize_middleware() → split by type - AML->>AMP: execute(AgentRunContext) + AML->>AMP: execute(AgentContext) loop Agent Middleware Chain AMP->>AMP: middleware[i].process(context, next) @@ -127,7 +127,7 @@ sequenceDiagram **Entry Point:** `Agent.run(messages, thread, options, middleware)` -**Context Object:** `AgentRunContext` +**Context Object:** `AgentContext` | Field | Type | Description | |-------|------|-------------| diff --git a/python/samples/getting_started/middleware/README.md b/python/samples/getting_started/middleware/README.md index 3d1bd61d27..659e81647a 100644 --- a/python/samples/getting_started/middleware/README.md +++ b/python/samples/getting_started/middleware/README.md @@ -13,7 +13,7 @@ This folder contains examples demonstrating various middleware patterns with the | [`exception_handling_with_middleware.py`](exception_handling_with_middleware.py) | Demonstrates how to use middleware for centralized exception handling in function calls. Shows how to catch exceptions from functions, provide graceful error responses, and override function results when errors occur to provide user-friendly messages. | | [`override_result_with_middleware.py`](override_result_with_middleware.py) | Shows how to use middleware to intercept and modify function results after execution, supporting both regular and streaming agent responses. Demonstrates result filtering, formatting, enhancement, and custom streaming response generation. | | [`shared_state_middleware.py`](shared_state_middleware.py) | Demonstrates how to implement function-based middleware within a class to share state between multiple middleware functions. Shows how middleware can work together by sharing state, including call counting and result enhancement. | -| [`thread_behavior_middleware.py`](thread_behavior_middleware.py) | Demonstrates how middleware can access and track thread state across multiple agent runs. Shows how `AgentRunContext.thread` behaves differently before and after the `next()` call, how conversation history accumulates in threads, and timing of thread message updates. Essential for understanding conversation flow in middleware. | +| [`thread_behavior_middleware.py`](thread_behavior_middleware.py) | Demonstrates how middleware can access and track thread state across multiple agent runs. Shows how `AgentContext.thread` behaves differently before and after the `next()` call, how conversation history accumulates in threads, and timing of thread message updates. Essential for understanding conversation flow in middleware. | | [`agent_and_run_level_middleware.py`](agent_and_run_level_middleware.py) | Explains the difference between agent-level middleware (applied to ALL runs of the agent) and run-level middleware (applied to specific runs only). Shows security validation, performance monitoring, and context-specific middleware patterns. | | [`chat_middleware.py`](chat_middleware.py) | Demonstrates how to use chat middleware to observe and override inputs sent to AI models. Shows how to intercept chat requests, log and modify input messages, and override entire responses before they reach the underlying AI service. | diff --git a/python/samples/getting_started/middleware/agent_and_run_level_middleware.py b/python/samples/getting_started/middleware/agent_and_run_level_middleware.py index 32fd7a2e52..c90dd1936b 100644 --- a/python/samples/getting_started/middleware/agent_and_run_level_middleware.py +++ b/python/samples/getting_started/middleware/agent_and_run_level_middleware.py @@ -7,9 +7,9 @@ from random import randint from typing import Annotated from agent_framework import ( + AgentContext, AgentMiddleware, AgentResponse, - AgentRunContext, FunctionInvocationContext, tool, ) @@ -49,7 +49,7 @@ def get_weather( class SecurityAgentMiddleware(AgentMiddleware): """Agent-level security middleware that validates all requests.""" - async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: print("[SecurityMiddleware] Checking security for all requests...") # Check for security violations in the last user message @@ -66,8 +66,8 @@ class SecurityAgentMiddleware(AgentMiddleware): async def performance_monitor_middleware( - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: """Agent-level performance monitoring for all runs.""" print("[PerformanceMonitor] Starting performance monitoring...") @@ -85,7 +85,7 @@ async def performance_monitor_middleware( class HighPriorityMiddleware(AgentMiddleware): """Run-level middleware for high priority requests.""" - async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: print("[HighPriority] Processing high priority request with expedited handling...") # Read metadata set by agent-level middleware @@ -101,8 +101,8 @@ class HighPriorityMiddleware(AgentMiddleware): async def debugging_middleware( - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: """Run-level debugging middleware for troubleshooting specific runs.""" print("[Debug] Debug mode enabled for this run") @@ -126,7 +126,7 @@ class CachingMiddleware(AgentMiddleware): def __init__(self) -> None: self.cache: dict[str, AgentResponse] = {} - async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Create a simple cache key from the last message last_message = context.messages[-1] if context.messages else None cache_key: str = last_message.text if last_message and last_message.text else "no_message" diff --git a/python/samples/getting_started/middleware/class_based_middleware.py b/python/samples/getting_started/middleware/class_based_middleware.py index 65fa279f19..727c0a2821 100644 --- a/python/samples/getting_started/middleware/class_based_middleware.py +++ b/python/samples/getting_started/middleware/class_based_middleware.py @@ -7,9 +7,9 @@ from random import randint from typing import Annotated from agent_framework import ( + AgentContext, AgentMiddleware, AgentResponse, - AgentRunContext, ChatMessage, FunctionInvocationContext, FunctionMiddleware, @@ -49,8 +49,8 @@ class SecurityAgentMiddleware(AgentMiddleware): async def process( self, - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: # Check for potential security violations in the query # Look at the last user message @@ -61,9 +61,7 @@ class SecurityAgentMiddleware(AgentMiddleware): print("[SecurityAgentMiddleware] Security Warning: Detected sensitive information, blocking request.") # Override the result with warning message context.result = AgentResponse( - messages=[ - ChatMessage("assistant", ["Detected sensitive information, the request is blocked."]) - ] + messages=[ChatMessage("assistant", ["Detected sensitive information, the request is blocked."])] ) # Simply don't call next() to prevent execution return diff --git a/python/samples/getting_started/middleware/decorator_middleware.py b/python/samples/getting_started/middleware/decorator_middleware.py index f16407918c..2ea1196bc3 100644 --- a/python/samples/getting_started/middleware/decorator_middleware.py +++ b/python/samples/getting_started/middleware/decorator_middleware.py @@ -20,7 +20,7 @@ to explicitly mark middleware functions without requiring type annotations. The framework supports the following middleware detection scenarios: 1. Both decorator and parameter type specified: - - Validates that they match (e.g., @agent_middleware with AgentRunContext) + - Validates that they match (e.g., @agent_middleware with AgentContext) - Throws exception if they don't match for safety 2. Only decorator specified: @@ -28,7 +28,7 @@ The framework supports the following middleware detection scenarios: - No type annotations needed - framework handles context types automatically 3. Only parameter type specified: - - Uses type annotations (AgentRunContext, FunctionInvocationContext) for detection + - Uses type annotations (AgentContext, FunctionInvocationContext) for detection 4. Neither decorator nor parameter type specified: - Throws exception requiring either decorator or type annotation diff --git a/python/samples/getting_started/middleware/function_based_middleware.py b/python/samples/getting_started/middleware/function_based_middleware.py index 21defef491..1616aa5fc3 100644 --- a/python/samples/getting_started/middleware/function_based_middleware.py +++ b/python/samples/getting_started/middleware/function_based_middleware.py @@ -7,7 +7,7 @@ from random import randint from typing import Annotated from agent_framework import ( - AgentRunContext, + AgentContext, FunctionInvocationContext, tool, ) @@ -42,8 +42,8 @@ def get_weather( async def security_agent_middleware( - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: """Agent middleware that checks for security violations.""" # Check for potential security violations in the query diff --git a/python/samples/getting_started/middleware/middleware_termination.py b/python/samples/getting_started/middleware/middleware_termination.py index ea32bc606b..69fa5766d9 100644 --- a/python/samples/getting_started/middleware/middleware_termination.py +++ b/python/samples/getting_started/middleware/middleware_termination.py @@ -6,9 +6,9 @@ from random import randint from typing import Annotated from agent_framework import ( + AgentContext, AgentMiddleware, AgentResponse, - AgentRunContext, ChatMessage, tool, ) @@ -47,8 +47,8 @@ class PreTerminationMiddleware(AgentMiddleware): async def process( self, - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: # Check if the user message contains any blocked words last_message = context.messages[-1] if context.messages else None @@ -87,8 +87,8 @@ class PostTerminationMiddleware(AgentMiddleware): async def process( self, - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: print(f"[PostTerminationMiddleware] Processing request (response count: {self.response_count})") diff --git a/python/samples/getting_started/middleware/override_result_with_middleware.py b/python/samples/getting_started/middleware/override_result_with_middleware.py index 06351d1803..8aef8f8e3b 100644 --- a/python/samples/getting_started/middleware/override_result_with_middleware.py +++ b/python/samples/getting_started/middleware/override_result_with_middleware.py @@ -7,9 +7,9 @@ from random import randint from typing import Annotated from agent_framework import ( + AgentContext, AgentResponse, AgentResponseUpdate, - AgentRunContext, ChatContext, ChatMessage, ChatResponse, @@ -104,9 +104,7 @@ async def validate_weather_middleware(context: ChatContext, next: Callable[[Chat context.result.messages.append(ChatMessage(role=Role.ASSISTANT, text=validation_note)) -async def agent_cleanup_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] -) -> None: +async def agent_cleanup_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: """Agent middleware that validates chat middleware effects and cleans the result.""" await next(context) diff --git a/python/samples/getting_started/middleware/thread_behavior_middleware.py b/python/samples/getting_started/middleware/thread_behavior_middleware.py index 93f72d567a..0665d23720 100644 --- a/python/samples/getting_started/middleware/thread_behavior_middleware.py +++ b/python/samples/getting_started/middleware/thread_behavior_middleware.py @@ -5,7 +5,7 @@ from collections.abc import Awaitable, Callable from typing import Annotated from agent_framework import ( - AgentRunContext, + AgentContext, ChatMessageStore, tool, ) @@ -19,7 +19,7 @@ Thread Behavior MiddlewareTypes Example This sample demonstrates how middleware can access and track thread state across multiple agent runs. The example shows: -- How AgentRunContext.thread property behaves across multiple runs +- How AgentContext.thread property behaves across multiple runs - How middleware can access conversation history through the thread - The timing of when thread messages are populated (before vs after next() call) - How to track thread state changes across runs @@ -45,8 +45,8 @@ def get_weather( async def thread_tracking_middleware( - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: """MiddlewareTypes that tracks and logs thread behavior across runs.""" thread_messages = []