Merge branch 'main' into local-branch-fix-workflow-as-agent-pending-request-handling

This commit is contained in:
Tao Chen
2026-06-01 22:15:14 -07:00
Unverified
116 changed files with 9188 additions and 4649 deletions
@@ -71,6 +71,7 @@ from ._evaluation import (
Evaluator,
ExpectedToolCall,
LocalEvaluator,
RubricScore,
evaluate_agent,
evaluate_workflow,
evaluator,
@@ -460,6 +461,7 @@ __all__ = [
"ResponseStream",
"Role",
"RoleLiteral",
"RubricScore",
"RunContext",
"Runner",
"RunnerContext",
@@ -311,12 +311,15 @@ class EvalScoreResult:
score: Numeric score from the evaluator.
passed: Whether the item passed this evaluator's threshold.
sample: Optional raw evaluator output (rationale, metadata).
dimensions: Per-dimension scores when this evaluator is a rubric
evaluator. ``None`` for non-rubric (e.g. built-in) evaluators.
"""
name: str
score: float
passed: bool | None = None
sample: dict[str, Any] | None = None
dimensions: list[RubricScore] | None = None
@experimental(feature_id=ExperimentalFeature.EVALS)
@@ -496,6 +499,179 @@ class EvalResults:
detail += f" Errored items: {', '.join(summaries)}."
raise EvalNotPassedError(detail)
def assert_score_at_least(
self,
min_score: float,
*,
evaluator: str | None = None,
msg: str | None = None,
) -> None:
"""Assert every item's score (optionally filtered by evaluator) is ``>= min_score``.
Designed for CI gates on generated rubric evaluators (e.g.
``results.assert_score_at_least(0.80)``). Includes any
sub-results from workflow evaluations.
Args:
min_score: Minimum acceptable score (inclusive).
evaluator: When set, only check scores from the evaluator
whose ``EvalScoreResult.name`` matches.
msg: Optional custom failure message.
Raises:
EvalNotPassedError: When any matching score is below the threshold.
"""
offenders: list[str] = []
def _check(results: EvalResults) -> None:
for item in results.items:
for score in item.scores:
if evaluator is not None and score.name != evaluator:
continue
if score.score < min_score:
offenders.append(f"{item.item_id}/{score.name}={score.score:.3f}")
for sub in results.sub_results.values():
_check(sub)
_check(self)
if offenders:
detail = msg or (
f"{len(offenders)} score(s) below threshold {min_score}"
f"{' for ' + evaluator if evaluator else ''}: {', '.join(offenders[:5])}"
+ (f" (+{len(offenders) - 5} more)" if len(offenders) > 5 else "")
)
raise EvalNotPassedError(detail)
def assert_dimension_score_at_least(
self,
dimension_id: str,
min_score: float,
*,
evaluator: str | None = None,
require_applicable: bool = False,
msg: str | None = None,
) -> None:
"""Assert every item's score for a rubric *dimension* is ``>= min_score``.
Walks ``EvalScoreResult.dimensions`` looking for the named
dimension across all items (and sub-results). Non-applicable
dimensions are skipped by default; pass
``require_applicable=True`` to fail when no applicable score is
produced.
Args:
dimension_id: Dimension id (matches the rubric definition).
min_score: Minimum acceptable dimension score (inclusive).
evaluator: When set, only consider scores from the evaluator
whose ``EvalScoreResult.name`` matches.
require_applicable: When ``True``, missing or non-applicable
dimension scores raise. Defaults to ``False`` (skip).
msg: Optional custom failure message.
Raises:
EvalNotPassedError: When the dimension fails the threshold.
"""
offenders: list[str] = []
missing_items: list[str] = []
def _check(results: EvalResults) -> None:
for item in results.items:
found_applicable = False
for score in item.scores:
if evaluator is not None and score.name != evaluator:
continue
if not score.dimensions:
continue
for rs in score.dimensions:
if rs.id != dimension_id:
continue
if not rs.applicable:
continue
found_applicable = True
if rs.score is None or rs.score < min_score:
offenders.append(
f"{item.item_id}/{score.name}/{dimension_id}="
f"{rs.score if rs.score is not None else 'None'}"
)
if require_applicable and not found_applicable:
missing_items.append(item.item_id)
for sub in results.sub_results.values():
_check(sub)
_check(self)
problems: list[str] = []
if offenders:
problems.append(
f"{len(offenders)} dimension score(s) for '{dimension_id}' below {min_score}: "
f"{', '.join(offenders[:5])}" + (f" (+{len(offenders) - 5} more)" if len(offenders) > 5 else "")
)
if missing_items:
problems.append(
f"Dimension '{dimension_id}' not applicable on {len(missing_items)} item(s): "
f"{', '.join(missing_items[:5])}"
)
if problems:
raise EvalNotPassedError(msg or "; ".join(problems))
def assert_no_failed_items(self, msg: str | None = None) -> None:
"""Assert no item ended in ``fail`` or ``error`` status.
Includes any sub-results from workflow evaluations.
Args:
msg: Optional custom failure message.
Raises:
EvalNotPassedError: When any item failed or errored.
"""
bad: list[str] = []
def _check(results: EvalResults) -> None:
for item in results.items:
if item.is_failed or item.is_error:
bad.append(f"{item.item_id}:{item.status}")
for sub in results.sub_results.values():
_check(sub)
_check(self)
if bad:
detail = msg or (
f"{len(bad)} item(s) failed or errored: {', '.join(bad[:5])}"
+ (f" (+{len(bad) - 5} more)" if len(bad) > 5 else "")
)
raise EvalNotPassedError(detail)
# endregion
# region Generated rubric evaluators
@experimental(feature_id=ExperimentalFeature.EVALS)
@dataclass(frozen=True)
class RubricScore:
"""A single dimension's score from a rubric-based evaluator run.
Rubric evaluators emit one ``RubricScore`` per dimension per item.
Attached to :class:`EvalScoreResult` as a typed view of the raw
``properties.rubric_scores`` payload returned by providers such as
Foundry's generated rubric evaluators.
Attributes:
id: Dimension id (matches the rubric definition).
score: Numeric score, or ``None`` when the dimension was marked
non-applicable for this item.
applicable: Whether the dimension applied to this item.
weight: Dimension weight (mirrors the rubric definition).
reason: Short rationale produced by the evaluator.
"""
id: str
score: int | None
applicable: bool
weight: int
reason: str
# endregion
@@ -14,12 +14,13 @@ import logging
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any
from .._agents import Agent
from .._agents import Agent, SupportsAgentRun
from .._clients import SupportsWebSearchTool
from .._compaction import CompactionProvider, ContextWindowCompactionStrategy, ToolResultCompactionStrategy
from .._feature_stage import ExperimentalFeature, experimental
from .._sessions import ContextProvider, HistoryProvider, InMemoryHistoryProvider
from .._skills import SkillsProvider
from ._background_agents import BackgroundAgentsProvider
from ._memory import MemoryContextProvider, MemoryStore
from ._mode import AgentModeProvider
from ._todo import TodoProvider
@@ -103,6 +104,8 @@ def _assemble_context_providers(
memory_store: MemoryStore | None,
skills_provider: SkillsProvider | None,
skills_paths: Sequence[str] | None,
background_agents: Sequence[SupportsAgentRun] | None,
background_agents_instructions: str | None,
extra_context_providers: Sequence[ContextProvider] | None,
) -> list[ContextProvider]:
"""Assemble the ordered list of context providers."""
@@ -130,6 +133,10 @@ def _assemble_context_providers(
if skills_paths:
providers.append(SkillsProvider.from_paths(*skills_paths))
# Background agents are opt-in: only added when agents are provided.
if background_agents:
providers.append(BackgroundAgentsProvider(background_agents, instructions=background_agents_instructions))
# Append any user-supplied additional providers.
if extra_context_providers:
providers.extend(extra_context_providers)
@@ -165,6 +172,8 @@ def create_harness_agent(
memory_store: MemoryStore | None = None,
skills_provider: SkillsProvider | None = None,
skills_paths: Sequence[str] | None = None,
background_agents: Sequence[SupportsAgentRun] | None = None,
background_agents_instructions: str | None = None,
disable_web_search: bool = False,
otel_provider_name: str | None = None,
context_providers: Sequence[ContextProvider] | None = None,
@@ -182,6 +191,7 @@ def create_harness_agent(
- **AgentModeProvider** — plan/execute mode tracking
- **MemoryContextProvider** — file-based durable memory (when ``memory_store`` provided)
- **SkillsProvider** — skill discovery and progressive loading
- **BackgroundAgentsProvider** — delegate work to background sub-agents
- **OpenTelemetry** — observability via ``AgentTelemetryLayer``
Each feature can be disabled or customized via keyword arguments.
@@ -253,6 +263,13 @@ def create_harness_agent(
skills_paths: Paths for file-based skill discovery (looks for SKILL.md files).
Can be combined with ``skills_provider``. When neither ``skills_provider``
nor ``skills_paths`` is provided, no SkillsProvider is added.
background_agents: Collection of agents available for background task delegation.
When provided, a ``BackgroundAgentsProvider`` is automatically included,
enabling the agent to start, monitor, and retrieve results from background tasks.
Each agent must have a non-empty, unique name (case-insensitive).
background_agents_instructions: Optional instruction override for the
``BackgroundAgentsProvider``. May include ``{background_agents}`` placeholder
which will be replaced with the agent listing.
disable_web_search: When True, skip automatic web search tool inclusion.
When False (default), the web search tool is automatically added if the
client implements SupportsWebSearchTool. A warning is logged if the client
@@ -302,6 +319,8 @@ def create_harness_agent(
memory_store=memory_store,
skills_provider=skills_provider,
skills_paths=skills_paths,
background_agents=background_agents,
background_agents_instructions=background_agents_instructions,
extra_context_providers=context_providers,
)
+72 -20
View File
@@ -10,7 +10,7 @@ import logging
import re
import sys
from abc import abstractmethod
from collections.abc import Callable, Collection, Coroutine, Sequence
from collections.abc import Callable, Collection, Coroutine, Mapping, Sequence
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
from datetime import timedelta
from functools import partial
@@ -142,6 +142,13 @@ def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str,
return meta
def _url_origin(url: Any) -> tuple[str, str, int | None]:
port = url.port
if port is None:
port = 443 if url.scheme == "https" else 80 if url.scheme == "http" else None
return (url.scheme, url.host or "", port)
def streamable_http_client(*args: Any, **kwargs: Any) -> _AsyncGeneratorContextManager[Any, None]:
"""Lazily import the MCP streamable HTTP transport."""
try:
@@ -255,6 +262,7 @@ class MCPTool:
self._exit_stack = AsyncExitStack()
self._lifecycle_lock = asyncio.Lock()
self._lifecycle_request_lock = asyncio.Lock()
self._function_load_lock = asyncio.Lock()
self._lifecycle_queue: asyncio.Queue[tuple[str, bool, bool, asyncio.Future[None]]] | None = None
self._lifecycle_owner_task: asyncio.Task[None] | None = None
self.session = session
@@ -655,6 +663,11 @@ class MCPTool:
raise
except asyncio.CancelledError:
logger.warning("Could not cleanly close MCP exit stack because the lifecycle owner task was cancelled.")
except Exception as e:
if type(e).__name__ == "ExceptionGroup":
logger.warning("Could not cleanly close MCP exit stack due to cleanup error group. Error: %s", e)
else:
raise
async def _close_and_check_cancelled(self, ex: BaseException) -> bool:
"""Close the exit stack and return True if *ex* is a genuine task cancellation.
@@ -1018,6 +1031,10 @@ class MCPTool:
Raises:
ToolExecutionException: If the MCP server is not connected.
"""
async with self._function_load_lock:
await self._load_prompts_locked()
async def _load_prompts_locked(self) -> None:
from anyio import ClosedResourceError
from mcp import types
@@ -1100,6 +1117,10 @@ class MCPTool:
Raises:
ToolExecutionException: If the MCP server is not connected.
"""
async with self._function_load_lock:
await self._load_tools_locked()
async def _load_tools_locked(self) -> None:
from anyio import ClosedResourceError
from mcp import types
@@ -1109,7 +1130,7 @@ class MCPTool:
# Track existing function names to prevent duplicates
existing_names = {func.name for func in self._functions}
self._tool_call_meta_by_name.clear()
tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
params: types.PaginatedRequestParams | None = None
while True:
@@ -1145,7 +1166,7 @@ class MCPTool:
for tool in tool_list.tools:
if tool.meta is not None:
self._tool_call_meta_by_name[tool.name] = dict(tool.meta)
tool_call_meta_by_name[tool.name] = dict(tool.meta)
normalized_name = _normalize_mcp_name(tool.name)
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
@@ -1194,6 +1215,8 @@ class MCPTool:
break
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)
self._tool_call_meta_by_name = tool_call_meta_by_name
async def _close_on_owner(self) -> None:
# Cancel any pending reload tasks before tearing down the session.
tasks = list(self._pending_reload_tasks)
@@ -1276,7 +1299,11 @@ class MCPTool:
tool_name: The name of the tool to call.
Keyword Args:
kwargs: Arguments to pass to the tool.
_meta: Optional ``dict[str, Any]`` of MCP request metadata. This reserved key is passed as the
``meta`` parameter of the underlying ``session.call_tool`` call rather than as a tool argument.
User-supplied keys override metadata from ``tools/list``; OpenTelemetry propagation fills in
non-conflicting keys.
kwargs: Remaining arguments to pass to the tool.
Returns:
A list of Content items representing the tool output. The default
@@ -1294,6 +1321,19 @@ class MCPTool:
raise ToolExecutionException(
"Tools are not loaded for this server, please set load_tools=True in the constructor."
)
raw_user_meta: object | None = kwargs.get("_meta")
user_meta: dict[str, Any] | None = None
if raw_user_meta is not None and not isinstance(raw_user_meta, dict):
raise ToolExecutionException("MCP tool metadata provided via _meta must be a dict.")
if isinstance(raw_user_meta, dict):
raw_user_meta_dict = cast(Mapping[object, object], raw_user_meta)
user_meta = {}
for key, value in raw_user_meta_dict.items():
if not isinstance(key, str):
raise ToolExecutionException("MCP tool metadata provided via _meta must use string keys.")
user_meta[key] = value
# Filter out framework kwargs that cannot be serialized by the MCP SDK.
# These are internal objects passed through the function invocation pipeline
# that should not be forwarded to external MCP servers.
@@ -1313,12 +1353,16 @@ class MCPTool:
"conversation_id",
"options",
"response_format",
"_meta",
}
}
# Some MCP proxies require their tools/list metadata to be echoed on tools/call.
tool_meta = self._tool_call_meta_by_name.get(tool_name)
meta = _inject_otel_into_mcp_meta(dict(tool_meta) if tool_meta is not None else None)
request_meta = dict(tool_meta) if tool_meta is not None else None
if user_meta is not None:
request_meta = {**(request_meta or {}), **user_meta}
meta = _inject_otel_into_mcp_meta(request_meta)
parser = self.parse_tool_results or self._parse_tool_result_from_mcp
# Try the operation, reconnecting once if the connection is closed
@@ -1336,28 +1380,33 @@ class MCPTool:
return parser(result)
except ToolExecutionException:
raise
except ClosedResourceError as cl_ex:
except (ClosedResourceError, McpError) as call_ex:
is_session_terminated = (
isinstance(call_ex, McpError) and "session terminated" in call_ex.error.message.lower()
)
is_connection_lost = isinstance(call_ex, ClosedResourceError) or is_session_terminated
if not is_connection_lost:
error_message = call_ex.error.message if isinstance(call_ex, McpError) else str(call_ex)
raise ToolExecutionException(error_message, inner_exception=call_ex) from call_ex
if attempt == 0:
# First attempt failed, try reconnecting
logger.info("MCP connection closed unexpectedly. Reconnecting...")
# First attempt failed, try reconnecting.
logger.info("MCP connection closed or terminated unexpectedly. Reconnecting...")
try:
await self.connect(reset=True)
continue # Retry the operation
continue
except Exception as reconn_ex:
raise ToolExecutionException(
"Failed to reconnect to MCP server.",
inner_exception=reconn_ex,
) from reconn_ex
else:
# Second attempt also failed, give up
logger.error(f"MCP connection closed unexpectedly after reconnection: {cl_ex}")
raise ToolExecutionException(
f"Failed to call tool '{tool_name}' - connection lost.",
inner_exception=cl_ex,
) from cl_ex
except McpError as mcp_exc:
error_message = mcp_exc.error.message
raise ToolExecutionException(error_message, inner_exception=mcp_exc) from mcp_exc
# Second attempt also failed, give up.
logger.error("MCP connection closed unexpectedly after reconnection: %s", call_ex)
raise ToolExecutionException(
f"Failed to call tool '{tool_name}' - connection lost.",
inner_exception=call_ex,
) from call_ex
except Exception as ex:
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.")
@@ -1718,10 +1767,11 @@ class MCPStreamableHTTPTool(MCPTool):
Returns:
An async context manager for the streamable HTTP client transport.
"""
from httpx import AsyncClient, Request, Timeout
from httpx import URL, AsyncClient, Request, Timeout
http_client = self._httpx_client
if self._header_provider is not None:
target_origin = _url_origin(URL(self.url))
if http_client is None:
http_client = AsyncClient(
follow_redirects=True,
@@ -1732,6 +1782,8 @@ class MCPStreamableHTTPTool(MCPTool):
if not hasattr(self, "_inject_headers_hook"):
async def _inject_headers(request: Request) -> None: # noqa: RUF029
if _url_origin(request.url) != target_origin:
return
headers = _mcp_call_headers.get({})
for key, value in headers.items():
request.headers[key] = value
@@ -36,11 +36,10 @@ if TYPE_CHECKING:
from pydantic import BaseModel
from ._agents import SupportsAgentRun
from ._clients import SupportsChatGetResponse
from ._compaction import CompactionStrategy, TokenizerProtocol
from ._sessions import AgentSession
from ._tools import FunctionTool, ToolTypes
from ._types import ChatOptions, ChatResponse, ChatResponseUpdate
from ._types import ChatOptions
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
@@ -7,6 +7,8 @@ import json
import logging
import re
from collections.abc import Mapping, MutableMapping
from dataclasses import asdict, is_dataclass
from datetime import date, datetime
from typing import Any, ClassVar, Protocol, TypeVar, runtime_checkable
logger = logging.getLogger("agent_framework")
@@ -614,3 +616,46 @@ class SerializationMixin:
# Fallback and default
# Convert class name to snake_case
return _CAMEL_TO_SNAKE_PATTERN.sub("_", cls.__name__).lower()
def make_json_safe(obj: Any) -> Any:
"""Recursively convert an object to a JSON-serializable form.
Handles dataclasses, Pydantic models, objects with ``to_dict``/``dict``/``__dict__``,
datetimes, lists, dicts, and primitives. Falls back to ``str()`` for any remaining
non-serializable value so that ``json.dumps`` never raises a ``TypeError``.
Args:
obj: Object to make JSON safe.
Returns:
A JSON-serializable version of the object.
"""
if obj is None or isinstance(obj, (str, int, float, bool)):
return obj
if isinstance(obj, (datetime, date)):
return obj.isoformat()
if is_dataclass(obj) and not isinstance(obj, type):
return make_json_safe(asdict(obj)) # type: ignore[arg-type]
if callable(getattr(obj, "model_dump", None)):
try:
return make_json_safe(obj.model_dump()) # type: ignore[no-any-return]
except TypeError:
pass
if callable(getattr(obj, "to_dict", None)):
try:
return make_json_safe(obj.to_dict()) # type: ignore[no-any-return]
except TypeError:
pass
if callable(getattr(obj, "dict", None)):
try:
return make_json_safe(obj.dict()) # type: ignore[no-any-return]
except TypeError:
pass
if isinstance(obj, dict):
return {str(key): make_json_safe(value) for key, value in obj.items()} # type: ignore[misc]
if isinstance(obj, (list, tuple)):
return [make_json_safe(item) for item in obj] # type: ignore[misc]
if hasattr(obj, "__dict__"):
return {key: make_json_safe(value) for key, value in vars(obj).items()} # type: ignore[misc]
return str(obj)
@@ -1973,11 +1973,97 @@ def _coalesce_text_content(contents: list[Content], type_str: Literal["text", "t
contents.extend(coalesced_contents)
def _content_items_text(items: Any) -> str | None:
"""Return concatenated text when a content item list only contains text."""
if not isinstance(items, list):
return None
text_parts: list[str] = []
content_items = cast(list[object], items)
for item in content_items:
if not isinstance(item, Content) or item.type != "text":
return None
text_parts.append(item.text or "")
return "".join(text_parts)
def _merge_content_item_lists(existing: Any, incoming: Any) -> Any:
"""Merge streamed nested content lists, replacing deltas with a later full value when present."""
if incoming is None:
return existing
if existing is None:
return deepcopy(incoming)
existing_text = _content_items_text(existing)
incoming_text = _content_items_text(incoming)
if existing_text is not None and incoming_text is not None:
if incoming_text.startswith(existing_text):
return deepcopy(incoming)
if existing_text.startswith(incoming_text):
return existing
existing_items = cast(list[Content], existing)
merged = deepcopy(existing_items[0])
merged.text = existing_text + incoming_text
return [merged]
if isinstance(existing, list) and isinstance(incoming, list):
existing_list = cast(list[object], existing)
incoming_list = cast(list[object], incoming)
return [*existing_list, *deepcopy(incoming_list)]
return deepcopy(incoming)
def _merge_code_interpreter_content(existing: Content, incoming: Content) -> None:
"""Merge two code interpreter content items for the same logical call."""
existing.inputs = _merge_content_item_lists(existing.inputs, incoming.inputs)
existing.outputs = _merge_content_item_lists(existing.outputs, incoming.outputs)
existing.annotations = _combine_annotations(existing.annotations, incoming.annotations)
existing.additional_properties = {**existing.additional_properties, **incoming.additional_properties}
existing.raw_representation = _combine_raw_representations(existing.raw_representation, incoming.raw_representation)
def _code_interpreter_key(content: Content) -> tuple[str, str] | None:
"""Return the aggregation key for code interpreter call/result content."""
if content.type not in {"code_interpreter_tool_call", "code_interpreter_tool_result"}:
return None
call_id = content.call_id or content.additional_properties.get("item_id")
if not isinstance(call_id, str) or not call_id:
return None
return content.type, call_id
def _coalesce_code_interpreter_content(contents: list[Content]) -> None:
"""Coalesce streaming code interpreter chunks by call id."""
if not contents:
return
coalesced_contents: list[Content] = []
seen: dict[tuple[str, str], Content] = {}
for content in contents:
key = _code_interpreter_key(content)
if key is None:
coalesced_contents.append(content)
continue
existing = seen.get(key)
if existing is None:
copied = deepcopy(content)
seen[key] = copied
coalesced_contents.append(copied)
continue
_merge_code_interpreter_content(existing, content)
contents.clear()
contents.extend(coalesced_contents)
def _finalize_response(response: ChatResponse | AgentResponse) -> None:
"""Finalizes the response by performing any necessary post-processing."""
for msg in response.messages:
_coalesce_text_content(msg.contents, "text")
_coalesce_text_content(msg.contents, "text_reasoning")
_coalesce_code_interpreter_content(msg.contents)
# region ContinuationToken
@@ -12,6 +12,7 @@ from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload
from .._agents import BaseAgent
from .._serialization import make_json_safe
from .._sessions import (
AgentSession,
ContextProvider,
@@ -62,7 +63,7 @@ class WorkflowAgent(BaseAgent):
data: Any
def to_dict(self) -> dict[str, Any]:
return {"request_id": self.request_id, "data": self.data}
return {"request_id": self.request_id, "data": make_json_safe(self.data)}
def to_json(self) -> str:
return json.dumps(self.to_dict())
@@ -47,6 +47,7 @@ from copy import deepcopy
from typing import Any, Generic, Literal, TypeVar, overload
from .._feature_stage import ExperimentalFeature, experimental
from .._serialization import make_json_safe
from .._types import AgentResponse, AgentResponseUpdate, ResponseStream
from ..observability import OtelAttr, capture_exception, create_workflow_span
from ._checkpoint import CheckpointStorage, WorkflowCheckpoint
@@ -1515,7 +1516,7 @@ class FunctionalWorkflowAgent:
function_call = Content.from_function_call(
call_id=request_id,
name=self.REQUEST_INFO_FUNCTION_NAME,
arguments={"request_id": request_id, "data": event.data},
arguments={"request_id": request_id, "data": make_json_safe(event.data)},
)
return Content.from_function_approval_request(
id=request_id,
@@ -34,6 +34,7 @@ _IMPORTS: dict[str, tuple[str, str]] = {
"FoundryLocalChatOptions": ("agent_framework_foundry_local", "agent-framework-foundry-local"),
"FoundryLocalClient": ("agent_framework_foundry_local", "agent-framework-foundry-local"),
"FoundryLocalSettings": ("agent_framework_foundry_local", "agent-framework-foundry-local"),
"GeneratedEvaluatorRef": ("agent_framework_foundry", "agent-framework-foundry"),
"RawAnthropicFoundryClient": ("agent_framework_anthropic", "agent-framework-anthropic"),
"RawFoundryAgent": ("agent_framework_foundry", "agent-framework-foundry"),
"RawFoundryAgentChatClient": ("agent_framework_foundry", "agent-framework-foundry"),
@@ -20,6 +20,7 @@ from agent_framework_foundry import (
FoundryEmbeddingSettings,
FoundryEvals,
FoundryMemoryProvider,
GeneratedEvaluatorRef,
RawFoundryAgent,
RawFoundryAgentChatClient,
RawFoundryChatClient,
@@ -52,6 +53,7 @@ __all__ = [
"FoundryLocalClient",
"FoundryLocalSettings",
"FoundryMemoryProvider",
"GeneratedEvaluatorRef",
"RawAnthropicFoundryClient",
"RawFoundryAgent",
"RawFoundryAgentChatClient",
@@ -394,3 +394,94 @@ def test_create_harness_agent_logs_warning_when_no_web_search(caplog: pytest.Log
max_output_tokens=16_384,
)
assert any("SupportsWebSearchTool" in msg for msg in caplog.messages)
# --- Background Agents Tests ---
class _FakeBackgroundAgent:
"""Minimal agent stub satisfying SupportsAgentRun for background agents tests."""
def __init__(self, name: str, description: str | None = None):
self.id = f"agent-{name}"
self.name = name
self.description = description
def create_session(self, *, session_id: str | None = None) -> AgentSession:
return AgentSession(session_id=session_id)
def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession:
return AgentSession(service_session_id=service_session_id, session_id=session_id)
async def run(self, messages: Any = None, *, stream: bool = False, session: Any = None, **kwargs: Any) -> Any:
from agent_framework import AgentResponse
return AgentResponse(messages=[], response_id="fake-bg-response")
def test_create_harness_agent_no_background_agents_by_default() -> None:
"""No BackgroundAgentsProvider should be included when background_agents is not provided."""
from agent_framework._harness._background_agents import BackgroundAgentsProvider
agent = create_harness_agent(
client=_FakeChatClient(), # type: ignore[arg-type]
max_context_window_tokens=128_000,
max_output_tokens=16_384,
disable_web_search=True,
)
providers = agent.context_providers or []
assert not any(isinstance(p, BackgroundAgentsProvider) for p in providers)
def test_create_harness_agent_adds_background_agents_provider() -> None:
"""BackgroundAgentsProvider should be included when background_agents are provided."""
from agent_framework._harness._background_agents import BackgroundAgentsProvider
bg_agent = _FakeBackgroundAgent("WebSearcher", "Searches the web")
agent = create_harness_agent(
client=_FakeChatClient(), # type: ignore[arg-type]
max_context_window_tokens=128_000,
max_output_tokens=16_384,
disable_web_search=True,
background_agents=[bg_agent],
)
providers = agent.context_providers or []
bg_providers = [p for p in providers if isinstance(p, BackgroundAgentsProvider)]
assert len(bg_providers) == 1
def test_create_harness_agent_background_agents_custom_instructions() -> None:
"""Custom instructions should be passed to BackgroundAgentsProvider."""
from agent_framework._harness._background_agents import BackgroundAgentsProvider
custom_instructions = "## Custom\n\nUse agents wisely.\n\n{background_agents}"
bg_agent = _FakeBackgroundAgent("Helper", "A helper agent")
agent = create_harness_agent(
client=_FakeChatClient(), # type: ignore[arg-type]
max_context_window_tokens=128_000,
max_output_tokens=16_384,
disable_web_search=True,
background_agents=[bg_agent],
background_agents_instructions=custom_instructions,
)
providers = agent.context_providers or []
bg_providers = [p for p in providers if isinstance(p, BackgroundAgentsProvider)]
assert len(bg_providers) == 1
# Verify the custom instructions were used (placeholder replaced with agent list).
assert "Custom" in bg_providers[0]._instructions
assert "Helper" in bg_providers[0]._instructions
def test_create_harness_agent_empty_background_agents_list() -> None:
"""An empty background_agents list should NOT add a BackgroundAgentsProvider."""
from agent_framework._harness._background_agents import BackgroundAgentsProvider
agent = create_harness_agent(
client=_FakeChatClient(), # type: ignore[arg-type]
max_context_window_tokens=128_000,
max_output_tokens=16_384,
disable_web_search=True,
background_agents=[],
)
providers = agent.context_providers or []
assert not any(isinstance(p, BackgroundAgentsProvider) for p in providers)
@@ -11,8 +11,13 @@ import pytest
from agent_framework._evaluation import (
CheckResult,
EvalItem,
EvalItemResult,
EvalNotPassedError,
EvalResults,
EvalScoreResult,
ExpectedToolCall,
LocalEvaluator,
RubricScore,
_coerce_result,
evaluator,
keyword_check,
@@ -1010,19 +1015,101 @@ class TestAllPassedSubResults:
# ---------------------------------------------------------------------------
# r5 review: _build_overall_item with empty outputs
# Rubric assertions (EvalResults.assert_*)
# ---------------------------------------------------------------------------
class TestBuildOverallItemEmpty:
"""Test _build_overall_item returns None for empty workflow outputs."""
def _rubric_results(*scores_per_item: list[EvalScoreResult]) -> EvalResults:
items = [
EvalItemResult(item_id=f"item-{i}", status="pass", scores=scores) for i, scores in enumerate(scores_per_item)
]
return EvalResults(
provider="test",
eval_id="ev1",
run_id="run1",
result_counts={"passed": len(items), "failed": 0, "errored": 0, "total": len(items)},
items=items,
)
def test_returns_none_for_empty_outputs(self):
from unittest.mock import MagicMock
from agent_framework._evaluation import _build_overall_item
class TestRubricAssertions:
"""Tests for EvalResults.assert_dimension_score_at_least."""
mock_result = MagicMock()
mock_result.get_outputs.return_value = []
item = _build_overall_item("Hello", mock_result)
assert item is None
def test_dimension_at_or_above_threshold_passes(self) -> None:
results = _rubric_results(
[
EvalScoreResult(
name="policy",
score=0.9,
dimensions=[RubricScore(id="clarity", score=4, applicable=True, weight=1, reason="")],
)
],
)
# Should not raise.
results.assert_dimension_score_at_least("clarity", 3)
def test_dimension_below_threshold_raises(self) -> None:
results = _rubric_results(
[
EvalScoreResult(
name="policy",
score=0.5,
dimensions=[RubricScore(id="clarity", score=2, applicable=True, weight=1, reason="")],
)
],
)
with pytest.raises(EvalNotPassedError):
results.assert_dimension_score_at_least("clarity", 3)
def test_non_applicable_skipped_by_default(self) -> None:
results = _rubric_results(
[
EvalScoreResult(
name="policy",
score=1.0,
dimensions=[RubricScore(id="clarity", score=None, applicable=False, weight=1, reason="n/a")],
)
],
)
# No applicable scores; default behaviour is to skip silently.
results.assert_dimension_score_at_least("clarity", 3)
def test_require_applicable_raises_when_dimension_absent(self) -> None:
results = _rubric_results(
[EvalScoreResult(name="policy", score=1.0, dimensions=[])],
)
with pytest.raises(EvalNotPassedError, match="not applicable"):
results.assert_dimension_score_at_least("clarity", 3, require_applicable=True)
def test_require_applicable_raises_when_filtered_evaluator_missing(self) -> None:
# Regression: previously the (not evaluator or found_any) guard caused
# this case to silently pass even with require_applicable=True.
results = _rubric_results(
[
EvalScoreResult(
name="other",
score=0.9,
dimensions=[RubricScore(id="clarity", score=4, applicable=True, weight=1, reason="")],
)
],
)
with pytest.raises(EvalNotPassedError, match="not applicable"):
results.assert_dimension_score_at_least("clarity", 3, evaluator="policy", require_applicable=True)
def test_evaluator_filter_isolates_offenders(self) -> None:
results = _rubric_results(
[
EvalScoreResult(
name="other",
score=0.1,
dimensions=[RubricScore(id="clarity", score=1, applicable=True, weight=1, reason="")],
),
EvalScoreResult(
name="policy",
score=0.9,
dimensions=[RubricScore(id="clarity", score=4, applicable=True, weight=1, reason="")],
),
],
)
# The low-scoring "other" evaluator is filtered out; "policy" passes.
results.assert_dimension_score_at_least("clarity", 3, evaluator="policy")
+206
View File
@@ -1161,6 +1161,43 @@ async def test_local_mcp_server_function_execution_error():
await func.invoke(param="test_value")
async def test_mcp_tool_reconnects_after_session_terminated_error():
"""Session termination errors should reconnect once and retry the tool call."""
class TestServer(MCPTool):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.connect_count = 0
self.sessions: list[Any] = []
async def connect(self, *, reset: bool = False) -> None:
self.connect_count += 1
self.session = Mock(spec=ClientSession)
self.sessions.append(self.session)
if self.connect_count == 1:
self.session.call_tool = AsyncMock(
side_effect=McpError(types.ErrorData(code=-32000, message="Session terminated"))
)
else:
self.session.call_tool = AsyncMock(
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="recovered")])
)
self.is_connected = True
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
return None
server = TestServer(name="test_server")
await server.connect()
result = await server.call_tool("test_tool", param="test_value")
assert _mcp_result_to_text(result) == "recovered"
assert server.connect_count == 2
assert server.sessions[0].call_tool.await_count == 1
assert server.sessions[1].call_tool.await_count == 1
async def test_mcp_tool_call_tool_raises_on_is_error():
"""Test that call_tool raises ToolExecutionException when MCP returns isError=True."""
@@ -3260,6 +3297,68 @@ async def test_load_prompts_pagination_with_duplicates():
assert [f.name for f in tool._functions] == ["prompt_1", "prompt_2"]
async def test_load_tools_concurrent_reload_does_not_duplicate_tools_and_preserves_meta():
"""Concurrent tool reloads should not duplicate functions or lose tools/list metadata."""
tool = MCPTool(name="test_tool")
mock_session = AsyncMock()
tool.session = mock_session
tool.load_tools_flag = True
page = Mock()
page.tools = [
types.Tool(
name="tool_1",
description="First tool",
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
_meta={"echo": "tool_1"},
),
]
page.nextCursor = None
async def mock_list_tools(params: Any = None) -> Any:
assert params is None
await asyncio.sleep(0)
return page
mock_session.list_tools = AsyncMock(side_effect=mock_list_tools)
await asyncio.wait_for(asyncio.gather(tool.load_tools(), tool.load_tools()), timeout=1)
assert mock_session.list_tools.call_count == 2
assert [f.name for f in tool._functions] == ["tool_1"]
assert tool._tool_call_meta_by_name == {"tool_1": {"echo": "tool_1"}}
async def test_load_prompts_concurrent_reload_does_not_duplicate_prompts():
"""Concurrent prompt reloads should not duplicate functions."""
tool = MCPTool(name="test_tool")
mock_session = AsyncMock()
tool.session = mock_session
tool.load_prompts_flag = True
page = Mock()
page.prompts = [
types.Prompt(
name="prompt_1",
description="First prompt",
arguments=[types.PromptArgument(name="arg1", description="Arg 1", required=True)],
),
]
page.nextCursor = None
async def mock_list_prompts(params: Any = None) -> Any:
assert params is None
await asyncio.sleep(0)
return page
mock_session.list_prompts = AsyncMock(side_effect=mock_list_prompts)
await asyncio.wait_for(asyncio.gather(tool.load_prompts(), tool.load_prompts()), timeout=1)
assert mock_session.list_prompts.call_count == 2
assert [f.name for f in tool._functions] == ["prompt_1"]
async def test_load_tools_pagination_exception_handling():
"""Test that load_tools handles exceptions during pagination gracefully."""
from unittest.mock import AsyncMock
@@ -3891,6 +3990,31 @@ async def test_mcp_tool_safe_close_handles_cancelled_error():
mock_exit_stack.aclose.assert_called_once()
async def test_mcp_tool_safe_close_handles_cleanup_exception_group():
"""Cleanup task groups should not hide the original connect failure."""
import builtins
from contextlib import AsyncExitStack
exception_group_type = getattr(builtins, "ExceptionGroup", None)
if exception_group_type is None:
pytest.skip("ExceptionGroup is not available on this Python version")
tool = MCPStreamableHTTPTool(
name="test",
url="http://example.com/mcp",
load_tools=False,
load_prompts=False,
)
mock_exit_stack = AsyncMock(spec=AsyncExitStack)
mock_exit_stack.aclose = AsyncMock(side_effect=exception_group_type("cleanup failed", [RuntimeError("reader")]))
tool._exit_stack = mock_exit_stack
await tool._safe_close_exit_stack()
mock_exit_stack.aclose.assert_called_once()
async def test_connect_sets_logging_level_when_logger_level_is_set():
"""Test that connect() sets the MCP server logging level when the logger level is not NOTSET."""
@@ -4389,6 +4513,52 @@ async def test_mcp_tool_call_tool_forwards_tool_list_meta():
assert server.session.call_tool.call_args.kwargs["meta"] == tool_meta
async def test_mcp_tool_call_tool_user_meta_merges_with_tool_list_meta():
"""User-provided _meta should be sent as MCP request metadata, not tool arguments."""
from opentelemetry import trace
tool_meta = {"from_tool": "tool-value", "shared": "tool-value"}
user_meta = {"from_user": "user-value", "shared": "user-value"}
class TestServer(MCPTool):
async def connect(self) -> None:
self.session = Mock(spec=ClientSession)
self.session.list_tools = AsyncMock(
return_value=types.ListToolsResult(
tools=[
types.Tool(
name="test_tool",
description="Test tool",
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
_meta=tool_meta,
)
]
)
)
self.session.call_tool = AsyncMock(
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="result")])
)
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
return None
server = TestServer(name="test_server")
async with server:
await server.load_tools()
with trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)):
await server.call_tool("test_tool", param="test_value", _meta=user_meta)
call_kwargs = server.session.call_tool.call_args.kwargs
assert call_kwargs["arguments"] == {"param": "test_value"}
assert call_kwargs["meta"] == {
"from_tool": "tool-value",
"from_user": "user-value",
"shared": "user-value",
}
assert user_meta == {"from_user": "user-value", "shared": "user-value"}
async def test_mcp_streamable_http_tool_hook_not_duplicated_on_repeated_get_mcp_client():
"""Test that calling get_mcp_client multiple times does not accumulate duplicate hooks."""
tool = MCPStreamableHTTPTool(
@@ -4641,6 +4811,42 @@ async def test_mcp_streamable_http_tool_header_provider_with_httpx_event_hook():
await tool._httpx_client.aclose()
async def test_mcp_streamable_http_tool_header_provider_skips_cross_origin_redirect():
"""The request hook must not re-add caller headers after a cross-origin redirect."""
import httpx
from agent_framework._mcp import _mcp_call_headers
tool = MCPStreamableHTTPTool(
name="test",
url="http://example.com/mcp",
header_provider=lambda kw: {"Authorization": f"Bearer {kw.get('token', '')}"},
)
try:
with patch("agent_framework._mcp.streamable_http_client"):
tool.get_mcp_client()
assert tool._httpx_client is not None
hooks = tool._httpx_client.event_hooks.get("request", [])
assert len(hooks) == 1
token = _mcp_call_headers.set({"Authorization": "Bearer secret"})
try:
same_origin = httpx.Request("POST", "http://example.com/redirected")
await hooks[0](same_origin)
assert same_origin.headers.get("Authorization") == "Bearer secret"
cross_origin = httpx.Request("POST", "http://attacker.example/capture")
await hooks[0](cross_origin)
assert "Authorization" not in cross_origin.headers
finally:
_mcp_call_headers.reset(token)
finally:
if getattr(tool, "_httpx_client", None) is not None:
await tool._httpx_client.aclose()
async def test_mcp_streamable_http_tool_header_provider_with_user_httpx_client():
"""Test that header_provider works when the user provides their own httpx client."""
import httpx
@@ -1691,6 +1691,65 @@ def test_to_otel_part_function_call():
}
def test_to_otel_part_function_call_reuses_prepared_arguments():
"""Test _to_otel_part does not re-serialize function-call arguments in the observability hot path."""
from agent_framework import Content
from agent_framework.observability import _to_otel_part
arguments = {"payload": object()}
content = Content(type="function_call", call_id="call_789", name="handoff", arguments=arguments)
result = _to_otel_part(content)
assert result is not None
assert result["arguments"] is arguments
def test_make_json_safe_non_callable_method_attribute():
"""Test make_json_safe handles objects where model_dump/to_dict/dict are non-callable attributes."""
from agent_framework._serialization import make_json_safe
class ObjWithNonCallableModelDump:
model_dump = 42 # not callable
obj = ObjWithNonCallableModelDump()
result = make_json_safe(obj)
assert result == {}
def test_make_json_safe_callable_method_type_error_falls_through():
"""Test make_json_safe falls through when serializer-like methods require arguments."""
from agent_framework._serialization import make_json_safe
class ObjWithRequiredArgModelDump:
def __init__(self) -> None:
self.value = "fallback"
def model_dump(self, required: str) -> dict[str, str]:
return {"required": required}
obj = ObjWithRequiredArgModelDump()
result = make_json_safe(obj)
assert result == {"value": "fallback"}
def test_make_json_safe_dict_with_non_string_keys():
"""Test make_json_safe converts non-primitive dict keys to strings."""
import json
from datetime import datetime
from agent_framework._serialization import make_json_safe
dt_key = datetime(2024, 1, 1)
obj = {dt_key: "value", 42: "num_value", "str_key": "normal"}
result = make_json_safe(obj)
# json.dumps must not raise TypeError
serialized = json.dumps(result)
parsed = json.loads(serialized)
assert parsed[str(dt_key)] == "value"
assert parsed["42"] == "num_value"
assert parsed["str_key"] == "normal"
def test_to_otel_part_function_result():
"""Test _to_otel_part with function_result content."""
from agent_framework import Content
@@ -3019,6 +3078,49 @@ async def test_system_instructions_preserves_non_ascii_characters(span_exporter:
assert [msg.get("role") for msg in input_messages] == ["user"]
@pytest.mark.parametrize("enable_sensitive_data", [True], indirect=True)
def test_capture_messages_with_prepared_request_info_function_call_arguments(span_exporter: InMemorySpanExporter):
"""Test _capture_messages handles request-info function-call arguments prepared at Content creation."""
import dataclasses
import json
from opentelemetry import trace
from agent_framework import WorkflowAgent
@dataclasses.dataclass
class HandoffRequest:
target_agent: str
reason: str
arguments = WorkflowAgent.RequestInfoFunctionArgs(
request_id="call_dc",
data=HandoffRequest(target_agent="helper", reason="overflow"),
).to_dict()
msg = Message(
role="assistant",
contents=[
Content(
type="function_call",
call_id="call_dc",
name="request_info",
arguments=arguments,
)
],
)
span_exporter.clear()
tracer = trace.get_tracer("test")
with tracer.start_as_current_span("test_span") as span:
_capture_messages(span=span, provider_name="test_provider", messages=[msg])
spans = span_exporter.get_finished_spans()
span = spans[0]
input_messages = json.loads(span.attributes[OtelAttr.INPUT_MESSAGES])
tool_part = input_messages[0]["parts"][0]
assert tool_part["type"] == "tool_call"
assert tool_part["arguments"]["data"] == {"target_agent": "helper", "reason": "overflow"}
def test_capture_messages_keeps_framework_instructions_out_of_logs_and_span_messages(
span_exporter: InMemorySpanExporter,
):
@@ -307,6 +307,63 @@ class TestHistoryProviderBase:
assert provider.stored[0].text == "hello"
assert provider.stored[1].text == "hi"
async def test_after_run_stores_coalesced_code_interpreter_chunks(self) -> None:
from agent_framework import AgentResponse, AgentResponseUpdate, Content
provider = ConcreteHistoryProvider("mem", store_inputs=False)
updates = [
AgentResponseUpdate(
role="assistant",
contents=[
Content.from_code_interpreter_tool_result(
call_id="ci_123",
outputs=[],
)
],
),
AgentResponseUpdate(
contents=[
Content.from_code_interpreter_tool_call(
call_id="ci_123",
inputs=[Content.from_text(text="import")],
additional_properties={"sequence_number": 1},
)
],
),
AgentResponseUpdate(
contents=[
Content.from_code_interpreter_tool_call(
call_id="ci_123",
inputs=[Content.from_text(text=" pandas")],
additional_properties={"sequence_number": 2},
)
],
),
AgentResponseUpdate(
contents=[
Content.from_code_interpreter_tool_call(
call_id="ci_123",
inputs=[Content.from_text(text="import pandas as pd")],
additional_properties={"sequence_number": 3},
)
],
),
]
ctx = SessionContext(session_id="s1", input_messages=[Message(role="user", contents=["make a sheet"])])
ctx._response = AgentResponse.from_updates(updates)
await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type]
assert len(provider.stored) == 1
stored_contents = provider.stored[0].contents
calls = [content for content in stored_contents if content.type == "code_interpreter_tool_call"]
results = [content for content in stored_contents if content.type == "code_interpreter_tool_result"]
assert len(calls) == 1
assert len(results) == 1
assert calls[0].inputs is not None
assert len(calls[0].inputs) == 1
assert calls[0].inputs[0].text == "import pandas as pd"
async def test_after_run_skips_inputs_when_disabled(self) -> None:
from agent_framework import AgentResponse
@@ -5,6 +5,7 @@
from __future__ import annotations
import asyncio
import json
import logging
from collections.abc import Iterator
from contextlib import contextmanager
@@ -1642,6 +1643,37 @@ class TestFunctionalWorkflowAgentHITL:
break
assert approval_found, "expected FunctionApprovalRequestContent in agent response"
async def test_request_info_dataclass_arguments_are_serialized_for_agent(self):
@dataclass
class HandoffRequest:
target_agent: str
reason: str
@workflow
async def wf(x: str, ctx: RunContext) -> str:
answer = await ctx.request_info(
HandoffRequest(target_agent=x, reason="overflow"),
response_type=str,
request_id="rid-1",
)
return f"got:{answer}"
agent = wf.as_agent()
response = await agent.run("helper")
function_call_arguments = None
for message in response.messages:
for content in message.contents:
if getattr(content, "type", None) == "function_approval_request" and content.function_call is not None:
function_call_arguments = content.function_call.arguments
break
assert function_call_arguments == {
"request_id": "rid-1",
"data": {"target_agent": "helper", "reason": "overflow"},
}
assert json.loads(json.dumps(function_call_arguments)) == function_call_arguments
async def test_resume_via_agent_responses_kwarg(self):
@workflow
async def wf(x: str, ctx: RunContext) -> str:
@@ -1,7 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.
import json
import uuid
from collections.abc import Awaitable, Sequence
from dataclasses import dataclass
from typing import Any, Literal, overload
import pytest
@@ -23,6 +25,7 @@ from agent_framework import (
WorkflowAgent,
WorkflowBuilder,
WorkflowContext,
WorkflowEvent,
executor,
handler,
response_handler,
@@ -292,6 +295,33 @@ class TestWorkflowAgent:
pending_requests = await workflow._runner_context.get_pending_request_info_events()
assert len(pending_requests) == 0
def test_request_info_dataclass_arguments_are_serialized_when_content_is_created(self) -> None:
"""Test WorkflowAgent prepares request_info arguments before observability captures messages."""
@dataclass
class HandoffRequest:
target_agent: str
reason: str
executor = SimpleExecutor(id="executor1", response_text="Response")
workflow = WorkflowBuilder(start_executor=executor).build()
agent = WorkflowAgent(workflow=workflow, name="Request Test Agent")
event = WorkflowEvent.request_info(
request_id="request_123",
source_executor_id="executor1",
request_data=HandoffRequest(target_agent="helper", reason="overflow"),
response_type=str,
)
function_call, approval_request = agent._process_request_info_event(event) # pyright: ignore[reportPrivateUsage]
assert function_call.arguments == {
"request_id": "request_123",
"data": {"target_agent": "helper", "reason": "overflow"},
}
assert approval_request.function_call is function_call
assert json.loads(json.dumps(function_call.arguments)) == function_call.arguments
def test_workflow_as_agent_method(self) -> None:
"""Test that Workflow.as_agent() creates a properly configured WorkflowAgent."""
# Create a simple workflow