mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fixed middleware and multimodal input samples (#4022)
* Fix streaming branch in weather override middleware sample The streaming branch of weather_override_middleware only prefixed the original weather data via a transform hook instead of replacing the content with the 'perfect weather' override like the non-streaming branch does. Replace with a new ResponseStream that yields the override content as ChatResponseUpdate chunks. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fixed exception handling middleware sample * Fixed runtime context delegation middleware example * Fixed multimodal input examples * Small update --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
df58775d64
commit
2dd731f90f
@@ -1024,6 +1024,11 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
await self._async_exit_stack.enter_async_context(mcp_server)
|
||||
final_tools.extend(mcp_server.functions)
|
||||
|
||||
# Merge runtime kwargs into additional_function_arguments so they're available
|
||||
# in function middleware context and tool invocation.
|
||||
existing_additional_args = opts.pop("additional_function_arguments", None) or {}
|
||||
additional_function_arguments = {**kwargs, **existing_additional_args}
|
||||
|
||||
# Build options dict from run() options merged with provided options
|
||||
run_opts: dict[str, Any] = {
|
||||
"model_id": opts.pop("model_id", None),
|
||||
@@ -1031,7 +1036,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
if active_session
|
||||
else opts.pop("conversation_id", None),
|
||||
"allow_multiple_tool_calls": opts.pop("allow_multiple_tool_calls", None),
|
||||
"additional_function_arguments": opts.pop("additional_function_arguments", None),
|
||||
"additional_function_arguments": additional_function_arguments or None,
|
||||
"frequency_penalty": opts.pop("frequency_penalty", None),
|
||||
"logit_bias": opts.pop("logit_bias", None),
|
||||
"max_tokens": opts.pop("max_tokens", None),
|
||||
|
||||
@@ -2871,12 +2871,7 @@ class _ChatOptionsBase(TypedDict, total=False):
|
||||
presence_penalty: float
|
||||
|
||||
# Tool configuration (forward reference to avoid circular import)
|
||||
tools: (
|
||||
ToolTypes
|
||||
| Callable[..., Any]
|
||||
| Sequence[ToolTypes | Callable[..., Any]]
|
||||
| None
|
||||
)
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None
|
||||
tool_choice: ToolMode | Literal["auto", "required", "none"]
|
||||
allow_multiple_tool_calls: bool
|
||||
|
||||
|
||||
@@ -769,6 +769,179 @@ class TestChatAgentFunctionMiddlewareWithTools:
|
||||
assert modified_kwargs["new_param"] == "added_by_middleware"
|
||||
assert modified_kwargs["custom_param"] == "test_value"
|
||||
|
||||
async def test_run_kwargs_available_in_function_middleware(self, chat_client_base: "MockBaseChatClient") -> None:
|
||||
"""Test that kwargs passed directly to agent.run() appear in FunctionInvocationContext.kwargs,
|
||||
including complex nested values like dicts."""
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
@function_middleware
|
||||
async def capture_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
captured_kwargs.update(context.kwargs)
|
||||
await call_next()
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=[
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_1", name="sample_tool_function", arguments='{"location": "Seattle"}'
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
|
||||
]
|
||||
|
||||
agent = Agent(client=chat_client_base, middleware=[capture_middleware], tools=[sample_tool_function])
|
||||
|
||||
session_metadata = {"tenant": "acme-corp", "region": "us-west"}
|
||||
await agent.run(
|
||||
[Message(role="user", text="Get weather")],
|
||||
user_id="user-456",
|
||||
session_metadata=session_metadata,
|
||||
)
|
||||
|
||||
assert "user_id" in captured_kwargs, f"Expected 'user_id' in kwargs: {captured_kwargs}"
|
||||
assert captured_kwargs["user_id"] == "user-456"
|
||||
assert captured_kwargs["session_metadata"] == {"tenant": "acme-corp", "region": "us-west"}
|
||||
|
||||
async def test_run_kwargs_merged_with_additional_function_arguments(
|
||||
self, chat_client_base: "MockBaseChatClient"
|
||||
) -> None:
|
||||
"""Test that explicit additional_function_arguments in options take precedence over run kwargs."""
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
@function_middleware
|
||||
async def capture_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
captured_kwargs.update(context.kwargs)
|
||||
await call_next()
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=[
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_1", name="sample_tool_function", arguments='{"location": "Seattle"}'
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
|
||||
]
|
||||
|
||||
agent = Agent(client=chat_client_base, middleware=[capture_middleware], tools=[sample_tool_function])
|
||||
|
||||
await agent.run(
|
||||
[Message(role="user", text="Get weather")],
|
||||
# This kwarg should be overridden by additional_function_arguments
|
||||
user_id="from-kwargs",
|
||||
tenant_id="from-kwargs",
|
||||
options={
|
||||
"additional_function_arguments": {
|
||||
"user_id": "from-options",
|
||||
"extra_key": "only-in-options",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# additional_function_arguments takes precedence for overlapping keys
|
||||
assert captured_kwargs["user_id"] == "from-options"
|
||||
# Non-overlapping kwargs from run() still come through
|
||||
assert captured_kwargs["tenant_id"] == "from-kwargs"
|
||||
# Keys only in additional_function_arguments are present
|
||||
assert captured_kwargs["extra_key"] == "only-in-options"
|
||||
|
||||
async def test_run_kwargs_consistent_across_multiple_tool_calls(
|
||||
self, chat_client_base: "MockBaseChatClient"
|
||||
) -> None:
|
||||
"""Test that kwargs are consistent across multiple tool invocations in a single run."""
|
||||
invocation_kwargs: list[dict[str, Any]] = []
|
||||
|
||||
@function_middleware
|
||||
async def capture_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
invocation_kwargs.append(dict(context.kwargs))
|
||||
await call_next()
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=[
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_1", name="sample_tool_function", arguments='{"location": "Seattle"}'
|
||||
),
|
||||
Content.from_function_call(
|
||||
call_id="call_2", name="sample_tool_function", arguments='{"location": "Portland"}'
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
|
||||
]
|
||||
|
||||
agent = Agent(client=chat_client_base, middleware=[capture_middleware], tools=[sample_tool_function])
|
||||
|
||||
await agent.run(
|
||||
[Message(role="user", text="Get weather for both cities")],
|
||||
user_id="user-456",
|
||||
request_id="req-001",
|
||||
)
|
||||
|
||||
assert len(invocation_kwargs) == 2
|
||||
for kw in invocation_kwargs:
|
||||
assert kw["user_id"] == "user-456"
|
||||
assert kw["request_id"] == "req-001"
|
||||
|
||||
async def test_run_without_kwargs_produces_empty_context_kwargs(
|
||||
self, chat_client_base: "MockBaseChatClient"
|
||||
) -> None:
|
||||
"""Test that when no kwargs are passed to run(), FunctionInvocationContext.kwargs is empty."""
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
@function_middleware
|
||||
async def capture_middleware(
|
||||
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
captured_kwargs.update(context.kwargs)
|
||||
await call_next()
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=[
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_1", name="sample_tool_function", arguments='{"location": "Seattle"}'
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
|
||||
]
|
||||
|
||||
agent = Agent(client=chat_client_base, middleware=[capture_middleware], tools=[sample_tool_function])
|
||||
|
||||
await agent.run([Message(role="user", text="Get weather")])
|
||||
|
||||
# No runtime kwargs should be present
|
||||
assert "user_id" not in captured_kwargs
|
||||
|
||||
|
||||
class TestMiddlewareDynamicRebuild:
|
||||
"""Test cases for dynamic middleware pipeline rebuilding with Agent."""
|
||||
|
||||
Reference in New Issue
Block a user