Python: Fixed middleware and multimodal input samples (#4022)

* Fix streaming branch in weather override middleware sample

The streaming branch of weather_override_middleware only prefixed the
original weather data via a transform hook instead of replacing the
content with the 'perfect weather' override like the non-streaming
branch does. Replace with a new ResponseStream that yields the override
content as ChatResponseUpdate chunks.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fixed exception handling middleware sample

* Fixed runtime context delegation middleware example

* Fixed multimodal input examples

* Small update

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Dmytro Struk
2026-02-17 19:49:33 -05:00
committed by GitHub
Unverified
parent df58775d64
commit 2dd731f90f
8 changed files with 204 additions and 33 deletions
@@ -1024,6 +1024,11 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
await self._async_exit_stack.enter_async_context(mcp_server)
final_tools.extend(mcp_server.functions)
# Merge runtime kwargs into additional_function_arguments so they're available
# in function middleware context and tool invocation.
existing_additional_args = opts.pop("additional_function_arguments", None) or {}
additional_function_arguments = {**kwargs, **existing_additional_args}
# Build options dict from run() options merged with provided options
run_opts: dict[str, Any] = {
"model_id": opts.pop("model_id", None),
@@ -1031,7 +1036,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
if active_session
else opts.pop("conversation_id", None),
"allow_multiple_tool_calls": opts.pop("allow_multiple_tool_calls", None),
"additional_function_arguments": opts.pop("additional_function_arguments", None),
"additional_function_arguments": additional_function_arguments or None,
"frequency_penalty": opts.pop("frequency_penalty", None),
"logit_bias": opts.pop("logit_bias", None),
"max_tokens": opts.pop("max_tokens", None),
@@ -2871,12 +2871,7 @@ class _ChatOptionsBase(TypedDict, total=False):
presence_penalty: float
# Tool configuration (forward reference to avoid circular import)
tools: (
ToolTypes
| Callable[..., Any]
| Sequence[ToolTypes | Callable[..., Any]]
| None
)
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None
tool_choice: ToolMode | Literal["auto", "required", "none"]
allow_multiple_tool_calls: bool
@@ -769,6 +769,179 @@ class TestChatAgentFunctionMiddlewareWithTools:
assert modified_kwargs["new_param"] == "added_by_middleware"
assert modified_kwargs["custom_param"] == "test_value"
async def test_run_kwargs_available_in_function_middleware(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test that kwargs passed directly to agent.run() appear in FunctionInvocationContext.kwargs,
including complex nested values like dicts."""
captured_kwargs: dict[str, Any] = {}
@function_middleware
async def capture_middleware(
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
captured_kwargs.update(context.kwargs)
await call_next()
chat_client_base.run_responses = [
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1", name="sample_tool_function", arguments='{"location": "Seattle"}'
)
],
)
]
),
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
]
agent = Agent(client=chat_client_base, middleware=[capture_middleware], tools=[sample_tool_function])
session_metadata = {"tenant": "acme-corp", "region": "us-west"}
await agent.run(
[Message(role="user", text="Get weather")],
user_id="user-456",
session_metadata=session_metadata,
)
assert "user_id" in captured_kwargs, f"Expected 'user_id' in kwargs: {captured_kwargs}"
assert captured_kwargs["user_id"] == "user-456"
assert captured_kwargs["session_metadata"] == {"tenant": "acme-corp", "region": "us-west"}
async def test_run_kwargs_merged_with_additional_function_arguments(
self, chat_client_base: "MockBaseChatClient"
) -> None:
"""Test that explicit additional_function_arguments in options take precedence over run kwargs."""
captured_kwargs: dict[str, Any] = {}
@function_middleware
async def capture_middleware(
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
captured_kwargs.update(context.kwargs)
await call_next()
chat_client_base.run_responses = [
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1", name="sample_tool_function", arguments='{"location": "Seattle"}'
)
],
)
]
),
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
]
agent = Agent(client=chat_client_base, middleware=[capture_middleware], tools=[sample_tool_function])
await agent.run(
[Message(role="user", text="Get weather")],
# This kwarg should be overridden by additional_function_arguments
user_id="from-kwargs",
tenant_id="from-kwargs",
options={
"additional_function_arguments": {
"user_id": "from-options",
"extra_key": "only-in-options",
}
},
)
# additional_function_arguments takes precedence for overlapping keys
assert captured_kwargs["user_id"] == "from-options"
# Non-overlapping kwargs from run() still come through
assert captured_kwargs["tenant_id"] == "from-kwargs"
# Keys only in additional_function_arguments are present
assert captured_kwargs["extra_key"] == "only-in-options"
async def test_run_kwargs_consistent_across_multiple_tool_calls(
self, chat_client_base: "MockBaseChatClient"
) -> None:
"""Test that kwargs are consistent across multiple tool invocations in a single run."""
invocation_kwargs: list[dict[str, Any]] = []
@function_middleware
async def capture_middleware(
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
invocation_kwargs.append(dict(context.kwargs))
await call_next()
chat_client_base.run_responses = [
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1", name="sample_tool_function", arguments='{"location": "Seattle"}'
),
Content.from_function_call(
call_id="call_2", name="sample_tool_function", arguments='{"location": "Portland"}'
),
],
)
]
),
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
]
agent = Agent(client=chat_client_base, middleware=[capture_middleware], tools=[sample_tool_function])
await agent.run(
[Message(role="user", text="Get weather for both cities")],
user_id="user-456",
request_id="req-001",
)
assert len(invocation_kwargs) == 2
for kw in invocation_kwargs:
assert kw["user_id"] == "user-456"
assert kw["request_id"] == "req-001"
async def test_run_without_kwargs_produces_empty_context_kwargs(
self, chat_client_base: "MockBaseChatClient"
) -> None:
"""Test that when no kwargs are passed to run(), FunctionInvocationContext.kwargs is empty."""
captured_kwargs: dict[str, Any] = {}
@function_middleware
async def capture_middleware(
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
captured_kwargs.update(context.kwargs)
await call_next()
chat_client_base.run_responses = [
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1", name="sample_tool_function", arguments='{"location": "Seattle"}'
)
],
)
]
),
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
]
agent = Agent(client=chat_client_base, middleware=[capture_middleware], tools=[sample_tool_function])
await agent.run([Message(role="user", text="Get weather")])
# No runtime kwargs should be present
assert "user_id" not in captured_kwargs
class TestMiddlewareDynamicRebuild:
"""Test cases for dynamic middleware pipeline rebuilding with Agent."""
@@ -47,8 +47,8 @@ async def exception_handling_middleware(
print(f"[ExceptionHandlingMiddleware] Caught TimeoutError: {e}")
# Override function result to provide custom message in response.
context.result = (
"Request Timeout: The data service is taking longer than expected to respond.",
"Respond with message - 'Sorry for the inconvenience, please try again later.'",
"Request Timeout: The data service is taking longer than expected to respond."
"Respond with message - 'Sorry for the inconvenience, please try again later.'"
)
@@ -2,7 +2,7 @@
import asyncio
import re
from collections.abc import Awaitable, Callable
from collections.abc import AsyncIterable, Awaitable, Callable
from random import randint
from typing import Annotated
@@ -13,9 +13,9 @@ from agent_framework import (
ChatContext,
ChatResponse,
ChatResponseUpdate,
Content,
Message,
ResponseStream,
Role,
tool,
)
from agent_framework.openai import OpenAIResponsesClient
@@ -66,22 +66,20 @@ async def weather_override_middleware(context: ChatContext, call_next: Callable[
]
if context.stream and isinstance(context.result, ResponseStream):
index = {"value": 0}
def _update_hook(update: ChatResponseUpdate) -> ChatResponseUpdate:
for content in update.contents or []:
if not content.text:
continue
content.text = f"Weather Advisory: [{index['value']}] {content.text}"
index["value"] += 1
return update
async def _override_stream() -> AsyncIterable[ChatResponseUpdate]:
for i, chunk_text in enumerate(chunks):
yield ChatResponseUpdate(
contents=[Content.from_text(text=f"Weather Advisory: [{i}] {chunk_text}")],
role="assistant",
)
context.result.with_transform_hook(_update_hook)
context.result = ResponseStream(_override_stream())
else:
# For non-streaming: just replace with a new message
current_text = context.result.text if isinstance(context.result, ChatResponse) else ""
custom_message = f"Weather Advisory: [0] {''.join(chunks)} Original message was: {current_text}"
context.result = ChatResponse(messages=[Message(role=Role.ASSISTANT, text=custom_message)])
context.result = ChatResponse(messages=[Message(role="assistant", text=custom_message)])
async def validate_weather_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
@@ -96,12 +94,12 @@ async def validate_weather_middleware(context: ChatContext, call_next: Callable[
if context.stream and isinstance(context.result, ResponseStream):
def _append_validation_note(response: ChatResponse) -> ChatResponse:
response.messages.append(Message(role=Role.ASSISTANT, text=validation_note))
response.messages.append(Message(role="assistant", text=validation_note))
return response
context.result.with_finalizer(_append_validation_note)
elif isinstance(context.result, ChatResponse):
context.result.messages.append(Message(role=Role.ASSISTANT, text=validation_note))
context.result.messages.append(Message(role="assistant", text=validation_note))
async def agent_cleanup_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
@@ -154,7 +152,7 @@ async def agent_cleanup_middleware(context: AgentContext, call_next: Callable[[]
if not found_validation:
raise RuntimeError("Expected validation note not found in agent response.")
cleaned_messages.append(Message(role=Role.ASSISTANT, text=" Agent: OK"))
cleaned_messages.append(Message(role="assistant", text=" Agent: OK"))
response.messages = cleaned_messages
return response
@@ -9,7 +9,7 @@ from azure.identity import AzureCliCredential
def create_sample_image() -> str:
"""Create a simple 1x1 pixel PNG image for testing."""
# This is a tiny red pixel in PNG format
# This is a tiny yellow pixel in PNG format
png_data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
return f"data:image/png;base64,{png_data}"
@@ -32,7 +32,7 @@ async def test_image() -> None:
],
)
response = await client.get_response(message)
response = await client.get_response([message])
print(f"Image Response: {response}")
@@ -18,7 +18,7 @@ def load_sample_pdf() -> bytes:
def create_sample_image() -> str:
"""Create a simple 1x1 pixel PNG image for testing."""
# This is a tiny red pixel in PNG format
# This is a tiny yellow pixel in PNG format
png_data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
return f"data:image/png;base64,{png_data}"
@@ -41,7 +41,7 @@ async def test_image() -> None:
],
)
response = await client.get_response(message)
response = await client.get_response([message])
print(f"Image Response: {response}")
@@ -62,7 +62,7 @@ async def test_pdf() -> None:
],
)
response = await client.get_response(message)
response = await client.get_response([message])
print(f"PDF Response: {response}")
@@ -19,7 +19,7 @@ def load_sample_pdf() -> bytes:
def create_sample_image() -> str:
"""Create a simple 1x1 pixel PNG image for testing."""
# This is a tiny red pixel in PNG format
# This is a tiny yellow pixel in PNG format
png_data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
return f"data:image/png;base64,{png_data}"
@@ -53,7 +53,7 @@ async def test_image() -> None:
],
)
response = await client.get_response(message)
response = await client.get_response([message])
print(f"Image Response: {response}")
@@ -70,7 +70,7 @@ async def test_audio() -> None:
],
)
response = await client.get_response(message)
response = await client.get_response([message])
print(f"Audio Response: {response}")
@@ -89,7 +89,7 @@ async def test_pdf() -> None:
],
)
response = await client.get_response(message)
response = await client.get_response([message])
print(f"PDF Response: {response}")