mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
[BREAKING] Python: Refactor workflows kwargs (#5010)
* Refactor workflows kwargs usage * Update sample * Add tests * Update samples * Fix formatting * Comments * Comments 2 * Comments 3 * Fix test and typing
This commit is contained in:
committed by
GitHub
Unverified
parent
fd253c0b0e
commit
62595b233f
@@ -6,7 +6,7 @@ import json
|
||||
import logging
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import AsyncIterable, Awaitable, Sequence
|
||||
from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload
|
||||
@@ -152,7 +152,8 @@ class WorkflowAgent(BaseAgent):
|
||||
session: AgentSession | None = None,
|
||||
checkpoint_id: str | None = None,
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
**kwargs: Any,
|
||||
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ...
|
||||
|
||||
@overload
|
||||
@@ -164,7 +165,8 @@ class WorkflowAgent(BaseAgent):
|
||||
session: AgentSession | None = None,
|
||||
checkpoint_id: str | None = None,
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
**kwargs: Any,
|
||||
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
|
||||
) -> AgentResponse: ...
|
||||
|
||||
def run(
|
||||
@@ -175,7 +177,8 @@ class WorkflowAgent(BaseAgent):
|
||||
session: AgentSession | None = None,
|
||||
checkpoint_id: str | None = None,
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
**kwargs: Any,
|
||||
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse] | Awaitable[AgentResponse]:
|
||||
"""Get a response from the workflow agent.
|
||||
|
||||
@@ -192,8 +195,12 @@ class WorkflowAgent(BaseAgent):
|
||||
checkpoint_storage: Runtime checkpoint storage. When provided with checkpoint_id,
|
||||
used to load and restore the checkpoint. When provided without checkpoint_id,
|
||||
enables checkpointing for this run.
|
||||
**kwargs: Additional keyword arguments passed through to underlying workflow
|
||||
and tool functions.
|
||||
function_invocation_kwargs: Keyword arguments forwarded to tool invocations in
|
||||
subagents. Either a mapping of agent name/executor id to kwargs, or a flat
|
||||
mapping of kwargs for all tool invocations.
|
||||
client_kwargs: Keyword arguments forwarded to chat client calls in
|
||||
subagents. Either a mapping of agent name/executor id to kwargs, or a flat
|
||||
mapping of kwargs for all chat client calls.
|
||||
|
||||
Returns:
|
||||
When stream=True: An AsyncIterable[AgentResponseUpdate] for streaming updates.
|
||||
@@ -208,10 +215,26 @@ class WorkflowAgent(BaseAgent):
|
||||
response_id = str(uuid.uuid4())
|
||||
if stream:
|
||||
return ResponseStream(
|
||||
self._run_stream_impl(messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs),
|
||||
self._run_stream_impl(
|
||||
messages,
|
||||
response_id,
|
||||
session,
|
||||
checkpoint_id,
|
||||
checkpoint_storage,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
),
|
||||
finalizer=AgentResponse.from_updates,
|
||||
)
|
||||
return self._run_impl(messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs)
|
||||
return self._run_impl(
|
||||
messages,
|
||||
response_id,
|
||||
session,
|
||||
checkpoint_id,
|
||||
checkpoint_storage,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
)
|
||||
|
||||
async def _run_impl(
|
||||
self,
|
||||
@@ -220,7 +243,8 @@ class WorkflowAgent(BaseAgent):
|
||||
session: AgentSession | None,
|
||||
checkpoint_id: str | None = None,
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
**kwargs: Any,
|
||||
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
|
||||
) -> AgentResponse:
|
||||
"""Internal implementation of non-streaming execution.
|
||||
|
||||
@@ -230,8 +254,8 @@ class WorkflowAgent(BaseAgent):
|
||||
session: The agent session for conversation context.
|
||||
checkpoint_id: ID of checkpoint to restore from.
|
||||
checkpoint_storage: Runtime checkpoint storage.
|
||||
**kwargs: Additional keyword arguments passed through to the underlying
|
||||
workflow and tool functions.
|
||||
function_invocation_kwargs: Optional kwargs for tool invocations.
|
||||
client_kwargs: Optional kwargs for chat client calls.
|
||||
|
||||
Returns:
|
||||
An AgentResponse representing the workflow execution results.
|
||||
@@ -264,7 +288,12 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
output_events: list[WorkflowEvent[Any]] = []
|
||||
async for event in self._run_core(
|
||||
session_messages, checkpoint_id, checkpoint_storage, streaming=False, **kwargs
|
||||
session_messages,
|
||||
checkpoint_id,
|
||||
checkpoint_storage,
|
||||
streaming=False,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
if event.type == "output" or event.type == "request_info":
|
||||
output_events.append(event)
|
||||
@@ -285,7 +314,8 @@ class WorkflowAgent(BaseAgent):
|
||||
session: AgentSession | None,
|
||||
checkpoint_id: str | None = None,
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
**kwargs: Any,
|
||||
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
|
||||
) -> AsyncIterable[AgentResponseUpdate]:
|
||||
"""Internal implementation of streaming execution.
|
||||
|
||||
@@ -295,8 +325,8 @@ class WorkflowAgent(BaseAgent):
|
||||
session: The agent session for conversation context.
|
||||
checkpoint_id: ID of checkpoint to restore from.
|
||||
checkpoint_storage: Runtime checkpoint storage.
|
||||
**kwargs: Additional keyword arguments passed through to the underlying
|
||||
workflow and tool functions.
|
||||
function_invocation_kwargs: Optional kwargs for tool invocations.
|
||||
client_kwargs: Optional kwargs for chat client calls.
|
||||
|
||||
Yields:
|
||||
AgentResponseUpdate objects representing the workflow execution progress.
|
||||
@@ -329,7 +359,12 @@ class WorkflowAgent(BaseAgent):
|
||||
session_messages: list[Message] = session_context.get_messages(include_input=True)
|
||||
all_updates: list[AgentResponseUpdate] = []
|
||||
async for event in self._run_core(
|
||||
session_messages, checkpoint_id, checkpoint_storage, streaming=True, **kwargs
|
||||
session_messages,
|
||||
checkpoint_id,
|
||||
checkpoint_storage,
|
||||
streaming=True,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
updates = self._convert_workflow_event_to_agent_response_updates(response_id, event)
|
||||
for update in updates:
|
||||
@@ -349,7 +384,8 @@ class WorkflowAgent(BaseAgent):
|
||||
checkpoint_id: str | None,
|
||||
checkpoint_storage: CheckpointStorage | None,
|
||||
streaming: bool,
|
||||
**kwargs: Any,
|
||||
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
|
||||
) -> AsyncIterable[WorkflowEvent]:
|
||||
"""Core implementation that yields workflow events for both streaming and non-streaming modes.
|
||||
|
||||
@@ -358,8 +394,8 @@ class WorkflowAgent(BaseAgent):
|
||||
checkpoint_id: ID of checkpoint to restore from.
|
||||
checkpoint_storage: Runtime checkpoint storage.
|
||||
streaming: Whether to use streaming workflow methods.
|
||||
**kwargs: Additional keyword arguments passed through to the underlying
|
||||
workflow and tool functions.
|
||||
function_invocation_kwargs: Optional kwargs for tool invocations.
|
||||
client_kwargs: Optional kwargs for chat client calls.
|
||||
|
||||
Yields:
|
||||
WorkflowEvent objects from the workflow execution.
|
||||
@@ -371,10 +407,19 @@ class WorkflowAgent(BaseAgent):
|
||||
if bool(self.pending_requests):
|
||||
function_responses = self._process_pending_requests(input_messages)
|
||||
if streaming:
|
||||
async for event in self.workflow.run(responses=function_responses, stream=True, **kwargs):
|
||||
async for event in self.workflow.run(
|
||||
responses=function_responses,
|
||||
stream=True,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
yield event
|
||||
else:
|
||||
for event in await self.workflow.run(responses=function_responses, **kwargs):
|
||||
for event in await self.workflow.run(
|
||||
responses=function_responses,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
yield event
|
||||
|
||||
elif checkpoint_id is not None:
|
||||
@@ -383,14 +428,16 @@ class WorkflowAgent(BaseAgent):
|
||||
stream=True,
|
||||
checkpoint_id=checkpoint_id,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
**kwargs,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
yield event
|
||||
else:
|
||||
for event in await self.workflow.run(
|
||||
checkpoint_id=checkpoint_id,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
**kwargs,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -400,14 +447,16 @@ class WorkflowAgent(BaseAgent):
|
||||
message=input_messages,
|
||||
stream=True,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
**kwargs,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
yield event
|
||||
else:
|
||||
for event in await self.workflow.run(
|
||||
message=input_messages,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
**kwargs,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
@@ -14,7 +14,7 @@ from .._agents import SupportsAgentRun
|
||||
from .._sessions import AgentSession
|
||||
from .._types import AgentResponse, AgentResponseUpdate, Message, ResponseStream
|
||||
from ._agent_utils import resolve_agent_id
|
||||
from ._const import WORKFLOW_RUN_KWARGS_KEY
|
||||
from ._const import GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY
|
||||
from ._executor import Executor, handler
|
||||
from ._message_utils import normalize_messages_input
|
||||
from ._request_info_mixin import response_handler
|
||||
@@ -335,15 +335,17 @@ class AgentExecutor(Executor):
|
||||
Returns:
|
||||
The complete AgentResponse, or None if waiting for user input.
|
||||
"""
|
||||
run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}))
|
||||
function_invocation_kwargs, client_kwargs = self._prepare_agent_run_args(
|
||||
ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {})
|
||||
)
|
||||
|
||||
run_agent = cast(Callable[..., Awaitable[AgentResponse[Any]]], self._agent.run)
|
||||
response = await run_agent(
|
||||
self._cache,
|
||||
stream=False,
|
||||
session=self._session,
|
||||
options=options,
|
||||
**run_kwargs,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
)
|
||||
await ctx.yield_output(response)
|
||||
|
||||
@@ -365,7 +367,9 @@ class AgentExecutor(Executor):
|
||||
Returns:
|
||||
The complete AgentResponse, or None if waiting for user input.
|
||||
"""
|
||||
run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}))
|
||||
function_invocation_kwargs, client_kwargs = self._prepare_agent_run_args(
|
||||
ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {})
|
||||
)
|
||||
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
streamed_user_input_requests: list[Content] = []
|
||||
@@ -374,8 +378,8 @@ class AgentExecutor(Executor):
|
||||
self._cache,
|
||||
stream=True,
|
||||
session=self._session,
|
||||
options=options,
|
||||
**run_kwargs,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
)
|
||||
async for update in stream:
|
||||
updates.append(update)
|
||||
@@ -421,74 +425,58 @@ class AgentExecutor(Executor):
|
||||
|
||||
return response
|
||||
|
||||
# Parameters that are explicitly passed to agent.run() by AgentExecutor
|
||||
# and must not appear in **run_kwargs to avoid TypeError from duplicate values.
|
||||
_RESERVED_RUN_PARAMS: frozenset[str] = frozenset({"session", "stream", "messages"})
|
||||
def _prepare_agent_run_args(
|
||||
self,
|
||||
raw_run_kwargs: dict[str, Any],
|
||||
) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
|
||||
"""Prepare function_invocation_kwargs and client_kwargs for agent.run().
|
||||
|
||||
@staticmethod
|
||||
def _prepare_agent_run_args(raw_run_kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any] | None]:
|
||||
"""Prepare kwargs and options for agent.run(), avoiding duplicate option passing.
|
||||
Extracts ``function_invocation_kwargs`` and ``client_kwargs`` from the
|
||||
workflow state dict, resolving per-executor entries using ``self.id``. The
|
||||
``__global__`` sentinel key (set by ``Workflow._resolve_invocation_kwargs``) denotes
|
||||
global kwargs that apply to all executors. Per-executor dicts use executor IDs as
|
||||
keys; this executor extracts only its own entry.
|
||||
|
||||
Workflow-level kwargs are propagated to tool calls through
|
||||
`options.additional_function_arguments`. If workflow kwargs include an
|
||||
`options` key, merge it into the final options object and remove it from
|
||||
kwargs before spreading `**run_kwargs`.
|
||||
|
||||
Reserved parameters (session, stream, messages) that are explicitly
|
||||
managed by AgentExecutor are stripped from run_kwargs to prevent
|
||||
``TypeError: got multiple values for keyword argument`` collisions.
|
||||
Returns:
|
||||
A 2-tuple of (function_invocation_kwargs, client_kwargs).
|
||||
"""
|
||||
run_kwargs = dict(raw_run_kwargs)
|
||||
fi_resolved = raw_run_kwargs.get("function_invocation_kwargs")
|
||||
ci_resolved = raw_run_kwargs.get("client_kwargs")
|
||||
|
||||
# Strip reserved params that AgentExecutor passes explicitly to agent.run().
|
||||
for key in AgentExecutor._RESERVED_RUN_PARAMS:
|
||||
if key in run_kwargs:
|
||||
logger.warning(
|
||||
"Workflow kwarg '%s' is reserved by AgentExecutor and will be ignored. "
|
||||
"Remove it from workflow.run() kwargs to silence this warning.",
|
||||
key,
|
||||
)
|
||||
run_kwargs.pop(key)
|
||||
function_invocation_kwargs = self._resolve_executor_kwargs(fi_resolved)
|
||||
client_kwargs = self._resolve_executor_kwargs(ci_resolved)
|
||||
|
||||
options_from_workflow = run_kwargs.pop("options", None)
|
||||
workflow_additional_args = run_kwargs.pop("additional_function_arguments", None)
|
||||
return function_invocation_kwargs, client_kwargs
|
||||
|
||||
options: dict[str, Any] = {}
|
||||
if options_from_workflow is not None:
|
||||
if isinstance(options_from_workflow, Mapping):
|
||||
options_from_workflow_map = cast(Mapping[str, Any], options_from_workflow)
|
||||
for key, value in options_from_workflow_map.items():
|
||||
options[key] = value
|
||||
else:
|
||||
logger.warning(
|
||||
"Ignoring non-mapping workflow 'options' kwarg of type %s for AgentExecutor %s.",
|
||||
type(options_from_workflow).__name__,
|
||||
AgentExecutor.__name__,
|
||||
)
|
||||
def _resolve_executor_kwargs(self, resolved: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
"""Extract this executor's kwargs from a resolved invocation kwargs dict.
|
||||
|
||||
existing_additional_args = options.get("additional_function_arguments")
|
||||
additional_args: dict[str, Any]
|
||||
if isinstance(existing_additional_args, Mapping):
|
||||
existing_additional_args_map = cast(Mapping[str, Any], existing_additional_args)
|
||||
additional_args = {key: value for key, value in existing_additional_args_map.items()}
|
||||
Args:
|
||||
resolved: The resolved dict produced by ``Workflow._resolve_invocation_kwargs``,
|
||||
containing either a ``__global__`` key (global kwargs) or executor-ID keys
|
||||
(per-executor kwargs). May also be ``None``.
|
||||
|
||||
Returns:
|
||||
The kwargs for this executor, or ``None`` if not applicable.
|
||||
"""
|
||||
if not isinstance(resolved, dict):
|
||||
return None
|
||||
# Use explicit key-presence checks so that an empty per-executor dict is
|
||||
# honoured (e.g. to clear kwargs) instead of falling through to global.
|
||||
if self.id in resolved:
|
||||
executor_kwargs = resolved[self.id]
|
||||
elif GLOBAL_KWARGS_KEY in resolved:
|
||||
executor_kwargs = resolved[GLOBAL_KWARGS_KEY]
|
||||
else:
|
||||
additional_args = {}
|
||||
return None
|
||||
|
||||
if workflow_additional_args is not None:
|
||||
if isinstance(workflow_additional_args, Mapping):
|
||||
workflow_additional_args_map = cast(Mapping[str, Any], workflow_additional_args)
|
||||
additional_args.update({key: value for key, value in workflow_additional_args_map.items()})
|
||||
else:
|
||||
logger.warning(
|
||||
"Ignoring non-mapping workflow 'additional_function_arguments' kwarg of type %s for AgentExecutor %s.", # noqa: E501
|
||||
type(workflow_additional_args).__name__,
|
||||
AgentExecutor.__name__,
|
||||
)
|
||||
if not isinstance(executor_kwargs, dict):
|
||||
logger.warning(
|
||||
"Executor %s expected a dict for its kwargs, but got %s. Ignoring.",
|
||||
self.id,
|
||||
type(executor_kwargs), # type: ignore
|
||||
)
|
||||
|
||||
if run_kwargs:
|
||||
additional_args.update(run_kwargs)
|
||||
return None
|
||||
|
||||
if additional_args:
|
||||
options["additional_function_arguments"] = additional_args
|
||||
|
||||
return run_kwargs, options or None
|
||||
return executor_kwargs # type: ignore
|
||||
|
||||
@@ -14,6 +14,10 @@ INTERNAL_SOURCE_PREFIX = "internal"
|
||||
# to pass kwargs from workflow.run() through to agent.run() and @tool functions.
|
||||
WORKFLOW_RUN_KWARGS_KEY = "_workflow_run_kwargs"
|
||||
|
||||
# Sentinel key used in resolved invocation kwargs dicts to denote global kwargs
|
||||
# that apply to all executors (as opposed to per-executor keyed entries).
|
||||
GLOBAL_KWARGS_KEY = "__global__"
|
||||
|
||||
|
||||
def INTERNAL_SOURCE_ID(executor_id: str) -> str:
|
||||
"""Generate an internal source ID for a given executor."""
|
||||
|
||||
@@ -10,14 +10,14 @@ import json
|
||||
import logging
|
||||
import types
|
||||
import uuid
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Sequence
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, Sequence
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
from .._types import ResponseStream
|
||||
from ..observability import OtelAttr, capture_exception, create_workflow_span
|
||||
from ._agent import WorkflowAgent
|
||||
from ._checkpoint import CheckpointStorage
|
||||
from ._const import DEFAULT_MAX_ITERATIONS, WORKFLOW_RUN_KWARGS_KEY
|
||||
from ._const import DEFAULT_MAX_ITERATIONS, GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY
|
||||
from ._edge import (
|
||||
EdgeGroup,
|
||||
FanOutEdgeGroup,
|
||||
@@ -180,7 +180,6 @@ class Workflow(DictConvertible):
|
||||
description: str | None = None,
|
||||
max_iterations: int = DEFAULT_MAX_ITERATIONS,
|
||||
output_executors: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize the workflow with a list of edges.
|
||||
|
||||
@@ -198,7 +197,6 @@ class Workflow(DictConvertible):
|
||||
WorkflowBuilder, this will be the description of the builder.
|
||||
output_executors: Optional list of executor IDs whose outputs will be considered workflow outputs.
|
||||
If None or empty, all executor outputs are treated as workflow outputs.
|
||||
kwargs: Additional keyword arguments. Unused in this implementation.
|
||||
"""
|
||||
self.edge_groups = list(edge_groups)
|
||||
self.executors = dict(executors)
|
||||
@@ -300,7 +298,8 @@ class Workflow(DictConvertible):
|
||||
initial_executor_fn: Callable[[], Awaitable[None]] | None = None,
|
||||
reset_context: bool = True,
|
||||
streaming: bool = False,
|
||||
run_kwargs: dict[str, Any] | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
|
||||
) -> AsyncIterable[WorkflowEvent]:
|
||||
"""Private method to run workflow with proper tracing.
|
||||
|
||||
@@ -311,7 +310,10 @@ class Workflow(DictConvertible):
|
||||
initial_executor_fn: Optional function to execute initial executor
|
||||
reset_context: Whether to reset the context for a new run
|
||||
streaming: Whether to enable streaming mode for agents
|
||||
run_kwargs: Optional kwargs to store in State for agent invocations
|
||||
function_invocation_kwargs: Optional kwargs to store in State for function
|
||||
invocations in subagents
|
||||
client_kwargs: Optional kwargs to store in State for chat client
|
||||
invocations in subagents
|
||||
|
||||
Yields:
|
||||
WorkflowEvent: The events generated during the workflow execution.
|
||||
@@ -350,8 +352,17 @@ class Workflow(DictConvertible):
|
||||
# Only overwrite when new kwargs are explicitly provided or state was
|
||||
# just cleared (fresh run). On continuation (reset_context=False) with
|
||||
# no new kwargs, preserve the kwargs from the original run.
|
||||
if run_kwargs is not None:
|
||||
self._state.set(WORKFLOW_RUN_KWARGS_KEY, run_kwargs)
|
||||
if function_invocation_kwargs is not None or client_kwargs is not None:
|
||||
combined_kwargs: dict[str, Any] = {}
|
||||
if function_invocation_kwargs is not None:
|
||||
combined_kwargs["function_invocation_kwargs"] = self._resolve_invocation_kwargs(
|
||||
function_invocation_kwargs, "function_invocation_kwargs"
|
||||
)
|
||||
if client_kwargs is not None:
|
||||
combined_kwargs["client_kwargs"] = self._resolve_invocation_kwargs(
|
||||
client_kwargs, "client_kwargs"
|
||||
)
|
||||
self._state.set(WORKFLOW_RUN_KWARGS_KEY, combined_kwargs)
|
||||
elif reset_context:
|
||||
self._state.set(WORKFLOW_RUN_KWARGS_KEY, {})
|
||||
self._state.commit() # Commit immediately so kwargs are available
|
||||
@@ -459,10 +470,11 @@ class Workflow(DictConvertible):
|
||||
message: Any | None = None,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
responses: dict[str, Any] | None = None,
|
||||
responses: Mapping[str, Any] | None = None,
|
||||
checkpoint_id: str | None = None,
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
**kwargs: Any,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
) -> ResponseStream[WorkflowEvent, WorkflowRunResult]: ...
|
||||
|
||||
@overload
|
||||
@@ -471,11 +483,12 @@ class Workflow(DictConvertible):
|
||||
message: Any | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
responses: dict[str, Any] | None = None,
|
||||
responses: Mapping[str, Any] | None = None,
|
||||
checkpoint_id: str | None = None,
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
include_status_events: bool = False,
|
||||
**kwargs: Any,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
) -> Awaitable[WorkflowRunResult]: ...
|
||||
|
||||
def run(
|
||||
@@ -483,11 +496,12 @@ class Workflow(DictConvertible):
|
||||
message: Any | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
responses: dict[str, Any] | None = None,
|
||||
responses: Mapping[str, Any] | None = None,
|
||||
checkpoint_id: str | None = None,
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
include_status_events: bool = False,
|
||||
**kwargs: Any,
|
||||
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
|
||||
) -> ResponseStream[WorkflowEvent, WorkflowRunResult] | Awaitable[WorkflowRunResult]:
|
||||
"""Run the workflow, optionally streaming events.
|
||||
|
||||
@@ -509,7 +523,12 @@ class Workflow(DictConvertible):
|
||||
(restore then send responses).
|
||||
checkpoint_storage: Runtime checkpoint storage.
|
||||
include_status_events: Whether to include status events (non-streaming only).
|
||||
**kwargs: Additional keyword arguments to pass through to agent invocations.
|
||||
function_invocation_kwargs: Keyword arguments forwarded to tool invocations in
|
||||
subagents. Either a mapping for agent name or agent executor id to kwargs,
|
||||
or a flat mapping of kwargs for all tool invocations.
|
||||
client_kwargs: Keyword arguments forwarded to chat client calls in
|
||||
subagents. Either a mapping for agent name or agent executor id to kwargs,
|
||||
or a flat mapping of kwargs for all chat client calls.
|
||||
|
||||
Returns:
|
||||
When stream=True: A ResponseStream[WorkflowEvent, WorkflowRunResult] for
|
||||
@@ -530,7 +549,8 @@ class Workflow(DictConvertible):
|
||||
checkpoint_id=checkpoint_id,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
streaming=stream,
|
||||
**kwargs,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
),
|
||||
finalizer=functools.partial(self._finalize_events, include_status_events=include_status_events),
|
||||
cleanup_hooks=[
|
||||
@@ -546,11 +566,12 @@ class Workflow(DictConvertible):
|
||||
self,
|
||||
message: Any | None = None,
|
||||
*,
|
||||
responses: dict[str, Any] | None = None,
|
||||
responses: Mapping[str, Any] | None = None,
|
||||
checkpoint_id: str | None = None,
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
streaming: bool = False,
|
||||
**kwargs: Any,
|
||||
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
|
||||
) -> AsyncIterable[WorkflowEvent]:
|
||||
"""Single core execution path for both streaming and non-streaming modes.
|
||||
|
||||
@@ -569,11 +590,8 @@ class Workflow(DictConvertible):
|
||||
initial_executor_fn=initial_executor_fn,
|
||||
reset_context=reset_context,
|
||||
streaming=streaming,
|
||||
# Empty **kwargs (no caller-provided kwargs) is collapsed to None so that
|
||||
# continuation calls without explicit kwargs preserve the original run's kwargs.
|
||||
# A non-empty kwargs dict (even one with empty values like {"key": {}})
|
||||
# is passed through and will overwrite stored kwargs.
|
||||
run_kwargs=kwargs if kwargs else None,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
if event.type == "output" and not self._should_yield_output_event(event):
|
||||
continue
|
||||
@@ -624,7 +642,7 @@ class Workflow(DictConvertible):
|
||||
@staticmethod
|
||||
def _validate_run_params(
|
||||
message: Any | None,
|
||||
responses: dict[str, Any] | None,
|
||||
responses: Mapping[str, Any] | None,
|
||||
checkpoint_id: str | None,
|
||||
) -> None:
|
||||
"""Validate parameter combinations for run().
|
||||
@@ -650,7 +668,7 @@ class Workflow(DictConvertible):
|
||||
def _resolve_execution_mode(
|
||||
self,
|
||||
message: Any | None,
|
||||
responses: dict[str, Any] | None,
|
||||
responses: Mapping[str, Any] | None,
|
||||
checkpoint_id: str | None,
|
||||
checkpoint_storage: CheckpointStorage | None,
|
||||
) -> tuple[Callable[[], Awaitable[None]], bool]:
|
||||
@@ -680,7 +698,7 @@ class Workflow(DictConvertible):
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
checkpoint_storage: CheckpointStorage | None,
|
||||
responses: dict[str, Any],
|
||||
responses: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""Restore from a checkpoint then send responses to pending requests.
|
||||
|
||||
@@ -700,7 +718,7 @@ class Workflow(DictConvertible):
|
||||
await self._runner.restore_from_checkpoint(checkpoint_id, checkpoint_storage)
|
||||
await self._send_responses_internal(responses)
|
||||
|
||||
async def _send_responses_internal(self, responses: dict[str, Any]) -> None:
|
||||
async def _send_responses_internal(self, responses: Mapping[str, Any]) -> None:
|
||||
"""Internal method to validate and send responses to the executors."""
|
||||
pending_requests = await self._runner_context.get_pending_request_info_events()
|
||||
if not pending_requests:
|
||||
@@ -739,6 +757,44 @@ class Workflow(DictConvertible):
|
||||
raise ValueError(f"Executor with ID {executor_id} not found.")
|
||||
return self.executors[executor_id]
|
||||
|
||||
def _resolve_invocation_kwargs(
|
||||
self,
|
||||
kwargs: Mapping[str, Any],
|
||||
param_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Resolve invocation kwargs into a normalized per-executor or global format.
|
||||
|
||||
Detects whether the provided kwargs dict uses per-executor targeting by checking
|
||||
if any top-level key matches a known executor ID in the workflow. If at least one
|
||||
key matches, all entries are treated as per-executor. Otherwise the dict is treated
|
||||
as global kwargs that apply to every executor.
|
||||
|
||||
Args:
|
||||
kwargs: The raw invocation kwargs from the caller.
|
||||
param_name: The parameter name (for logging), e.g. ``"function_invocation_kwargs"``.
|
||||
|
||||
Returns:
|
||||
A dict with either:
|
||||
- ``{"__global__": <original dict>}`` for global kwargs, or
|
||||
- The original dict unchanged for per-executor kwargs.
|
||||
"""
|
||||
executor_ids = set(self.executors.keys())
|
||||
matched_ids = kwargs.keys() & executor_ids
|
||||
if matched_ids:
|
||||
logger.info(
|
||||
"Detected per-executor %s: executor ID(s) %s found in keys. "
|
||||
"All entries will be treated as per-executor.",
|
||||
param_name,
|
||||
matched_ids,
|
||||
)
|
||||
return dict(kwargs)
|
||||
|
||||
logger.info(
|
||||
"No executor IDs found in %s keys; treating as global kwargs for all executors.",
|
||||
param_name,
|
||||
)
|
||||
return {GLOBAL_KWARGS_KEY: dict(kwargs)}
|
||||
|
||||
def _should_yield_output_event(self, event: WorkflowEvent[Any]) -> bool:
|
||||
"""Determine if an output event should be yielded as a workflow output.
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
||||
from ._workflow import Workflow
|
||||
|
||||
from ._checkpoint_encoding import decode_checkpoint_value
|
||||
from ._const import WORKFLOW_RUN_KWARGS_KEY
|
||||
from ._const import GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY
|
||||
from ._events import (
|
||||
WorkflowEvent,
|
||||
WorkflowRunState,
|
||||
@@ -387,8 +387,28 @@ class WorkflowExecutor(Executor):
|
||||
# Get kwargs from parent workflow's State to propagate to subworkflow
|
||||
parent_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {})
|
||||
|
||||
# Extract invocation kwargs recognised by Workflow.run()
|
||||
# The state stores resolved format (with __global__ wrapper for global kwargs).
|
||||
# Unwrap __global__ before passing to the subworkflow so it gets re-resolved
|
||||
# against the subworkflow's own executor IDs.
|
||||
fi_kwargs: dict[str, Any] | None = None
|
||||
ci_kwargs: dict[str, Any] | None = None
|
||||
for key in ("function_invocation_kwargs", "client_kwargs"):
|
||||
resolved = parent_kwargs.get(key)
|
||||
if isinstance(resolved, dict):
|
||||
# Unwrap global sentinel; pass per-executor dicts as-is
|
||||
unwrapped: dict[str, Any] = resolved.get(GLOBAL_KWARGS_KEY, resolved) # type: ignore
|
||||
if key == "function_invocation_kwargs":
|
||||
fi_kwargs = unwrapped # type: ignore
|
||||
else:
|
||||
ci_kwargs = unwrapped # type: ignore
|
||||
|
||||
# Run the sub-workflow and collect all events, passing parent kwargs
|
||||
result = await self.workflow.run(input_data, **parent_kwargs)
|
||||
result = await self.workflow.run(
|
||||
input_data,
|
||||
function_invocation_kwargs=fi_kwargs, # type: ignore
|
||||
client_kwargs=ci_kwargs, # type: ignore
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"WorkflowExecutor {self.id} sub-workflow {self.workflow.id} "
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncIterable, Awaitable
|
||||
from typing import TYPE_CHECKING, Any, Literal, overload
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -22,9 +21,7 @@ from agent_framework import (
|
||||
)
|
||||
from agent_framework._workflows._agent_executor import AgentExecutorResponse
|
||||
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _pytest.logging import LogCaptureFixture
|
||||
from agent_framework._workflows._const import GLOBAL_KWARGS_KEY
|
||||
|
||||
|
||||
class _CountingAgent(BaseAgent):
|
||||
@@ -309,87 +306,28 @@ async def test_agent_executor_save_and_restore_state_directly() -> None:
|
||||
assert restored_session.session_id == session.session_id
|
||||
|
||||
|
||||
async def test_agent_executor_run_with_session_kwarg_does_not_raise() -> None:
|
||||
"""Passing session= via workflow.run() should not cause a duplicate-keyword TypeError (#4295)."""
|
||||
agent = _CountingAgent(id="session_kwarg_agent", name="SessionKwargAgent")
|
||||
executor = AgentExecutor(agent, id="session_kwarg_exec")
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
async def test_prepare_agent_run_args_extracts_invocation_kwargs() -> None:
|
||||
"""_prepare_agent_run_args extracts function_invocation_kwargs and client_kwargs."""
|
||||
agent = _CountingAgent(id="test_agent", name="TestAgent")
|
||||
executor = AgentExecutor(agent, id="test_exec")
|
||||
|
||||
# This previously raised: TypeError: run() got multiple values for keyword argument 'session'
|
||||
result = await workflow.run("hello", session="user-supplied-value")
|
||||
assert result is not None
|
||||
assert agent.call_count == 1
|
||||
|
||||
|
||||
async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() -> None:
|
||||
"""Passing stream= via workflow.run() kwargs should not cause a duplicate-keyword TypeError."""
|
||||
agent = _CountingAgent(id="stream_kwarg_agent", name="StreamKwargAgent")
|
||||
executor = AgentExecutor(agent, id="stream_kwarg_exec")
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
|
||||
# stream=True at workflow level triggers streaming mode (returns async iterable)
|
||||
events: list[WorkflowEvent] = []
|
||||
async for event in workflow.run("hello", stream=True):
|
||||
events.append(event)
|
||||
assert len(events) > 0
|
||||
assert agent.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("reserved_kwarg", ["session", "stream", "messages"])
|
||||
async def test_prepare_agent_run_args_strips_reserved_kwargs(reserved_kwarg: str, caplog: "LogCaptureFixture") -> None:
|
||||
"""_prepare_agent_run_args must remove reserved kwargs and log a warning."""
|
||||
raw: dict[str, Any] = {
|
||||
reserved_kwarg: "should-be-stripped",
|
||||
"custom_key": "keep-me",
|
||||
"function_invocation_kwargs": {"__global__": {"key": "fi_val"}},
|
||||
"client_kwargs": {"__global__": {"key": "ci_val"}},
|
||||
}
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
assert reserved_kwarg not in run_kwargs
|
||||
assert "custom_key" in run_kwargs
|
||||
assert options is not None
|
||||
assert options["additional_function_arguments"]["custom_key"] == "keep-me"
|
||||
assert any(reserved_kwarg in record.message for record in caplog.records)
|
||||
fi_kwargs, ci_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
|
||||
assert fi_kwargs == {"key": "fi_val"}
|
||||
assert ci_kwargs == {"key": "ci_val"}
|
||||
|
||||
|
||||
async def test_prepare_agent_run_args_preserves_non_reserved_kwargs() -> None:
|
||||
"""Non-reserved workflow kwargs should pass through unchanged."""
|
||||
raw: dict[str, Any] = {"custom_param": "value", "another": 42}
|
||||
run_kwargs, _options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
|
||||
assert run_kwargs["custom_param"] == "value"
|
||||
assert run_kwargs["another"] == 42
|
||||
async def test_prepare_agent_run_args_returns_none_when_no_kwargs() -> None:
|
||||
"""_prepare_agent_run_args returns None for both when raw dict has no invocation kwargs."""
|
||||
agent = _CountingAgent(id="test_agent", name="TestAgent")
|
||||
executor = AgentExecutor(agent, id="test_exec")
|
||||
|
||||
|
||||
async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once(
|
||||
caplog: "LogCaptureFixture",
|
||||
) -> None:
|
||||
"""All reserved kwargs should be stripped when supplied together, each emitting a warning."""
|
||||
raw: dict[str, Any] = {"session": "x", "stream": True, "messages": [], "custom": 1}
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
assert "session" not in run_kwargs
|
||||
assert "stream" not in run_kwargs
|
||||
assert "messages" not in run_kwargs
|
||||
assert run_kwargs["custom"] == 1
|
||||
assert options is not None
|
||||
assert options["additional_function_arguments"]["custom"] == 1
|
||||
|
||||
warned_keys = {r.message.split("'")[1] for r in caplog.records if "reserved" in r.message.lower()}
|
||||
assert warned_keys == {"session", "stream", "messages"}
|
||||
|
||||
|
||||
async def test_agent_executor_run_with_messages_kwarg_does_not_raise() -> None:
|
||||
"""Passing messages= via workflow.run() kwargs should not cause a duplicate-keyword TypeError."""
|
||||
agent = _CountingAgent(id="messages_kwarg_agent", name="MessagesKwargAgent")
|
||||
executor = AgentExecutor(agent, id="messages_kwarg_exec")
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
|
||||
result = await workflow.run("hello", messages=["stale"])
|
||||
assert result is not None
|
||||
assert agent.call_count == 1
|
||||
fi_kwargs, ci_kwargs = executor._prepare_agent_run_args({}) # pyright: ignore[reportPrivateUsage]
|
||||
assert fi_kwargs is None
|
||||
assert ci_kwargs is None
|
||||
|
||||
|
||||
class _NonCopyableRaw:
|
||||
@@ -638,3 +576,126 @@ async def test_checkpoint_restore_works_without_context_mode_in_state() -> None:
|
||||
assert cache[0].text == "cached msg"
|
||||
# context_mode should remain as configured in the constructor, not changed by restore
|
||||
assert executor._context_mode == "last_agent" # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-executor kwargs resolution tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_resolve_executor_kwargs_returns_global_kwargs() -> None:
|
||||
"""_resolve_executor_kwargs with the global kwargs key returns the global kwargs."""
|
||||
agent = _CountingAgent(id="a", name="A")
|
||||
executor = AgentExecutor(agent, id="exec_a")
|
||||
|
||||
resolved = {GLOBAL_KWARGS_KEY: {"tool_param": "value"}}
|
||||
result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage]
|
||||
assert result == {"tool_param": "value"}
|
||||
|
||||
|
||||
async def test_resolve_executor_kwargs_returns_per_executor_kwargs() -> None:
|
||||
"""_resolve_executor_kwargs with matching executor ID returns that executor's kwargs."""
|
||||
agent = _CountingAgent(id="a", name="A")
|
||||
executor = AgentExecutor(agent, id="exec_a")
|
||||
|
||||
resolved = {"exec_a": {"my_param": 42}, "exec_b": {"other_param": 99}}
|
||||
result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage]
|
||||
assert result == {"my_param": 42}
|
||||
|
||||
|
||||
async def test_resolve_executor_kwargs_returns_none_for_unmatched_per_executor() -> None:
|
||||
"""_resolve_executor_kwargs returns None when per-executor dict has no matching ID."""
|
||||
agent = _CountingAgent(id="a", name="A")
|
||||
executor = AgentExecutor(agent, id="exec_c")
|
||||
|
||||
resolved = {"exec_a": {"my_param": 42}, "exec_b": {"other_param": 99}}
|
||||
result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage]
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_resolve_executor_kwargs_returns_none_for_none_input() -> None:
|
||||
"""_resolve_executor_kwargs returns None when input is None."""
|
||||
agent = _CountingAgent(id="a", name="A")
|
||||
executor = AgentExecutor(agent, id="exec_a")
|
||||
|
||||
result = executor._resolve_executor_kwargs(None) # pyright: ignore[reportPrivateUsage]
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_resolve_executor_kwargs_prefers_executor_id_over_global() -> None:
|
||||
"""_resolve_executor_kwargs prefers executor-specific entry over __global__."""
|
||||
agent = _CountingAgent(id="a", name="A")
|
||||
executor = AgentExecutor(agent, id="exec_a")
|
||||
|
||||
# Dict has both a per-executor entry and a global entry
|
||||
resolved = {"exec_a": {"specific": True}, GLOBAL_KWARGS_KEY: {"global": True}}
|
||||
result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage]
|
||||
assert result == {"specific": True}
|
||||
|
||||
|
||||
async def test_prepare_agent_run_args_extracts_function_invocation_kwargs() -> None:
|
||||
"""_prepare_agent_run_args extracts function_invocation_kwargs from the state dict."""
|
||||
agent = _CountingAgent(id="a", name="A")
|
||||
executor = AgentExecutor(agent, id="exec_a")
|
||||
|
||||
raw: dict[str, Any] = {
|
||||
"function_invocation_kwargs": {GLOBAL_KWARGS_KEY: {"tool_key": "tool_val"}},
|
||||
}
|
||||
fi_kwargs, client_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
|
||||
assert fi_kwargs == {"tool_key": "tool_val"}
|
||||
assert client_kwargs is None
|
||||
|
||||
|
||||
async def test_prepare_agent_run_args_extracts_client_kwargs() -> None:
|
||||
"""_prepare_agent_run_args extracts client_kwargs from the state dict."""
|
||||
agent = _CountingAgent(id="a", name="A")
|
||||
executor = AgentExecutor(agent, id="exec_a")
|
||||
|
||||
raw: dict[str, Any] = {
|
||||
"client_kwargs": {GLOBAL_KWARGS_KEY: {"model": "gpt-4"}},
|
||||
}
|
||||
fi_kwargs, client_kwargs = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
|
||||
assert fi_kwargs is None
|
||||
assert client_kwargs == {"model": "gpt-4"}
|
||||
|
||||
|
||||
async def test_prepare_agent_run_args_per_executor_resolution() -> None:
|
||||
"""_prepare_agent_run_args resolves per-executor function_invocation_kwargs using self.id."""
|
||||
agent = _CountingAgent(id="a", name="A")
|
||||
executor = AgentExecutor(agent, id="exec_a")
|
||||
|
||||
raw: dict[str, Any] = {
|
||||
"function_invocation_kwargs": {
|
||||
"exec_a": {"my_tool_key": "my_val"},
|
||||
"exec_b": {"other_tool_key": "other_val"},
|
||||
},
|
||||
}
|
||||
fi_kwargs, _ = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
|
||||
assert fi_kwargs == {"my_tool_key": "my_val"}
|
||||
|
||||
|
||||
async def test_prepare_agent_run_args_per_executor_no_match() -> None:
|
||||
"""_prepare_agent_run_args returns None for function_invocation_kwargs when executor ID not found."""
|
||||
agent = _CountingAgent(id="a", name="A")
|
||||
executor = AgentExecutor(agent, id="exec_c")
|
||||
|
||||
raw: dict[str, Any] = {
|
||||
"function_invocation_kwargs": {
|
||||
"exec_a": {"my_tool_key": "my_val"},
|
||||
"exec_b": {"other_tool_key": "other_val"},
|
||||
},
|
||||
}
|
||||
fi_kwargs, _ = executor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
|
||||
assert fi_kwargs is None
|
||||
|
||||
|
||||
async def test_resolve_executor_kwargs_empty_per_executor_does_not_fallback_to_global() -> None:
|
||||
"""An explicit empty per-executor dict should not fall through to global kwargs."""
|
||||
agent = _CountingAgent(id="a", name="A")
|
||||
executor = AgentExecutor(agent, id="exec_a")
|
||||
|
||||
# Per-executor entry for exec_a is empty, but global has values.
|
||||
# The empty dict should be honoured (no fallback to global).
|
||||
resolved = {"exec_a": {}, GLOBAL_KWARGS_KEY: {"global_key": "global_val"}}
|
||||
result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage]
|
||||
assert result == {}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,148 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Annotated, Any, cast
|
||||
|
||||
from agent_framework import Agent, Message, tool
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from agent_framework.orchestrations import SequentialBuilder
|
||||
from azure.identity import AzureCliCredential
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
"""
|
||||
Sample: Workflow kwargs Flow to @tool Tools
|
||||
|
||||
This sample demonstrates how to flow custom context (skill data, user tokens, etc.)
|
||||
through any workflow pattern to @tool functions using the **kwargs pattern.
|
||||
|
||||
Key Concepts:
|
||||
- Pass custom context as kwargs when invoking workflow.run()
|
||||
- kwargs are stored in State and passed to all agent invocations
|
||||
- @tool functions receive kwargs via **kwargs parameter
|
||||
- Works with Sequential, Concurrent, GroupChat, Handoff, and Magentic patterns
|
||||
|
||||
Prerequisites:
|
||||
- FOUNDRY_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint.
|
||||
- FOUNDRY_MODEL must be set to your Azure OpenAI model deployment name.
|
||||
"""
|
||||
|
||||
|
||||
# Define tools that accept custom context via **kwargs
|
||||
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production;
|
||||
# see samples/02-agents/tools/function_tool_with_approval.py
|
||||
# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
|
||||
@tool(approval_mode="never_require")
|
||||
def get_user_data(
|
||||
query: Annotated[str, Field(description="What user data to retrieve")],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Retrieve user-specific data based on the authenticated context."""
|
||||
user_token = kwargs.get("user_token", {})
|
||||
user_name = user_token.get("user_name", "anonymous")
|
||||
access_level = user_token.get("access_level", "none")
|
||||
|
||||
print(f"\n[get_user_data] Received kwargs keys: {list(kwargs.keys())}")
|
||||
print(f"[get_user_data] User: {user_name}")
|
||||
print(f"[get_user_data] Access level: {access_level}")
|
||||
|
||||
return f"Retrieved data for user {user_name} with {access_level} access: {query}"
|
||||
|
||||
|
||||
@tool(approval_mode="never_require")
|
||||
def call_api(
|
||||
endpoint_name: Annotated[str, Field(description="Name of the API endpoint to call")],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call an API using the configured endpoints from custom_data."""
|
||||
custom_data = kwargs.get("custom_data", {})
|
||||
api_config = custom_data.get("api_config", {})
|
||||
|
||||
base_url = api_config.get("base_url", "unknown")
|
||||
endpoints = api_config.get("endpoints", {})
|
||||
|
||||
print(f"\n[call_api] Received kwargs keys: {list(kwargs.keys())}")
|
||||
print(f"[call_api] Base URL: {base_url}")
|
||||
print(f"[call_api] Available endpoints: {list(endpoints.keys())}")
|
||||
|
||||
if endpoint_name in endpoints:
|
||||
return f"Called {base_url}{endpoints[endpoint_name]} successfully"
|
||||
return f"Endpoint '{endpoint_name}' not found in configuration"
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
print("=" * 70)
|
||||
print("Workflow kwargs Flow Demo (SequentialBuilder)")
|
||||
print("=" * 70)
|
||||
|
||||
# Create chat client
|
||||
client = FoundryChatClient(
|
||||
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
|
||||
model=os.environ["FOUNDRY_MODEL"],
|
||||
credential=AzureCliCredential(),
|
||||
)
|
||||
|
||||
# Create agent with tools that use kwargs
|
||||
agent = Agent(
|
||||
client=client,
|
||||
name="assistant",
|
||||
instructions=(
|
||||
"You are a helpful assistant. Use the available tools to help users. "
|
||||
"When asked about user data, use get_user_data. "
|
||||
"When asked to call an API, use call_api."
|
||||
),
|
||||
tools=[get_user_data, call_api],
|
||||
)
|
||||
|
||||
# Build a simple sequential workflow
|
||||
workflow = SequentialBuilder(participants=[agent]).build()
|
||||
|
||||
# Define custom context that will flow to tools via kwargs
|
||||
custom_data = {
|
||||
"api_config": {
|
||||
"base_url": "https://api.example.com",
|
||||
"endpoints": {
|
||||
"users": "/v1/users",
|
||||
"orders": "/v1/orders",
|
||||
"products": "/v1/products",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
user_token = {
|
||||
"user_name": "bob@contoso.com",
|
||||
"access_level": "admin",
|
||||
}
|
||||
|
||||
print("\nCustom Data being passed:")
|
||||
print(json.dumps(custom_data, indent=2))
|
||||
print(f"\nUser: {user_token['user_name']}")
|
||||
print("\n" + "-" * 70)
|
||||
print("Workflow Execution (watch for [tool_name] logs showing kwargs received):")
|
||||
print("-" * 70)
|
||||
|
||||
# Run workflow with kwargs - these will flow through to tools
|
||||
async for event in workflow.run(
|
||||
"Please get my user data and then call the users API endpoint.",
|
||||
additional_function_arguments={"custom_data": custom_data, "user_token": user_token},
|
||||
stream=True,
|
||||
):
|
||||
if event.type == "output":
|
||||
output_data = cast(list[Message], event.data)
|
||||
if isinstance(output_data, list):
|
||||
for item in output_data:
|
||||
if isinstance(item, Message) and item.text:
|
||||
print(f"\n[Final Answer]: {item.text}")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Sample Complete")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,170 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Annotated, Any, cast
|
||||
|
||||
from agent_framework import Agent, Message, tool
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from agent_framework.orchestrations import SequentialBuilder
|
||||
from azure.identity import AzureCliCredential
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
"""
|
||||
Sample: Global Workflow kwargs
|
||||
|
||||
This sample demonstrates how to pass the same kwargs to every agent in a
|
||||
workflow using global targeting. When keys in function_invocation_kwargs do NOT
|
||||
match any executor ID (agent name), the framework treats them as global and
|
||||
delivers them to all agents.
|
||||
|
||||
Compare with workflow_kwargs_per_agent.py which targets kwargs to specific agents.
|
||||
|
||||
Key Concepts:
|
||||
- Global function_invocation_kwargs are delivered to every agent in the workflow
|
||||
- Useful when all agents share the same credentials, config, or context
|
||||
- @tool functions receive kwargs via the **kwargs parameter
|
||||
|
||||
Prerequisites:
|
||||
- FOUNDRY_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint.
|
||||
- Environment variables configured
|
||||
"""
|
||||
|
||||
|
||||
# 1. Define a tool for the research agent — queries a company's internal
|
||||
# database using credentials passed via global kwargs.
|
||||
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production;
|
||||
# see samples/02-agents/tools/function_tool_with_approval.py
|
||||
# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
|
||||
@tool(approval_mode="never_require")
|
||||
def query_company_database(
|
||||
query: Annotated[
|
||||
str, Field(description="The database query to run, e.g. 'Q3 revenue' or 'headcount by department'")
|
||||
],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Query the company's internal database for business metrics and data."""
|
||||
db_config = kwargs.get("db_config", {})
|
||||
connection_string = db_config.get("connection_string", "")
|
||||
database = db_config.get("database", "")
|
||||
|
||||
if not connection_string or not database:
|
||||
return f"ERROR: missing db_config — cannot run query '{query}'"
|
||||
|
||||
print(f"\n [query_company_database] Connecting to {database} at {connection_string[:30]}...")
|
||||
|
||||
# Simulated company data that the LLM would not know on its own
|
||||
return (
|
||||
f"Query results from {database}:\n"
|
||||
f"- Contoso Q3 2025 revenue: $47.2M (up 12% YoY)\n"
|
||||
f"- Top product line: CloudSync Pro ($18.6M)\n"
|
||||
f"- Engineering headcount: 342 (up from 298 in Q2)\n"
|
||||
f"- Customer churn rate: 4.1% (down from 5.3% in Q2)\n"
|
||||
f"- Net new enterprise customers: 28"
|
||||
)
|
||||
|
||||
|
||||
# 2. Define a tool for the writer agent — retrieves the formatting style
|
||||
# from user preferences passed via global kwargs.
|
||||
@tool(approval_mode="never_require")
|
||||
def get_formatting_instructions(
|
||||
section_title: Annotated[str, Field(description="The title of the section or report to format")],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Get the formatting instructions based on user preferences."""
|
||||
user_prefs = kwargs.get("user_preferences", {})
|
||||
output_format = user_prefs.get("format", "plain")
|
||||
language = user_prefs.get("language", "en")
|
||||
|
||||
print(f"\n [get_formatting_instructions] Format: {output_format}, Language: {language}")
|
||||
|
||||
return (
|
||||
f"Formatting rules for '{section_title}':\n"
|
||||
f"- Output format: {output_format}\n"
|
||||
f"- Language/locale: {language}\n"
|
||||
f"- Include a footer: 'Generated in {output_format} for locale {language}'"
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
print("=" * 70)
|
||||
print("Global Workflow kwargs Demo")
|
||||
print("=" * 70)
|
||||
|
||||
# 3. Create a shared chat client.
|
||||
client = FoundryChatClient(
|
||||
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
|
||||
model=os.environ["FOUNDRY_MODEL"],
|
||||
credential=AzureCliCredential(),
|
||||
)
|
||||
|
||||
# 4. Create two agents with different tools and responsibilities.
|
||||
researcher = Agent(
|
||||
client=client,
|
||||
name="researcher",
|
||||
instructions=(
|
||||
"You are a data analyst. Call query_company_database exactly once "
|
||||
"with the user's request as the query. Return the raw results."
|
||||
),
|
||||
tools=[query_company_database],
|
||||
)
|
||||
|
||||
writer = Agent(
|
||||
client=client,
|
||||
name="writer",
|
||||
instructions=(
|
||||
"You are a report writer. Call get_formatting_instructions exactly once, "
|
||||
"then rewrite the data you receive into a polished report following those rules."
|
||||
),
|
||||
tools=[get_formatting_instructions],
|
||||
)
|
||||
|
||||
# 5. Build a sequential workflow: researcher -> writer.
|
||||
workflow = SequentialBuilder(participants=[researcher, writer]).build()
|
||||
|
||||
# 6. Define global kwargs — every agent receives all of these.
|
||||
# Because the keys ("db_config", "user_preferences") do NOT match any
|
||||
# executor ID ("researcher", "writer"), the framework treats them as
|
||||
# global and delivers the full dict to every agent.
|
||||
global_fi_kwargs = {
|
||||
"db_config": {
|
||||
"connection_string": "Server=contoso-sql.database.windows.net;Database=metrics",
|
||||
"database": "contoso_metrics_prod",
|
||||
},
|
||||
"user_preferences": {
|
||||
"format": "markdown",
|
||||
"language": "en-US",
|
||||
},
|
||||
}
|
||||
|
||||
print("\nGlobal function_invocation_kwargs (sent to all agents):")
|
||||
print(json.dumps(global_fi_kwargs, indent=2))
|
||||
print("\n" + "-" * 70)
|
||||
print("Workflow Execution:")
|
||||
print("-" * 70)
|
||||
|
||||
# 7. Run the workflow — every agent receives the same global kwargs.
|
||||
async for event in workflow.run(
|
||||
"Pull Contoso's Q3 2025 performance data and write an executive summary.",
|
||||
function_invocation_kwargs=global_fi_kwargs,
|
||||
stream=True,
|
||||
):
|
||||
if event.type == "output":
|
||||
output_data = cast(list[Message], event.data)
|
||||
if isinstance(output_data, list):
|
||||
for item in output_data:
|
||||
if isinstance(item, Message) and item.text:
|
||||
print(f"\n[{item.author_name}]: {item.text}")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Sample Complete")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,222 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Annotated, Any, cast
|
||||
|
||||
from agent_framework import Agent, Message, tool
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from agent_framework.orchestrations import SequentialBuilder
|
||||
from azure.identity import AzureCliCredential
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
"""
|
||||
Sample: Per-Agent Workflow kwargs
|
||||
|
||||
This sample demonstrates how to pass different kwargs to different agents in a
|
||||
workflow using per-agent targeting. When keys in function_invocation_kwargs (or
|
||||
client_kwargs) match executor IDs (agent names by default), each agent
|
||||
receives only its own slice of the kwargs.
|
||||
|
||||
Key Concepts:
|
||||
- Per-agent function_invocation_kwargs target specific agents by executor ID
|
||||
- Agents only receive the kwargs assigned to them (not other agents' kwargs)
|
||||
- Useful when different agents need different credentials, configs, or context
|
||||
|
||||
Prerequisites:
|
||||
- FOUNDRY_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint.
|
||||
- Environment variables configured
|
||||
"""
|
||||
|
||||
|
||||
# 1. Define a tool for the research agent — queries a company's internal
|
||||
# database using credentials passed via per-agent kwargs.
|
||||
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production;
|
||||
# see samples/02-agents/tools/function_tool_with_approval.py
|
||||
# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
|
||||
@tool(approval_mode="never_require")
|
||||
def query_company_database(
|
||||
query: Annotated[
|
||||
str, Field(description="The database query to run, e.g. 'Q3 revenue' or 'headcount by department'")
|
||||
],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Query the company's internal database for business metrics and data."""
|
||||
db_config = kwargs.get("db_config", {})
|
||||
connection_string = db_config.get("connection_string", "")
|
||||
database = db_config.get("database", "")
|
||||
|
||||
if not connection_string or not database:
|
||||
return f"ERROR: missing db_config — cannot run query '{query}'"
|
||||
|
||||
print(f"\n [query_company_database] Connecting to {database} at {connection_string[:30]}...")
|
||||
|
||||
# Simulated company data that the LLM would not know on its own
|
||||
return (
|
||||
f"Query results from {database}:\n"
|
||||
f"- Contoso Q3 2025 revenue: $47.2M (up 12% YoY)\n"
|
||||
f"- Top product line: CloudSync Pro ($18.6M)\n"
|
||||
f"- Engineering headcount: 342 (up from 298 in Q2)\n"
|
||||
f"- Customer churn rate: 4.1% (down from 5.3% in Q2)\n"
|
||||
f"- Net new enterprise customers: 28"
|
||||
)
|
||||
|
||||
|
||||
# 2. Define a tool for the writer agent — retrieves the formatting style
|
||||
# from user preferences passed via per-agent kwargs.
|
||||
@tool(approval_mode="never_require")
|
||||
def get_formatting_instructions(
|
||||
section_title: Annotated[str, Field(description="The title of the section or report to format")],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Get the formatting instructions based on user preferences."""
|
||||
user_prefs = kwargs.get("user_preferences", {})
|
||||
output_format = user_prefs.get("format", "plain")
|
||||
language = user_prefs.get("language", "en")
|
||||
|
||||
print(f"\n [get_formatting_instructions] Format: {output_format}, Language: {language}")
|
||||
|
||||
return (
|
||||
f"Formatting rules for '{section_title}':\n"
|
||||
f"- Output format: {output_format}\n"
|
||||
f"- Language/locale: {language}\n"
|
||||
f"- Include a footer: 'Generated in {output_format} for locale {language}'"
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
print("=" * 70)
|
||||
print("Per-Agent Workflow kwargs Demo")
|
||||
print("=" * 70)
|
||||
|
||||
# 3. Create a shared chat client.
|
||||
client = FoundryChatClient(
|
||||
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
|
||||
model=os.environ["FOUNDRY_MODEL"],
|
||||
credential=AzureCliCredential(),
|
||||
)
|
||||
|
||||
# 4. Create two agents with different tools and responsibilities.
|
||||
researcher = Agent(
|
||||
client=client,
|
||||
name="researcher",
|
||||
instructions=(
|
||||
"You are a data analyst. Call query_company_database exactly once "
|
||||
"with the user's request as the query. Return the raw results."
|
||||
),
|
||||
tools=[query_company_database],
|
||||
)
|
||||
|
||||
writer = Agent(
|
||||
client=client,
|
||||
name="writer",
|
||||
instructions=(
|
||||
"You are a report writer. Call get_formatting_instructions exactly once, "
|
||||
"then rewrite the data you receive into a polished report following those rules."
|
||||
),
|
||||
tools=[get_formatting_instructions],
|
||||
)
|
||||
|
||||
# 5. Build a sequential workflow: researcher -> writer.
|
||||
workflow = SequentialBuilder(participants=[researcher, writer]).build()
|
||||
|
||||
# 6. Define per-agent kwargs — each agent gets only its own config.
|
||||
# The keys ("researcher", "writer") match the agent names, which are
|
||||
# used as executor IDs by default.
|
||||
per_agent_fi_kwargs = {
|
||||
"researcher": {
|
||||
"db_config": {
|
||||
"connection_string": "Server=contoso-sql.database.windows.net;Database=metrics",
|
||||
"database": "contoso_metrics_prod",
|
||||
},
|
||||
},
|
||||
"writer": {
|
||||
"user_preferences": {
|
||||
"format": "markdown",
|
||||
"language": "en-US",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
print("\nPer-agent function_invocation_kwargs:")
|
||||
print(json.dumps(per_agent_fi_kwargs, indent=2))
|
||||
print("\n" + "-" * 70)
|
||||
print("Workflow Execution:")
|
||||
print("-" * 70)
|
||||
|
||||
# 7. Run the workflow — each agent receives only its targeted kwargs.
|
||||
async for event in workflow.run(
|
||||
"Pull Contoso's Q3 2025 performance data and write an executive summary.",
|
||||
function_invocation_kwargs=per_agent_fi_kwargs,
|
||||
stream=True,
|
||||
):
|
||||
if event.type == "output":
|
||||
output_data = cast(list[Message], event.data)
|
||||
if isinstance(output_data, list):
|
||||
for item in output_data:
|
||||
if isinstance(item, Message) and item.text:
|
||||
print(f"\n[{item.author_name}]: {item.text}")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Sample Complete")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
Sample output:
|
||||
|
||||
Per-agent function_invocation_kwargs:
|
||||
{
|
||||
"researcher": {
|
||||
"db_config": {
|
||||
"connection_string": "Server=contoso-sql.database.windows.net;Database=metrics",
|
||||
"database": "contoso_metrics_prod"
|
||||
}
|
||||
},
|
||||
"writer": {
|
||||
"user_preferences": {
|
||||
"format": "markdown",
|
||||
"language": "en-US"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
----------------------------------------------------------------------
|
||||
Workflow Execution:
|
||||
----------------------------------------------------------------------
|
||||
|
||||
[query_company_database] Connecting to contoso_metrics_prod at Server=contoso-sql.database.wi...
|
||||
|
||||
[researcher]: Here is Contoso's Q3 2025 data:
|
||||
- Revenue: $47.2M (up 12% YoY)
|
||||
- Top product: CloudSync Pro ($18.6M)
|
||||
- Engineering headcount: 342
|
||||
- Churn rate: 4.1%
|
||||
- Net new enterprise customers: 28
|
||||
|
||||
[get_formatting_instructions] Format: markdown, Language: en-US
|
||||
|
||||
[writer]: # Contoso Q3 2025 Executive Summary
|
||||
|
||||
| Metric | Value |
|
||||
|---|---|
|
||||
| Revenue | $47.2M (+12% YoY) |
|
||||
| Top Product | CloudSync Pro ($18.6M) |
|
||||
| Engineering Headcount | 342 |
|
||||
| Customer Churn | 4.1% |
|
||||
| New Enterprise Customers | 28 |
|
||||
|
||||
Generated in markdown for locale en-US
|
||||
|
||||
======================================================================
|
||||
Sample Complete
|
||||
======================================================================
|
||||
"""
|
||||
Reference in New Issue
Block a user