mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Strip reserved kwargs in AgentExecutor to prevent duplicate-argument TypeError (#4298)
* Python: Strip reserved kwargs in AgentExecutor to prevent collision (#4295) workflow.run(session=...) passed 'session' through to agent.run() via **run_kwargs while AgentExecutor also passes session=self._session explicitly, causing TypeError: got multiple values for keyword argument. _prepare_agent_run_args now strips reserved params (session, stream, messages) from run_kwargs and logs a warning when they are present. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address PR review feedback for #4295 - Use _RESERVED_RUN_PARAMS constant in stripping loop instead of hardcoded tuple to maintain single source of truth - Trim frozenset to only stripped keys (session, stream, messages); options and additional_function_arguments have separate merge logic - Fix caplog type annotation to use TYPE_CHECKING pattern - Assert options return value in reserved-kwarg stripping test - Add test for multiple reserved kwargs supplied simultaneously - Add integration test for messages= kwarg via workflow.run() Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
97b24990d9
commit
823e714ccf
@@ -415,6 +415,10 @@ class AgentExecutor(Executor):
|
||||
|
||||
return response
|
||||
|
||||
# Parameters that are explicitly passed to agent.run() by AgentExecutor
|
||||
# and must not appear in **run_kwargs to avoid TypeError from duplicate values.
|
||||
_RESERVED_RUN_PARAMS: frozenset[str] = frozenset({"session", "stream", "messages"})
|
||||
|
||||
@staticmethod
|
||||
def _prepare_agent_run_args(raw_run_kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any] | None]:
|
||||
"""Prepare kwargs and options for agent.run(), avoiding duplicate option passing.
|
||||
@@ -423,8 +427,23 @@ class AgentExecutor(Executor):
|
||||
`options.additional_function_arguments`. If workflow kwargs include an
|
||||
`options` key, merge it into the final options object and remove it from
|
||||
kwargs before spreading `**run_kwargs`.
|
||||
|
||||
Reserved parameters (session, stream, messages) that are explicitly
|
||||
managed by AgentExecutor are stripped from run_kwargs to prevent
|
||||
``TypeError: got multiple values for keyword argument`` collisions.
|
||||
"""
|
||||
run_kwargs = dict(raw_run_kwargs)
|
||||
|
||||
# Strip reserved params that AgentExecutor passes explicitly to agent.run().
|
||||
for key in AgentExecutor._RESERVED_RUN_PARAMS:
|
||||
if key in run_kwargs:
|
||||
logger.warning(
|
||||
"Workflow kwarg '%s' is reserved by AgentExecutor and will be ignored. "
|
||||
"Remove it from workflow.run() kwargs to silence this warning.",
|
||||
key,
|
||||
)
|
||||
run_kwargs.pop(key)
|
||||
|
||||
options_from_workflow = run_kwargs.pop("options", None)
|
||||
workflow_additional_args = run_kwargs.pop("additional_function_arguments", None)
|
||||
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncIterable, Awaitable
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import (
|
||||
AgentExecutor,
|
||||
@@ -18,6 +21,9 @@ from agent_framework._workflows._agent_executor import AgentExecutorResponse
|
||||
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage
|
||||
from agent_framework.orchestrations import SequentialBuilder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _pytest.logging import LogCaptureFixture
|
||||
|
||||
|
||||
class _CountingAgent(BaseAgent):
|
||||
"""Agent that echoes messages with a counter to verify session state persistence."""
|
||||
@@ -251,3 +257,85 @@ async def test_agent_executor_save_and_restore_state_directly() -> None:
|
||||
# Verify session was restored with correct session_id
|
||||
restored_session = new_executor._session # type: ignore[reportPrivateUsage]
|
||||
assert restored_session.session_id == session.session_id
|
||||
|
||||
|
||||
async def test_agent_executor_run_with_session_kwarg_does_not_raise() -> None:
|
||||
"""Passing session= via workflow.run() should not cause a duplicate-keyword TypeError (#4295)."""
|
||||
agent = _CountingAgent(id="session_kwarg_agent", name="SessionKwargAgent")
|
||||
executor = AgentExecutor(agent, id="session_kwarg_exec")
|
||||
workflow = SequentialBuilder(participants=[executor]).build()
|
||||
|
||||
# This previously raised: TypeError: run() got multiple values for keyword argument 'session'
|
||||
result = await workflow.run("hello", session="user-supplied-value")
|
||||
assert result is not None
|
||||
assert agent.call_count == 1
|
||||
|
||||
|
||||
async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() -> None:
|
||||
"""Passing stream= via workflow.run() kwargs should not cause a duplicate-keyword TypeError."""
|
||||
agent = _CountingAgent(id="stream_kwarg_agent", name="StreamKwargAgent")
|
||||
executor = AgentExecutor(agent, id="stream_kwarg_exec")
|
||||
workflow = SequentialBuilder(participants=[executor]).build()
|
||||
|
||||
# stream=True at workflow level triggers streaming mode (returns async iterable)
|
||||
events = []
|
||||
async for event in workflow.run("hello", stream=True):
|
||||
events.append(event)
|
||||
assert len(events) > 0
|
||||
assert agent.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("reserved_kwarg", ["session", "stream", "messages"])
|
||||
async def test_prepare_agent_run_args_strips_reserved_kwargs(
|
||||
reserved_kwarg: str, caplog: "LogCaptureFixture"
|
||||
) -> None:
|
||||
"""_prepare_agent_run_args must remove reserved kwargs and log a warning."""
|
||||
raw = {reserved_kwarg: "should-be-stripped", "custom_key": "keep-me"}
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw)
|
||||
|
||||
assert reserved_kwarg not in run_kwargs
|
||||
assert "custom_key" in run_kwargs
|
||||
assert options is not None
|
||||
assert options["additional_function_arguments"]["custom_key"] == "keep-me"
|
||||
assert any(reserved_kwarg in record.message for record in caplog.records)
|
||||
|
||||
|
||||
async def test_prepare_agent_run_args_preserves_non_reserved_kwargs() -> None:
|
||||
"""Non-reserved workflow kwargs should pass through unchanged."""
|
||||
raw = {"custom_param": "value", "another": 42}
|
||||
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw)
|
||||
assert run_kwargs["custom_param"] == "value"
|
||||
assert run_kwargs["another"] == 42
|
||||
|
||||
|
||||
async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once(
|
||||
caplog: "LogCaptureFixture",
|
||||
) -> None:
|
||||
"""All reserved kwargs should be stripped when supplied together, each emitting a warning."""
|
||||
raw = {"session": "x", "stream": True, "messages": [], "custom": 1}
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw)
|
||||
|
||||
assert "session" not in run_kwargs
|
||||
assert "stream" not in run_kwargs
|
||||
assert "messages" not in run_kwargs
|
||||
assert run_kwargs["custom"] == 1
|
||||
assert options is not None
|
||||
assert options["additional_function_arguments"]["custom"] == 1
|
||||
|
||||
warned_keys = {r.message.split("'")[1] for r in caplog.records if "reserved" in r.message.lower()}
|
||||
assert warned_keys == {"session", "stream", "messages"}
|
||||
|
||||
|
||||
async def test_agent_executor_run_with_messages_kwarg_does_not_raise() -> None:
|
||||
"""Passing messages= via workflow.run() kwargs should not cause a duplicate-keyword TypeError."""
|
||||
agent = _CountingAgent(id="messages_kwarg_agent", name="MessagesKwargAgent")
|
||||
executor = AgentExecutor(agent, id="messages_kwarg_exec")
|
||||
workflow = SequentialBuilder(participants=[executor]).build()
|
||||
|
||||
result = await workflow.run("hello", messages=["stale"])
|
||||
assert result is not None
|
||||
assert agent.call_count == 1
|
||||
|
||||
Reference in New Issue
Block a user