Python: Fixed middleware and multimodal input samples (#4022)

* Fix streaming branch in weather override middleware sample

The streaming branch of weather_override_middleware only prefixed the
original weather data via a transform hook instead of replacing the
content with the 'perfect weather' override like the non-streaming
branch does. Replace with a new ResponseStream that yields the override
content as ChatResponseUpdate chunks.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fixed exception handling middleware sample

* Fixed runtime context delegation middleware example

* Fixed multimodal input examples

* Small update

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Dmytro Struk
2026-02-17 19:49:33 -05:00
committed by GitHub
Unverified
parent df58775d64
commit 2dd731f90f
8 changed files with 204 additions and 33 deletions
@@ -1024,6 +1024,11 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
await self._async_exit_stack.enter_async_context(mcp_server)
final_tools.extend(mcp_server.functions)
# Merge runtime kwargs into additional_function_arguments so they're available
# in function middleware context and tool invocation.
existing_additional_args = opts.pop("additional_function_arguments", None) or {}
additional_function_arguments = {**kwargs, **existing_additional_args}
# Build options dict from run() options merged with provided options
run_opts: dict[str, Any] = {
"model_id": opts.pop("model_id", None),
@@ -1031,7 +1036,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
if active_session
else opts.pop("conversation_id", None),
"allow_multiple_tool_calls": opts.pop("allow_multiple_tool_calls", None),
"additional_function_arguments": opts.pop("additional_function_arguments", None),
"additional_function_arguments": additional_function_arguments or None,
"frequency_penalty": opts.pop("frequency_penalty", None),
"logit_bias": opts.pop("logit_bias", None),
"max_tokens": opts.pop("max_tokens", None),
@@ -2871,12 +2871,7 @@ class _ChatOptionsBase(TypedDict, total=False):
presence_penalty: float
# Tool configuration (forward reference to avoid circular import)
tools: (
ToolTypes
| Callable[..., Any]
| Sequence[ToolTypes | Callable[..., Any]]
| None
)
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None
tool_choice: ToolMode | Literal["auto", "required", "none"]
allow_multiple_tool_calls: bool
@@ -769,6 +769,179 @@ class TestChatAgentFunctionMiddlewareWithTools:
assert modified_kwargs["new_param"] == "added_by_middleware"
assert modified_kwargs["custom_param"] == "test_value"
async def test_run_kwargs_available_in_function_middleware(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test that kwargs passed directly to agent.run() appear in FunctionInvocationContext.kwargs,
including complex nested values like dicts."""
captured_kwargs: dict[str, Any] = {}
@function_middleware
async def capture_middleware(
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
captured_kwargs.update(context.kwargs)
await call_next()
chat_client_base.run_responses = [
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1", name="sample_tool_function", arguments='{"location": "Seattle"}'
)
],
)
]
),
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
]
agent = Agent(client=chat_client_base, middleware=[capture_middleware], tools=[sample_tool_function])
session_metadata = {"tenant": "acme-corp", "region": "us-west"}
await agent.run(
[Message(role="user", text="Get weather")],
user_id="user-456",
session_metadata=session_metadata,
)
assert "user_id" in captured_kwargs, f"Expected 'user_id' in kwargs: {captured_kwargs}"
assert captured_kwargs["user_id"] == "user-456"
assert captured_kwargs["session_metadata"] == {"tenant": "acme-corp", "region": "us-west"}
async def test_run_kwargs_merged_with_additional_function_arguments(
self, chat_client_base: "MockBaseChatClient"
) -> None:
"""Test that explicit additional_function_arguments in options take precedence over run kwargs."""
captured_kwargs: dict[str, Any] = {}
@function_middleware
async def capture_middleware(
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
captured_kwargs.update(context.kwargs)
await call_next()
chat_client_base.run_responses = [
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1", name="sample_tool_function", arguments='{"location": "Seattle"}'
)
],
)
]
),
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
]
agent = Agent(client=chat_client_base, middleware=[capture_middleware], tools=[sample_tool_function])
await agent.run(
[Message(role="user", text="Get weather")],
# This kwarg should be overridden by additional_function_arguments
user_id="from-kwargs",
tenant_id="from-kwargs",
options={
"additional_function_arguments": {
"user_id": "from-options",
"extra_key": "only-in-options",
}
},
)
# additional_function_arguments takes precedence for overlapping keys
assert captured_kwargs["user_id"] == "from-options"
# Non-overlapping kwargs from run() still come through
assert captured_kwargs["tenant_id"] == "from-kwargs"
# Keys only in additional_function_arguments are present
assert captured_kwargs["extra_key"] == "only-in-options"
async def test_run_kwargs_consistent_across_multiple_tool_calls(
self, chat_client_base: "MockBaseChatClient"
) -> None:
"""Test that kwargs are consistent across multiple tool invocations in a single run."""
invocation_kwargs: list[dict[str, Any]] = []
@function_middleware
async def capture_middleware(
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
invocation_kwargs.append(dict(context.kwargs))
await call_next()
chat_client_base.run_responses = [
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1", name="sample_tool_function", arguments='{"location": "Seattle"}'
),
Content.from_function_call(
call_id="call_2", name="sample_tool_function", arguments='{"location": "Portland"}'
),
],
)
]
),
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
]
agent = Agent(client=chat_client_base, middleware=[capture_middleware], tools=[sample_tool_function])
await agent.run(
[Message(role="user", text="Get weather for both cities")],
user_id="user-456",
request_id="req-001",
)
assert len(invocation_kwargs) == 2
for kw in invocation_kwargs:
assert kw["user_id"] == "user-456"
assert kw["request_id"] == "req-001"
async def test_run_without_kwargs_produces_empty_context_kwargs(
self, chat_client_base: "MockBaseChatClient"
) -> None:
"""Test that when no kwargs are passed to run(), FunctionInvocationContext.kwargs is empty."""
captured_kwargs: dict[str, Any] = {}
@function_middleware
async def capture_middleware(
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
captured_kwargs.update(context.kwargs)
await call_next()
chat_client_base.run_responses = [
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1", name="sample_tool_function", arguments='{"location": "Seattle"}'
)
],
)
]
),
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
]
agent = Agent(client=chat_client_base, middleware=[capture_middleware], tools=[sample_tool_function])
await agent.run([Message(role="user", text="Get weather")])
# No runtime kwargs should be present
assert "user_id" not in captured_kwargs
class TestMiddlewareDynamicRebuild:
"""Test cases for dynamic middleware pipeline rebuilding with Agent."""