[BREAKING] Python: Refactor workflows kwargs (#5010)

* Refactor workflows kwargs usage

* Update sample

* Add tests

* Update samples

* Fix formatting

* Comments

* Comments 2

* Comments 3

* Fix test and typing
This commit is contained in:
Tao Chen
2026-04-02 02:40:39 -07:00
committed by GitHub
Unverified
parent fd253c0b0e
commit 62595b233f
10 changed files with 1092 additions and 826 deletions
@@ -6,7 +6,7 @@ import json
import logging
import sys
import uuid
from collections.abc import AsyncIterable, Awaitable, Sequence
from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload
@@ -152,7 +152,8 @@ class WorkflowAgent(BaseAgent):
session: AgentSession | None = None,
checkpoint_id: str | None = None,
checkpoint_storage: CheckpointStorage | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ...
@overload
@@ -164,7 +165,8 @@ class WorkflowAgent(BaseAgent):
session: AgentSession | None = None,
checkpoint_id: str | None = None,
checkpoint_storage: CheckpointStorage | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
) -> AgentResponse: ...
def run(
@@ -175,7 +177,8 @@ class WorkflowAgent(BaseAgent):
session: AgentSession | None = None,
checkpoint_id: str | None = None,
checkpoint_storage: CheckpointStorage | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
) -> ResponseStream[AgentResponseUpdate, AgentResponse] | Awaitable[AgentResponse]:
"""Get a response from the workflow agent.
@@ -192,8 +195,12 @@ class WorkflowAgent(BaseAgent):
checkpoint_storage: Runtime checkpoint storage. When provided with checkpoint_id,
used to load and restore the checkpoint. When provided without checkpoint_id,
enables checkpointing for this run.
**kwargs: Additional keyword arguments passed through to underlying workflow
and tool functions.
function_invocation_kwargs: Keyword arguments forwarded to tool invocations in
subagents. Either a mapping of agent name/executor id to kwargs, or a flat
mapping of kwargs for all tool invocations.
client_kwargs: Keyword arguments forwarded to chat client calls in
subagents. Either a mapping of agent name/executor id to kwargs, or a flat
mapping of kwargs for all chat client calls.
Returns:
When stream=True: An AsyncIterable[AgentResponseUpdate] for streaming updates.
@@ -208,10 +215,26 @@ class WorkflowAgent(BaseAgent):
response_id = str(uuid.uuid4())
if stream:
return ResponseStream(
self._run_stream_impl(messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs),
self._run_stream_impl(
messages,
response_id,
session,
checkpoint_id,
checkpoint_storage,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
),
finalizer=AgentResponse.from_updates,
)
return self._run_impl(messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs)
return self._run_impl(
messages,
response_id,
session,
checkpoint_id,
checkpoint_storage,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
)
async def _run_impl(
self,
@@ -220,7 +243,8 @@ class WorkflowAgent(BaseAgent):
session: AgentSession | None,
checkpoint_id: str | None = None,
checkpoint_storage: CheckpointStorage | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
) -> AgentResponse:
"""Internal implementation of non-streaming execution.
@@ -230,8 +254,8 @@ class WorkflowAgent(BaseAgent):
session: The agent session for conversation context.
checkpoint_id: ID of checkpoint to restore from.
checkpoint_storage: Runtime checkpoint storage.
**kwargs: Additional keyword arguments passed through to the underlying
workflow and tool functions.
function_invocation_kwargs: Optional kwargs for tool invocations.
client_kwargs: Optional kwargs for chat client calls.
Returns:
An AgentResponse representing the workflow execution results.
@@ -264,7 +288,12 @@ class WorkflowAgent(BaseAgent):
output_events: list[WorkflowEvent[Any]] = []
async for event in self._run_core(
session_messages, checkpoint_id, checkpoint_storage, streaming=False, **kwargs
session_messages,
checkpoint_id,
checkpoint_storage,
streaming=False,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
if event.type == "output" or event.type == "request_info":
output_events.append(event)
@@ -285,7 +314,8 @@ class WorkflowAgent(BaseAgent):
session: AgentSession | None,
checkpoint_id: str | None = None,
checkpoint_storage: CheckpointStorage | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
) -> AsyncIterable[AgentResponseUpdate]:
"""Internal implementation of streaming execution.
@@ -295,8 +325,8 @@ class WorkflowAgent(BaseAgent):
session: The agent session for conversation context.
checkpoint_id: ID of checkpoint to restore from.
checkpoint_storage: Runtime checkpoint storage.
**kwargs: Additional keyword arguments passed through to the underlying
workflow and tool functions.
function_invocation_kwargs: Optional kwargs for tool invocations.
client_kwargs: Optional kwargs for chat client calls.
Yields:
AgentResponseUpdate objects representing the workflow execution progress.
@@ -329,7 +359,12 @@ class WorkflowAgent(BaseAgent):
session_messages: list[Message] = session_context.get_messages(include_input=True)
all_updates: list[AgentResponseUpdate] = []
async for event in self._run_core(
session_messages, checkpoint_id, checkpoint_storage, streaming=True, **kwargs
session_messages,
checkpoint_id,
checkpoint_storage,
streaming=True,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
updates = self._convert_workflow_event_to_agent_response_updates(response_id, event)
for update in updates:
@@ -349,7 +384,8 @@ class WorkflowAgent(BaseAgent):
checkpoint_id: str | None,
checkpoint_storage: CheckpointStorage | None,
streaming: bool,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
) -> AsyncIterable[WorkflowEvent]:
"""Core implementation that yields workflow events for both streaming and non-streaming modes.
@@ -358,8 +394,8 @@ class WorkflowAgent(BaseAgent):
checkpoint_id: ID of checkpoint to restore from.
checkpoint_storage: Runtime checkpoint storage.
streaming: Whether to use streaming workflow methods.
**kwargs: Additional keyword arguments passed through to the underlying
workflow and tool functions.
function_invocation_kwargs: Optional kwargs for tool invocations.
client_kwargs: Optional kwargs for chat client calls.
Yields:
WorkflowEvent objects from the workflow execution.
@@ -371,10 +407,19 @@ class WorkflowAgent(BaseAgent):
if bool(self.pending_requests):
function_responses = self._process_pending_requests(input_messages)
if streaming:
async for event in self.workflow.run(responses=function_responses, stream=True, **kwargs):
async for event in self.workflow.run(
responses=function_responses,
stream=True,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event
else:
for event in await self.workflow.run(responses=function_responses, **kwargs):
for event in await self.workflow.run(
responses=function_responses,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event
elif checkpoint_id is not None:
@@ -383,14 +428,16 @@ class WorkflowAgent(BaseAgent):
stream=True,
checkpoint_id=checkpoint_id,
checkpoint_storage=checkpoint_storage,
**kwargs,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event
else:
for event in await self.workflow.run(
checkpoint_id=checkpoint_id,
checkpoint_storage=checkpoint_storage,
**kwargs,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event
@@ -400,14 +447,16 @@ class WorkflowAgent(BaseAgent):
message=input_messages,
stream=True,
checkpoint_storage=checkpoint_storage,
**kwargs,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event
else:
for event in await self.workflow.run(
message=input_messages,
checkpoint_storage=checkpoint_storage,
**kwargs,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event
@@ -2,7 +2,7 @@
import logging
import sys
from collections.abc import Awaitable, Callable, Mapping
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any, Literal, cast
@@ -14,7 +14,7 @@ from .._agents import SupportsAgentRun
from .._sessions import AgentSession
from .._types import AgentResponse, AgentResponseUpdate, Message, ResponseStream
from ._agent_utils import resolve_agent_id
from ._const import WORKFLOW_RUN_KWARGS_KEY
from ._const import GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY
from ._executor import Executor, handler
from ._message_utils import normalize_messages_input
from ._request_info_mixin import response_handler
@@ -335,15 +335,17 @@ class AgentExecutor(Executor):
Returns:
The complete AgentResponse, or None if waiting for user input.
"""
run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}))
function_invocation_kwargs, client_kwargs = self._prepare_agent_run_args(
ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {})
)
run_agent = cast(Callable[..., Awaitable[AgentResponse[Any]]], self._agent.run)
response = await run_agent(
self._cache,
stream=False,
session=self._session,
options=options,
**run_kwargs,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
)
await ctx.yield_output(response)
@@ -365,7 +367,9 @@ class AgentExecutor(Executor):
Returns:
The complete AgentResponse, or None if waiting for user input.
"""
run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}))
function_invocation_kwargs, client_kwargs = self._prepare_agent_run_args(
ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {})
)
updates: list[AgentResponseUpdate] = []
streamed_user_input_requests: list[Content] = []
@@ -374,8 +378,8 @@ class AgentExecutor(Executor):
self._cache,
stream=True,
session=self._session,
options=options,
**run_kwargs,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
)
async for update in stream:
updates.append(update)
@@ -421,74 +425,58 @@ class AgentExecutor(Executor):
return response
# Parameters that are explicitly passed to agent.run() by AgentExecutor
# and must not appear in **run_kwargs to avoid TypeError from duplicate values.
_RESERVED_RUN_PARAMS: frozenset[str] = frozenset({"session", "stream", "messages"})
def _prepare_agent_run_args(
self,
raw_run_kwargs: dict[str, Any],
) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
"""Prepare function_invocation_kwargs and client_kwargs for agent.run().
@staticmethod
def _prepare_agent_run_args(raw_run_kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any] | None]:
"""Prepare kwargs and options for agent.run(), avoiding duplicate option passing.
Extracts ``function_invocation_kwargs`` and ``client_kwargs`` from the
workflow state dict, resolving per-executor entries using ``self.id``. The
``__global__`` sentinel key (set by ``Workflow._resolve_invocation_kwargs``) denotes
global kwargs that apply to all executors. Per-executor dicts use executor IDs as
keys; this executor extracts only its own entry.
Workflow-level kwargs are propagated to tool calls through
`options.additional_function_arguments`. If workflow kwargs include an
`options` key, merge it into the final options object and remove it from
kwargs before spreading `**run_kwargs`.
Reserved parameters (session, stream, messages) that are explicitly
managed by AgentExecutor are stripped from run_kwargs to prevent
``TypeError: got multiple values for keyword argument`` collisions.
Returns:
A 2-tuple of (function_invocation_kwargs, client_kwargs).
"""
run_kwargs = dict(raw_run_kwargs)
fi_resolved = raw_run_kwargs.get("function_invocation_kwargs")
ci_resolved = raw_run_kwargs.get("client_kwargs")
# Strip reserved params that AgentExecutor passes explicitly to agent.run().
for key in AgentExecutor._RESERVED_RUN_PARAMS:
if key in run_kwargs:
logger.warning(
"Workflow kwarg '%s' is reserved by AgentExecutor and will be ignored. "
"Remove it from workflow.run() kwargs to silence this warning.",
key,
)
run_kwargs.pop(key)
function_invocation_kwargs = self._resolve_executor_kwargs(fi_resolved)
client_kwargs = self._resolve_executor_kwargs(ci_resolved)
options_from_workflow = run_kwargs.pop("options", None)
workflow_additional_args = run_kwargs.pop("additional_function_arguments", None)
return function_invocation_kwargs, client_kwargs
options: dict[str, Any] = {}
if options_from_workflow is not None:
if isinstance(options_from_workflow, Mapping):
options_from_workflow_map = cast(Mapping[str, Any], options_from_workflow)
for key, value in options_from_workflow_map.items():
options[key] = value
else:
logger.warning(
"Ignoring non-mapping workflow 'options' kwarg of type %s for AgentExecutor %s.",
type(options_from_workflow).__name__,
AgentExecutor.__name__,
)
def _resolve_executor_kwargs(self, resolved: dict[str, Any] | None) -> dict[str, Any] | None:
"""Extract this executor's kwargs from a resolved invocation kwargs dict.
existing_additional_args = options.get("additional_function_arguments")
additional_args: dict[str, Any]
if isinstance(existing_additional_args, Mapping):
existing_additional_args_map = cast(Mapping[str, Any], existing_additional_args)
additional_args = {key: value for key, value in existing_additional_args_map.items()}
Args:
resolved: The resolved dict produced by ``Workflow._resolve_invocation_kwargs``,
containing either a ``__global__`` key (global kwargs) or executor-ID keys
(per-executor kwargs). May also be ``None``.
Returns:
The kwargs for this executor, or ``None`` if not applicable.
"""
if not isinstance(resolved, dict):
return None
# Use explicit key-presence checks so that an empty per-executor dict is
# honoured (e.g. to clear kwargs) instead of falling through to global.
if self.id in resolved:
executor_kwargs = resolved[self.id]
elif GLOBAL_KWARGS_KEY in resolved:
executor_kwargs = resolved[GLOBAL_KWARGS_KEY]
else:
additional_args = {}
return None
if workflow_additional_args is not None:
if isinstance(workflow_additional_args, Mapping):
workflow_additional_args_map = cast(Mapping[str, Any], workflow_additional_args)
additional_args.update({key: value for key, value in workflow_additional_args_map.items()})
else:
logger.warning(
"Ignoring non-mapping workflow 'additional_function_arguments' kwarg of type %s for AgentExecutor %s.", # noqa: E501
type(workflow_additional_args).__name__,
AgentExecutor.__name__,
)
if not isinstance(executor_kwargs, dict):
logger.warning(
"Executor %s expected a dict for its kwargs, but got %s. Ignoring.",
self.id,
type(executor_kwargs), # type: ignore
)
if run_kwargs:
additional_args.update(run_kwargs)
return None
if additional_args:
options["additional_function_arguments"] = additional_args
return run_kwargs, options or None
return executor_kwargs # type: ignore
@@ -14,6 +14,10 @@ INTERNAL_SOURCE_PREFIX = "internal"
# to pass kwargs from workflow.run() through to agent.run() and @tool functions.
WORKFLOW_RUN_KWARGS_KEY = "_workflow_run_kwargs"
# Sentinel key used in resolved invocation kwargs dicts to denote global kwargs
# that apply to all executors (as opposed to per-executor keyed entries).
GLOBAL_KWARGS_KEY = "__global__"
def INTERNAL_SOURCE_ID(executor_id: str) -> str:
"""Generate an internal source ID for a given executor."""
@@ -10,14 +10,14 @@ import json
import logging
import types
import uuid
from collections.abc import AsyncIterable, Awaitable, Callable, Sequence
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, Sequence
from typing import Any, Literal, overload
from .._types import ResponseStream
from ..observability import OtelAttr, capture_exception, create_workflow_span
from ._agent import WorkflowAgent
from ._checkpoint import CheckpointStorage
from ._const import DEFAULT_MAX_ITERATIONS, WORKFLOW_RUN_KWARGS_KEY
from ._const import DEFAULT_MAX_ITERATIONS, GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY
from ._edge import (
EdgeGroup,
FanOutEdgeGroup,
@@ -180,7 +180,6 @@ class Workflow(DictConvertible):
description: str | None = None,
max_iterations: int = DEFAULT_MAX_ITERATIONS,
output_executors: list[str] | None = None,
**kwargs: Any,
):
"""Initialize the workflow with a list of edges.
@@ -198,7 +197,6 @@ class Workflow(DictConvertible):
WorkflowBuilder, this will be the description of the builder.
output_executors: Optional list of executor IDs whose outputs will be considered workflow outputs.
If None or empty, all executor outputs are treated as workflow outputs.
kwargs: Additional keyword arguments. Unused in this implementation.
"""
self.edge_groups = list(edge_groups)
self.executors = dict(executors)
@@ -300,7 +298,8 @@ class Workflow(DictConvertible):
initial_executor_fn: Callable[[], Awaitable[None]] | None = None,
reset_context: bool = True,
streaming: bool = False,
run_kwargs: dict[str, Any] | None = None,
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
) -> AsyncIterable[WorkflowEvent]:
"""Private method to run workflow with proper tracing.
@@ -311,7 +310,10 @@ class Workflow(DictConvertible):
initial_executor_fn: Optional function to execute initial executor
reset_context: Whether to reset the context for a new run
streaming: Whether to enable streaming mode for agents
run_kwargs: Optional kwargs to store in State for agent invocations
function_invocation_kwargs: Optional kwargs to store in State for function
invocations in subagents
client_kwargs: Optional kwargs to store in State for chat client
invocations in subagents
Yields:
WorkflowEvent: The events generated during the workflow execution.
@@ -350,8 +352,17 @@ class Workflow(DictConvertible):
# Only overwrite when new kwargs are explicitly provided or state was
# just cleared (fresh run). On continuation (reset_context=False) with
# no new kwargs, preserve the kwargs from the original run.
if run_kwargs is not None:
self._state.set(WORKFLOW_RUN_KWARGS_KEY, run_kwargs)
if function_invocation_kwargs is not None or client_kwargs is not None:
combined_kwargs: dict[str, Any] = {}
if function_invocation_kwargs is not None:
combined_kwargs["function_invocation_kwargs"] = self._resolve_invocation_kwargs(
function_invocation_kwargs, "function_invocation_kwargs"
)
if client_kwargs is not None:
combined_kwargs["client_kwargs"] = self._resolve_invocation_kwargs(
client_kwargs, "client_kwargs"
)
self._state.set(WORKFLOW_RUN_KWARGS_KEY, combined_kwargs)
elif reset_context:
self._state.set(WORKFLOW_RUN_KWARGS_KEY, {})
self._state.commit() # Commit immediately so kwargs are available
@@ -459,10 +470,11 @@ class Workflow(DictConvertible):
message: Any | None = None,
*,
stream: Literal[True],
responses: dict[str, Any] | None = None,
responses: Mapping[str, Any] | None = None,
checkpoint_id: str | None = None,
checkpoint_storage: CheckpointStorage | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> ResponseStream[WorkflowEvent, WorkflowRunResult]: ...
@overload
@@ -471,11 +483,12 @@ class Workflow(DictConvertible):
message: Any | None = None,
*,
stream: Literal[False] = ...,
responses: dict[str, Any] | None = None,
responses: Mapping[str, Any] | None = None,
checkpoint_id: str | None = None,
checkpoint_storage: CheckpointStorage | None = None,
include_status_events: bool = False,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> Awaitable[WorkflowRunResult]: ...
def run(
@@ -483,11 +496,12 @@ class Workflow(DictConvertible):
message: Any | None = None,
*,
stream: bool = False,
responses: dict[str, Any] | None = None,
responses: Mapping[str, Any] | None = None,
checkpoint_id: str | None = None,
checkpoint_storage: CheckpointStorage | None = None,
include_status_events: bool = False,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
) -> ResponseStream[WorkflowEvent, WorkflowRunResult] | Awaitable[WorkflowRunResult]:
"""Run the workflow, optionally streaming events.
@@ -509,7 +523,12 @@ class Workflow(DictConvertible):
(restore then send responses).
checkpoint_storage: Runtime checkpoint storage.
include_status_events: Whether to include status events (non-streaming only).
**kwargs: Additional keyword arguments to pass through to agent invocations.
function_invocation_kwargs: Keyword arguments forwarded to tool invocations in
subagents. Either a mapping for agent name or agent executor id to kwargs,
or a flat mapping of kwargs for all tool invocations.
client_kwargs: Keyword arguments forwarded to chat client calls in
subagents. Either a mapping for agent name or agent executor id to kwargs,
or a flat mapping of kwargs for all chat client calls.
Returns:
When stream=True: A ResponseStream[WorkflowEvent, WorkflowRunResult] for
@@ -530,7 +549,8 @@ class Workflow(DictConvertible):
checkpoint_id=checkpoint_id,
checkpoint_storage=checkpoint_storage,
streaming=stream,
**kwargs,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
),
finalizer=functools.partial(self._finalize_events, include_status_events=include_status_events),
cleanup_hooks=[
@@ -546,11 +566,12 @@ class Workflow(DictConvertible):
self,
message: Any | None = None,
*,
responses: dict[str, Any] | None = None,
responses: Mapping[str, Any] | None = None,
checkpoint_id: str | None = None,
checkpoint_storage: CheckpointStorage | None = None,
streaming: bool = False,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
) -> AsyncIterable[WorkflowEvent]:
"""Single core execution path for both streaming and non-streaming modes.
@@ -569,11 +590,8 @@ class Workflow(DictConvertible):
initial_executor_fn=initial_executor_fn,
reset_context=reset_context,
streaming=streaming,
# Empty **kwargs (no caller-provided kwargs) is collapsed to None so that
# continuation calls without explicit kwargs preserve the original run's kwargs.
# A non-empty kwargs dict (even one with empty values like {"key": {}})
# is passed through and will overwrite stored kwargs.
run_kwargs=kwargs if kwargs else None,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
if event.type == "output" and not self._should_yield_output_event(event):
continue
@@ -624,7 +642,7 @@ class Workflow(DictConvertible):
@staticmethod
def _validate_run_params(
message: Any | None,
responses: dict[str, Any] | None,
responses: Mapping[str, Any] | None,
checkpoint_id: str | None,
) -> None:
"""Validate parameter combinations for run().
@@ -650,7 +668,7 @@ class Workflow(DictConvertible):
def _resolve_execution_mode(
self,
message: Any | None,
responses: dict[str, Any] | None,
responses: Mapping[str, Any] | None,
checkpoint_id: str | None,
checkpoint_storage: CheckpointStorage | None,
) -> tuple[Callable[[], Awaitable[None]], bool]:
@@ -680,7 +698,7 @@ class Workflow(DictConvertible):
self,
checkpoint_id: str,
checkpoint_storage: CheckpointStorage | None,
responses: dict[str, Any],
responses: Mapping[str, Any],
) -> None:
"""Restore from a checkpoint then send responses to pending requests.
@@ -700,7 +718,7 @@ class Workflow(DictConvertible):
await self._runner.restore_from_checkpoint(checkpoint_id, checkpoint_storage)
await self._send_responses_internal(responses)
async def _send_responses_internal(self, responses: dict[str, Any]) -> None:
async def _send_responses_internal(self, responses: Mapping[str, Any]) -> None:
"""Internal method to validate and send responses to the executors."""
pending_requests = await self._runner_context.get_pending_request_info_events()
if not pending_requests:
@@ -739,6 +757,44 @@ class Workflow(DictConvertible):
raise ValueError(f"Executor with ID {executor_id} not found.")
return self.executors[executor_id]
def _resolve_invocation_kwargs(
self,
kwargs: Mapping[str, Any],
param_name: str,
) -> dict[str, Any]:
"""Resolve invocation kwargs into a normalized per-executor or global format.
Detects whether the provided kwargs dict uses per-executor targeting by checking
if any top-level key matches a known executor ID in the workflow. If at least one
key matches, all entries are treated as per-executor. Otherwise the dict is treated
as global kwargs that apply to every executor.
Args:
kwargs: The raw invocation kwargs from the caller.
param_name: The parameter name (for logging), e.g. ``"function_invocation_kwargs"``.
Returns:
A dict with either:
- ``{"__global__": <original dict>}`` for global kwargs, or
- The original dict unchanged for per-executor kwargs.
"""
executor_ids = set(self.executors.keys())
matched_ids = kwargs.keys() & executor_ids
if matched_ids:
logger.info(
"Detected per-executor %s: executor ID(s) %s found in keys. "
"All entries will be treated as per-executor.",
param_name,
matched_ids,
)
return dict(kwargs)
logger.info(
"No executor IDs found in %s keys; treating as global kwargs for all executors.",
param_name,
)
return {GLOBAL_KWARGS_KEY: dict(kwargs)}
def _should_yield_output_event(self, event: WorkflowEvent[Any]) -> bool:
"""Determine if an output event should be yielded as a workflow output.
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
from ._workflow import Workflow
from ._checkpoint_encoding import decode_checkpoint_value
from ._const import WORKFLOW_RUN_KWARGS_KEY
from ._const import GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY
from ._events import (
WorkflowEvent,
WorkflowRunState,
@@ -387,8 +387,28 @@ class WorkflowExecutor(Executor):
# Get kwargs from parent workflow's State to propagate to subworkflow
parent_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {})
# Extract invocation kwargs recognised by Workflow.run()
# The state stores resolved format (with __global__ wrapper for global kwargs).
# Unwrap __global__ before passing to the subworkflow so it gets re-resolved
# against the subworkflow's own executor IDs.
fi_kwargs: dict[str, Any] | None = None
ci_kwargs: dict[str, Any] | None = None
for key in ("function_invocation_kwargs", "client_kwargs"):
resolved = parent_kwargs.get(key)
if isinstance(resolved, dict):
# Unwrap global sentinel; pass per-executor dicts as-is
unwrapped: dict[str, Any] = resolved.get(GLOBAL_KWARGS_KEY, resolved) # type: ignore
if key == "function_invocation_kwargs":
fi_kwargs = unwrapped # type: ignore
else:
ci_kwargs = unwrapped # type: ignore
# Run the sub-workflow and collect all events, passing parent kwargs
result = await self.workflow.run(input_data, **parent_kwargs)
result = await self.workflow.run(
input_data,
function_invocation_kwargs=fi_kwargs, # type: ignore
client_kwargs=ci_kwargs, # type: ignore
)
logger.debug(
f"WorkflowExecutor {self.id} sub-workflow {self.workflow.id} "
@@ -1,8 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
import logging
from collections.abc import AsyncIterable, Awaitable
from typing import TYPE_CHECKING, Any, Literal, overload
from typing import Any, Literal, overload
import pytest
@@ -22,9 +21,7 @@ from agent_framework import (
)
from agent_framework._workflows._agent_executor import AgentExecutorResponse
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage
if TYPE_CHECKING:
from _pytest.logging import LogCaptureFixture
from agent_framework._workflows._const import GLOBAL_KWARGS_KEY
class _CountingAgent(BaseAgent):
@@ -309,87 +306,28 @@ async def test_agent_executor_save_and_restore_state_directly() -> None:
assert restored_session.session_id == session.session_id
async def test_agent_executor_run_with_session_kwarg_does_not_raise() -> None:
"""Passing session= via workflow.run() should not cause a duplicate-keyword TypeError (#4295)."""
agent = _CountingAgent(id="session_kwarg_agent", name="SessionKwargAgent")
executor = AgentExecutor(agent, id="session_kwarg_exec")
workflow = WorkflowBuilder(start_executor=executor).build()
async def test_prepare_agent_run_args_extracts_invocation_kwargs() -> None:
"""_prepare_agent_run_args extracts function_invocation_kwargs and client_kwargs."""
agent = _CountingAgent(id="test_agent", name="TestAgent")
executor = AgentExecutor(agent, id="test_exec")
# This previously raised: TypeError: run() got multiple values for keyword argument 'session'
result = await workflow.run("hello", session="user-supplied-value")
assert result is not None
assert agent.call_count == 1
async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() -> None:
"""Passing stream= via workflow.run() kwargs should not cause a duplicate-keyword TypeError."""
agent = _CountingAgent(id="stream_kwarg_agent", name="StreamKwargAgent")
executor = AgentExecutor(agent, id="stream_kwarg_exec")
workflow = WorkflowBuilder(start_executor=executor).build()
# stream=True at workflow level triggers streaming mode (returns async iterable)
events: list[WorkflowEvent] = []
async for event in workflow.run("hello", stream=True):
events.append(event)
assert len(events) > 0
assert agent.call_count == 1
@pytest.mark.parametrize("reserved_kwarg", ["session", "stream", "messages"])
async def test_prepare_agent_run_args_strips_reserved_kwargs(reserved_kwarg: str, caplog: "LogCaptureFixture") -> None:
"""_prepare_agent_run_args must remove reserved kwargs and log a warning."""
raw: dict[str, Any] = {
reserved_kwarg: "should-be-stripped",
"custom_key": "keep-me",
"function_invocation_kwargs": {"__global__": {"key": "fi_val"}},
"client_kwargs": {"__global__": {"key": "ci_val"}},
}
with caplog.at_level(logging.WARNING):
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
assert reserved_kwarg not in run_kwargs
assert "custom_key" in run_kwargs
assert options is not None
assert options["additional_function_arguments"]["custom_key"] == "keep-me"
assert any(reserved_kwarg in record.message for record in caplog.records)
fi_kwargs, ci_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
assert fi_kwargs == {"key": "fi_val"}
assert ci_kwargs == {"key": "ci_val"}
async def test_prepare_agent_run_args_preserves_non_reserved_kwargs() -> None:
"""Non-reserved workflow kwargs should pass through unchanged."""
raw: dict[str, Any] = {"custom_param": "value", "another": 42}
run_kwargs, _options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
assert run_kwargs["custom_param"] == "value"
assert run_kwargs["another"] == 42
async def test_prepare_agent_run_args_returns_none_when_no_kwargs() -> None:
"""_prepare_agent_run_args returns None for both when raw dict has no invocation kwargs."""
agent = _CountingAgent(id="test_agent", name="TestAgent")
executor = AgentExecutor(agent, id="test_exec")
async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once(
caplog: "LogCaptureFixture",
) -> None:
"""All reserved kwargs should be stripped when supplied together, each emitting a warning."""
raw: dict[str, Any] = {"session": "x", "stream": True, "messages": [], "custom": 1}
with caplog.at_level(logging.WARNING):
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
assert "session" not in run_kwargs
assert "stream" not in run_kwargs
assert "messages" not in run_kwargs
assert run_kwargs["custom"] == 1
assert options is not None
assert options["additional_function_arguments"]["custom"] == 1
warned_keys = {r.message.split("'")[1] for r in caplog.records if "reserved" in r.message.lower()}
assert warned_keys == {"session", "stream", "messages"}
async def test_agent_executor_run_with_messages_kwarg_does_not_raise() -> None:
"""Passing messages= via workflow.run() kwargs should not cause a duplicate-keyword TypeError."""
agent = _CountingAgent(id="messages_kwarg_agent", name="MessagesKwargAgent")
executor = AgentExecutor(agent, id="messages_kwarg_exec")
workflow = WorkflowBuilder(start_executor=executor).build()
result = await workflow.run("hello", messages=["stale"])
assert result is not None
assert agent.call_count == 1
fi_kwargs, ci_kwargs = executor._prepare_agent_run_args({}) # pyright: ignore[reportPrivateUsage]
assert fi_kwargs is None
assert ci_kwargs is None
class _NonCopyableRaw:
@@ -638,3 +576,126 @@ async def test_checkpoint_restore_works_without_context_mode_in_state() -> None:
assert cache[0].text == "cached msg"
# context_mode should remain as configured in the constructor, not changed by restore
assert executor._context_mode == "last_agent" # pyright: ignore[reportPrivateUsage]
# ---------------------------------------------------------------------------
# Per-executor kwargs resolution tests
# ---------------------------------------------------------------------------
async def test_resolve_executor_kwargs_returns_global_kwargs() -> None:
"""_resolve_executor_kwargs with the global kwargs key returns the global kwargs."""
agent = _CountingAgent(id="a", name="A")
executor = AgentExecutor(agent, id="exec_a")
resolved = {GLOBAL_KWARGS_KEY: {"tool_param": "value"}}
result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage]
assert result == {"tool_param": "value"}
async def test_resolve_executor_kwargs_returns_per_executor_kwargs() -> None:
"""_resolve_executor_kwargs with matching executor ID returns that executor's kwargs."""
agent = _CountingAgent(id="a", name="A")
executor = AgentExecutor(agent, id="exec_a")
resolved = {"exec_a": {"my_param": 42}, "exec_b": {"other_param": 99}}
result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage]
assert result == {"my_param": 42}
async def test_resolve_executor_kwargs_returns_none_for_unmatched_per_executor() -> None:
"""_resolve_executor_kwargs returns None when per-executor dict has no matching ID."""
agent = _CountingAgent(id="a", name="A")
executor = AgentExecutor(agent, id="exec_c")
resolved = {"exec_a": {"my_param": 42}, "exec_b": {"other_param": 99}}
result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage]
assert result is None
async def test_resolve_executor_kwargs_returns_none_for_none_input() -> None:
"""_resolve_executor_kwargs returns None when input is None."""
agent = _CountingAgent(id="a", name="A")
executor = AgentExecutor(agent, id="exec_a")
result = executor._resolve_executor_kwargs(None) # pyright: ignore[reportPrivateUsage]
assert result is None
async def test_resolve_executor_kwargs_prefers_executor_id_over_global() -> None:
"""_resolve_executor_kwargs prefers executor-specific entry over __global__."""
agent = _CountingAgent(id="a", name="A")
executor = AgentExecutor(agent, id="exec_a")
# Dict has both a per-executor entry and a global entry
resolved = {"exec_a": {"specific": True}, GLOBAL_KWARGS_KEY: {"global": True}}
result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage]
assert result == {"specific": True}
async def test_prepare_agent_run_args_extracts_function_invocation_kwargs() -> None:
"""_prepare_agent_run_args extracts function_invocation_kwargs from the state dict."""
agent = _CountingAgent(id="a", name="A")
executor = AgentExecutor(agent, id="exec_a")
raw: dict[str, Any] = {
"function_invocation_kwargs": {GLOBAL_KWARGS_KEY: {"tool_key": "tool_val"}},
}
fi_kwargs, client_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
assert fi_kwargs == {"tool_key": "tool_val"}
assert client_kwargs is None
async def test_prepare_agent_run_args_extracts_client_kwargs() -> None:
"""_prepare_agent_run_args extracts client_kwargs from the state dict."""
agent = _CountingAgent(id="a", name="A")
executor = AgentExecutor(agent, id="exec_a")
raw: dict[str, Any] = {
"client_kwargs": {GLOBAL_KWARGS_KEY: {"model": "gpt-4"}},
}
fi_kwargs, client_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
assert fi_kwargs is None
assert client_kwargs == {"model": "gpt-4"}
async def test_prepare_agent_run_args_per_executor_resolution() -> None:
"""_prepare_agent_run_args resolves per-executor function_invocation_kwargs using self.id."""
agent = _CountingAgent(id="a", name="A")
executor = AgentExecutor(agent, id="exec_a")
raw: dict[str, Any] = {
"function_invocation_kwargs": {
"exec_a": {"my_tool_key": "my_val"},
"exec_b": {"other_tool_key": "other_val"},
},
}
fi_kwargs, _ = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
assert fi_kwargs == {"my_tool_key": "my_val"}
async def test_prepare_agent_run_args_per_executor_no_match() -> None:
"""_prepare_agent_run_args returns None for function_invocation_kwargs when executor ID not found."""
agent = _CountingAgent(id="a", name="A")
executor = AgentExecutor(agent, id="exec_c")
raw: dict[str, Any] = {
"function_invocation_kwargs": {
"exec_a": {"my_tool_key": "my_val"},
"exec_b": {"other_tool_key": "other_val"},
},
}
fi_kwargs, _ = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
assert fi_kwargs is None
async def test_resolve_executor_kwargs_empty_per_executor_does_not_fallback_to_global() -> None:
"""An explicit empty per-executor dict should not fall through to global kwargs."""
agent = _CountingAgent(id="a", name="A")
executor = AgentExecutor(agent, id="exec_a")
# Per-executor entry for exec_a is empty, but global has values.
# The empty dict should be honoured (no fallback to global).
resolved = {"exec_a": {}, GLOBAL_KWARGS_KEY: {"global_key": "global_val"}}
result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage]
assert result == {}
File diff suppressed because it is too large Load Diff
@@ -1,148 +0,0 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import json
import os
from typing import Annotated, Any, cast
from agent_framework import Agent, Message, tool
from agent_framework.foundry import FoundryChatClient
from agent_framework.orchestrations import SequentialBuilder
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
from pydantic import Field
# Load environment variables from .env file
load_dotenv()
"""
Sample: Workflow kwargs Flow to @tool Tools
This sample demonstrates how to flow custom context (skill data, user tokens, etc.)
through any workflow pattern to @tool functions using the **kwargs pattern.
Key Concepts:
- Pass custom context as kwargs when invoking workflow.run()
- kwargs are stored in State and passed to all agent invocations
- @tool functions receive kwargs via **kwargs parameter
- Works with Sequential, Concurrent, GroupChat, Handoff, and Magentic patterns
Prerequisites:
- FOUNDRY_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint.
- FOUNDRY_MODEL must be set to your Azure OpenAI model deployment name.
"""
# Define tools that accept custom context via **kwargs
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production;
# see samples/02-agents/tools/function_tool_with_approval.py
# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
@tool(approval_mode="never_require")
def get_user_data(
query: Annotated[str, Field(description="What user data to retrieve")],
**kwargs: Any,
) -> str:
"""Retrieve user-specific data based on the authenticated context."""
user_token = kwargs.get("user_token", {})
user_name = user_token.get("user_name", "anonymous")
access_level = user_token.get("access_level", "none")
print(f"\n[get_user_data] Received kwargs keys: {list(kwargs.keys())}")
print(f"[get_user_data] User: {user_name}")
print(f"[get_user_data] Access level: {access_level}")
return f"Retrieved data for user {user_name} with {access_level} access: {query}"
@tool(approval_mode="never_require")
def call_api(
endpoint_name: Annotated[str, Field(description="Name of the API endpoint to call")],
**kwargs: Any,
) -> str:
"""Call an API using the configured endpoints from custom_data."""
custom_data = kwargs.get("custom_data", {})
api_config = custom_data.get("api_config", {})
base_url = api_config.get("base_url", "unknown")
endpoints = api_config.get("endpoints", {})
print(f"\n[call_api] Received kwargs keys: {list(kwargs.keys())}")
print(f"[call_api] Base URL: {base_url}")
print(f"[call_api] Available endpoints: {list(endpoints.keys())}")
if endpoint_name in endpoints:
return f"Called {base_url}{endpoints[endpoint_name]} successfully"
return f"Endpoint '{endpoint_name}' not found in configuration"
async def main() -> None:
print("=" * 70)
print("Workflow kwargs Flow Demo (SequentialBuilder)")
print("=" * 70)
# Create chat client
client = FoundryChatClient(
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
model=os.environ["FOUNDRY_MODEL"],
credential=AzureCliCredential(),
)
# Create agent with tools that use kwargs
agent = Agent(
client=client,
name="assistant",
instructions=(
"You are a helpful assistant. Use the available tools to help users. "
"When asked about user data, use get_user_data. "
"When asked to call an API, use call_api."
),
tools=[get_user_data, call_api],
)
# Build a simple sequential workflow
workflow = SequentialBuilder(participants=[agent]).build()
# Define custom context that will flow to tools via kwargs
custom_data = {
"api_config": {
"base_url": "https://api.example.com",
"endpoints": {
"users": "/v1/users",
"orders": "/v1/orders",
"products": "/v1/products",
},
},
}
user_token = {
"user_name": "bob@contoso.com",
"access_level": "admin",
}
print("\nCustom Data being passed:")
print(json.dumps(custom_data, indent=2))
print(f"\nUser: {user_token['user_name']}")
print("\n" + "-" * 70)
print("Workflow Execution (watch for [tool_name] logs showing kwargs received):")
print("-" * 70)
# Run workflow with kwargs - these will flow through to tools
async for event in workflow.run(
"Please get my user data and then call the users API endpoint.",
additional_function_arguments={"custom_data": custom_data, "user_token": user_token},
stream=True,
):
if event.type == "output":
output_data = cast(list[Message], event.data)
if isinstance(output_data, list):
for item in output_data:
if isinstance(item, Message) and item.text:
print(f"\n[Final Answer]: {item.text}")
print("\n" + "=" * 70)
print("Sample Complete")
print("=" * 70)
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,170 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import json
import os
from typing import Annotated, Any, cast
from agent_framework import Agent, Message, tool
from agent_framework.foundry import FoundryChatClient
from agent_framework.orchestrations import SequentialBuilder
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
from pydantic import Field
# Load environment variables from .env file
load_dotenv()
"""
Sample: Global Workflow kwargs
This sample demonstrates how to pass the same kwargs to every agent in a
workflow using global targeting. When keys in function_invocation_kwargs do NOT
match any executor ID (agent name), the framework treats them as global and
delivers them to all agents.
Compare with workflow_kwargs_per_agent.py which targets kwargs to specific agents.
Key Concepts:
- Global function_invocation_kwargs are delivered to every agent in the workflow
- Useful when all agents share the same credentials, config, or context
- @tool functions receive kwargs via the **kwargs parameter
Prerequisites:
- FOUNDRY_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint.
- Environment variables configured
"""
# 1. Define a tool for the research agent — queries a company's internal
# database using credentials passed via global kwargs.
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production;
# see samples/02-agents/tools/function_tool_with_approval.py
# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
@tool(approval_mode="never_require")
def query_company_database(
query: Annotated[
str, Field(description="The database query to run, e.g. 'Q3 revenue' or 'headcount by department'")
],
**kwargs: Any,
) -> str:
"""Query the company's internal database for business metrics and data."""
db_config = kwargs.get("db_config", {})
connection_string = db_config.get("connection_string", "")
database = db_config.get("database", "")
if not connection_string or not database:
return f"ERROR: missing db_config — cannot run query '{query}'"
print(f"\n [query_company_database] Connecting to {database} at {connection_string[:30]}...")
# Simulated company data that the LLM would not know on its own
return (
f"Query results from {database}:\n"
f"- Contoso Q3 2025 revenue: $47.2M (up 12% YoY)\n"
f"- Top product line: CloudSync Pro ($18.6M)\n"
f"- Engineering headcount: 342 (up from 298 in Q2)\n"
f"- Customer churn rate: 4.1% (down from 5.3% in Q2)\n"
f"- Net new enterprise customers: 28"
)
# 2. Define a tool for the writer agent — retrieves the formatting style
# from user preferences passed via global kwargs.
@tool(approval_mode="never_require")
def get_formatting_instructions(
section_title: Annotated[str, Field(description="The title of the section or report to format")],
**kwargs: Any,
) -> str:
"""Get the formatting instructions based on user preferences."""
user_prefs = kwargs.get("user_preferences", {})
output_format = user_prefs.get("format", "plain")
language = user_prefs.get("language", "en")
print(f"\n [get_formatting_instructions] Format: {output_format}, Language: {language}")
return (
f"Formatting rules for '{section_title}':\n"
f"- Output format: {output_format}\n"
f"- Language/locale: {language}\n"
f"- Include a footer: 'Generated in {output_format} for locale {language}'"
)
async def main() -> None:
print("=" * 70)
print("Global Workflow kwargs Demo")
print("=" * 70)
# 3. Create a shared chat client.
client = FoundryChatClient(
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
model=os.environ["FOUNDRY_MODEL"],
credential=AzureCliCredential(),
)
# 4. Create two agents with different tools and responsibilities.
researcher = Agent(
client=client,
name="researcher",
instructions=(
"You are a data analyst. Call query_company_database exactly once "
"with the user's request as the query. Return the raw results."
),
tools=[query_company_database],
)
writer = Agent(
client=client,
name="writer",
instructions=(
"You are a report writer. Call get_formatting_instructions exactly once, "
"then rewrite the data you receive into a polished report following those rules."
),
tools=[get_formatting_instructions],
)
# 5. Build a sequential workflow: researcher -> writer.
workflow = SequentialBuilder(participants=[researcher, writer]).build()
# 6. Define global kwargs — every agent receives all of these.
# Because the keys ("db_config", "user_preferences") do NOT match any
# executor ID ("researcher", "writer"), the framework treats them as
# global and delivers the full dict to every agent.
global_fi_kwargs = {
"db_config": {
"connection_string": "Server=contoso-sql.database.windows.net;Database=metrics",
"database": "contoso_metrics_prod",
},
"user_preferences": {
"format": "markdown",
"language": "en-US",
},
}
print("\nGlobal function_invocation_kwargs (sent to all agents):")
print(json.dumps(global_fi_kwargs, indent=2))
print("\n" + "-" * 70)
print("Workflow Execution:")
print("-" * 70)
# 7. Run the workflow — every agent receives the same global kwargs.
async for event in workflow.run(
"Pull Contoso's Q3 2025 performance data and write an executive summary.",
function_invocation_kwargs=global_fi_kwargs,
stream=True,
):
if event.type == "output":
output_data = cast(list[Message], event.data)
if isinstance(output_data, list):
for item in output_data:
if isinstance(item, Message) and item.text:
print(f"\n[{item.author_name}]: {item.text}")
print("\n" + "=" * 70)
print("Sample Complete")
print("=" * 70)
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,222 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import json
import os
from typing import Annotated, Any, cast
from agent_framework import Agent, Message, tool
from agent_framework.foundry import FoundryChatClient
from agent_framework.orchestrations import SequentialBuilder
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
from pydantic import Field
# Load environment variables from .env file
load_dotenv()
"""
Sample: Per-Agent Workflow kwargs
This sample demonstrates how to pass different kwargs to different agents in a
workflow using per-agent targeting. When keys in function_invocation_kwargs (or
client_kwargs) match executor IDs (agent names by default), each agent
receives only its own slice of the kwargs.
Key Concepts:
- Per-agent function_invocation_kwargs target specific agents by executor ID
- Agents only receive the kwargs assigned to them (not other agents' kwargs)
- Useful when different agents need different credentials, configs, or context
Prerequisites:
- FOUNDRY_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint.
- Environment variables configured
"""
# 1. Define a tool for the research agent — queries a company's internal
# database using credentials passed via per-agent kwargs.
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production;
# see samples/02-agents/tools/function_tool_with_approval.py
# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
@tool(approval_mode="never_require")
def query_company_database(
query: Annotated[
str, Field(description="The database query to run, e.g. 'Q3 revenue' or 'headcount by department'")
],
**kwargs: Any,
) -> str:
"""Query the company's internal database for business metrics and data."""
db_config = kwargs.get("db_config", {})
connection_string = db_config.get("connection_string", "")
database = db_config.get("database", "")
if not connection_string or not database:
return f"ERROR: missing db_config — cannot run query '{query}'"
print(f"\n [query_company_database] Connecting to {database} at {connection_string[:30]}...")
# Simulated company data that the LLM would not know on its own
return (
f"Query results from {database}:\n"
f"- Contoso Q3 2025 revenue: $47.2M (up 12% YoY)\n"
f"- Top product line: CloudSync Pro ($18.6M)\n"
f"- Engineering headcount: 342 (up from 298 in Q2)\n"
f"- Customer churn rate: 4.1% (down from 5.3% in Q2)\n"
f"- Net new enterprise customers: 28"
)
# 2. Define a tool for the writer agent — retrieves the formatting style
# from user preferences passed via per-agent kwargs.
@tool(approval_mode="never_require")
def get_formatting_instructions(
section_title: Annotated[str, Field(description="The title of the section or report to format")],
**kwargs: Any,
) -> str:
"""Get the formatting instructions based on user preferences."""
user_prefs = kwargs.get("user_preferences", {})
output_format = user_prefs.get("format", "plain")
language = user_prefs.get("language", "en")
print(f"\n [get_formatting_instructions] Format: {output_format}, Language: {language}")
return (
f"Formatting rules for '{section_title}':\n"
f"- Output format: {output_format}\n"
f"- Language/locale: {language}\n"
f"- Include a footer: 'Generated in {output_format} for locale {language}'"
)
async def main() -> None:
print("=" * 70)
print("Per-Agent Workflow kwargs Demo")
print("=" * 70)
# 3. Create a shared chat client.
client = FoundryChatClient(
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
model=os.environ["FOUNDRY_MODEL"],
credential=AzureCliCredential(),
)
# 4. Create two agents with different tools and responsibilities.
researcher = Agent(
client=client,
name="researcher",
instructions=(
"You are a data analyst. Call query_company_database exactly once "
"with the user's request as the query. Return the raw results."
),
tools=[query_company_database],
)
writer = Agent(
client=client,
name="writer",
instructions=(
"You are a report writer. Call get_formatting_instructions exactly once, "
"then rewrite the data you receive into a polished report following those rules."
),
tools=[get_formatting_instructions],
)
# 5. Build a sequential workflow: researcher -> writer.
workflow = SequentialBuilder(participants=[researcher, writer]).build()
# 6. Define per-agent kwargs — each agent gets only its own config.
# The keys ("researcher", "writer") match the agent names, which are
# used as executor IDs by default.
per_agent_fi_kwargs = {
"researcher": {
"db_config": {
"connection_string": "Server=contoso-sql.database.windows.net;Database=metrics",
"database": "contoso_metrics_prod",
},
},
"writer": {
"user_preferences": {
"format": "markdown",
"language": "en-US",
},
},
}
print("\nPer-agent function_invocation_kwargs:")
print(json.dumps(per_agent_fi_kwargs, indent=2))
print("\n" + "-" * 70)
print("Workflow Execution:")
print("-" * 70)
# 7. Run the workflow — each agent receives only its targeted kwargs.
async for event in workflow.run(
"Pull Contoso's Q3 2025 performance data and write an executive summary.",
function_invocation_kwargs=per_agent_fi_kwargs,
stream=True,
):
if event.type == "output":
output_data = cast(list[Message], event.data)
if isinstance(output_data, list):
for item in output_data:
if isinstance(item, Message) and item.text:
print(f"\n[{item.author_name}]: {item.text}")
print("\n" + "=" * 70)
print("Sample Complete")
print("=" * 70)
if __name__ == "__main__":
asyncio.run(main())
"""
Sample output:
Per-agent function_invocation_kwargs:
{
"researcher": {
"db_config": {
"connection_string": "Server=contoso-sql.database.windows.net;Database=metrics",
"database": "contoso_metrics_prod"
}
},
"writer": {
"user_preferences": {
"format": "markdown",
"language": "en-US"
}
}
}
----------------------------------------------------------------------
Workflow Execution:
----------------------------------------------------------------------
[query_company_database] Connecting to contoso_metrics_prod at Server=contoso-sql.database.wi...
[researcher]: Here is Contoso's Q3 2025 data:
- Revenue: $47.2M (up 12% YoY)
- Top product: CloudSync Pro ($18.6M)
- Engineering headcount: 342
- Churn rate: 4.1%
- Net new enterprise customers: 28
[get_formatting_instructions] Format: markdown, Language: en-US
[writer]: # Contoso Q3 2025 Executive Summary
| Metric | Value |
|---|---|
| Revenue | $47.2M (+12% YoY) |
| Top Product | CloudSync Pro ($18.6M) |
| Engineering Headcount | 342 |
| Customer Churn | 4.1% |
| New Enterprise Customers | 28 |
Generated in markdown for locale en-US
======================================================================
Sample Complete
======================================================================
"""