mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Merge branch 'main' into local-branch-fix-workflow-as-agent-pending-request-handling
This commit is contained in:
@@ -71,6 +71,7 @@ from ._evaluation import (
|
||||
Evaluator,
|
||||
ExpectedToolCall,
|
||||
LocalEvaluator,
|
||||
RubricScore,
|
||||
evaluate_agent,
|
||||
evaluate_workflow,
|
||||
evaluator,
|
||||
@@ -460,6 +461,7 @@ __all__ = [
|
||||
"ResponseStream",
|
||||
"Role",
|
||||
"RoleLiteral",
|
||||
"RubricScore",
|
||||
"RunContext",
|
||||
"Runner",
|
||||
"RunnerContext",
|
||||
|
||||
@@ -311,12 +311,15 @@ class EvalScoreResult:
|
||||
score: Numeric score from the evaluator.
|
||||
passed: Whether the item passed this evaluator's threshold.
|
||||
sample: Optional raw evaluator output (rationale, metadata).
|
||||
dimensions: Per-dimension scores when this evaluator is a rubric
|
||||
evaluator. ``None`` for non-rubric (e.g. built-in) evaluators.
|
||||
"""
|
||||
|
||||
name: str
|
||||
score: float
|
||||
passed: bool | None = None
|
||||
sample: dict[str, Any] | None = None
|
||||
dimensions: list[RubricScore] | None = None
|
||||
|
||||
|
||||
@experimental(feature_id=ExperimentalFeature.EVALS)
|
||||
@@ -496,6 +499,179 @@ class EvalResults:
|
||||
detail += f" Errored items: {', '.join(summaries)}."
|
||||
raise EvalNotPassedError(detail)
|
||||
|
||||
def assert_score_at_least(
|
||||
self,
|
||||
min_score: float,
|
||||
*,
|
||||
evaluator: str | None = None,
|
||||
msg: str | None = None,
|
||||
) -> None:
|
||||
"""Assert every item's score (optionally filtered by evaluator) is ``>= min_score``.
|
||||
|
||||
Designed for CI gates on generated rubric evaluators (e.g.
|
||||
``results.assert_score_at_least(0.80)``). Includes any
|
||||
sub-results from workflow evaluations.
|
||||
|
||||
Args:
|
||||
min_score: Minimum acceptable score (inclusive).
|
||||
evaluator: When set, only check scores from the evaluator
|
||||
whose ``EvalScoreResult.name`` matches.
|
||||
msg: Optional custom failure message.
|
||||
|
||||
Raises:
|
||||
EvalNotPassedError: When any matching score is below the threshold.
|
||||
"""
|
||||
offenders: list[str] = []
|
||||
|
||||
def _check(results: EvalResults) -> None:
|
||||
for item in results.items:
|
||||
for score in item.scores:
|
||||
if evaluator is not None and score.name != evaluator:
|
||||
continue
|
||||
if score.score < min_score:
|
||||
offenders.append(f"{item.item_id}/{score.name}={score.score:.3f}")
|
||||
for sub in results.sub_results.values():
|
||||
_check(sub)
|
||||
|
||||
_check(self)
|
||||
if offenders:
|
||||
detail = msg or (
|
||||
f"{len(offenders)} score(s) below threshold {min_score}"
|
||||
f"{' for ' + evaluator if evaluator else ''}: {', '.join(offenders[:5])}"
|
||||
+ (f" (+{len(offenders) - 5} more)" if len(offenders) > 5 else "")
|
||||
)
|
||||
raise EvalNotPassedError(detail)
|
||||
|
||||
def assert_dimension_score_at_least(
|
||||
self,
|
||||
dimension_id: str,
|
||||
min_score: float,
|
||||
*,
|
||||
evaluator: str | None = None,
|
||||
require_applicable: bool = False,
|
||||
msg: str | None = None,
|
||||
) -> None:
|
||||
"""Assert every item's score for a rubric *dimension* is ``>= min_score``.
|
||||
|
||||
Walks ``EvalScoreResult.dimensions`` looking for the named
|
||||
dimension across all items (and sub-results). Non-applicable
|
||||
dimensions are skipped by default; pass
|
||||
``require_applicable=True`` to fail when no applicable score is
|
||||
produced.
|
||||
|
||||
Args:
|
||||
dimension_id: Dimension id (matches the rubric definition).
|
||||
min_score: Minimum acceptable dimension score (inclusive).
|
||||
evaluator: When set, only consider scores from the evaluator
|
||||
whose ``EvalScoreResult.name`` matches.
|
||||
require_applicable: When ``True``, missing or non-applicable
|
||||
dimension scores raise. Defaults to ``False`` (skip).
|
||||
msg: Optional custom failure message.
|
||||
|
||||
Raises:
|
||||
EvalNotPassedError: When the dimension fails the threshold.
|
||||
"""
|
||||
offenders: list[str] = []
|
||||
missing_items: list[str] = []
|
||||
|
||||
def _check(results: EvalResults) -> None:
|
||||
for item in results.items:
|
||||
found_applicable = False
|
||||
for score in item.scores:
|
||||
if evaluator is not None and score.name != evaluator:
|
||||
continue
|
||||
if not score.dimensions:
|
||||
continue
|
||||
for rs in score.dimensions:
|
||||
if rs.id != dimension_id:
|
||||
continue
|
||||
if not rs.applicable:
|
||||
continue
|
||||
found_applicable = True
|
||||
if rs.score is None or rs.score < min_score:
|
||||
offenders.append(
|
||||
f"{item.item_id}/{score.name}/{dimension_id}="
|
||||
f"{rs.score if rs.score is not None else 'None'}"
|
||||
)
|
||||
if require_applicable and not found_applicable:
|
||||
missing_items.append(item.item_id)
|
||||
for sub in results.sub_results.values():
|
||||
_check(sub)
|
||||
|
||||
_check(self)
|
||||
problems: list[str] = []
|
||||
if offenders:
|
||||
problems.append(
|
||||
f"{len(offenders)} dimension score(s) for '{dimension_id}' below {min_score}: "
|
||||
f"{', '.join(offenders[:5])}" + (f" (+{len(offenders) - 5} more)" if len(offenders) > 5 else "")
|
||||
)
|
||||
if missing_items:
|
||||
problems.append(
|
||||
f"Dimension '{dimension_id}' not applicable on {len(missing_items)} item(s): "
|
||||
f"{', '.join(missing_items[:5])}"
|
||||
)
|
||||
if problems:
|
||||
raise EvalNotPassedError(msg or "; ".join(problems))
|
||||
|
||||
def assert_no_failed_items(self, msg: str | None = None) -> None:
|
||||
"""Assert no item ended in ``fail`` or ``error`` status.
|
||||
|
||||
Includes any sub-results from workflow evaluations.
|
||||
|
||||
Args:
|
||||
msg: Optional custom failure message.
|
||||
|
||||
Raises:
|
||||
EvalNotPassedError: When any item failed or errored.
|
||||
"""
|
||||
bad: list[str] = []
|
||||
|
||||
def _check(results: EvalResults) -> None:
|
||||
for item in results.items:
|
||||
if item.is_failed or item.is_error:
|
||||
bad.append(f"{item.item_id}:{item.status}")
|
||||
for sub in results.sub_results.values():
|
||||
_check(sub)
|
||||
|
||||
_check(self)
|
||||
if bad:
|
||||
detail = msg or (
|
||||
f"{len(bad)} item(s) failed or errored: {', '.join(bad[:5])}"
|
||||
+ (f" (+{len(bad) - 5} more)" if len(bad) > 5 else "")
|
||||
)
|
||||
raise EvalNotPassedError(detail)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Generated rubric evaluators
|
||||
|
||||
|
||||
@experimental(feature_id=ExperimentalFeature.EVALS)
|
||||
@dataclass(frozen=True)
|
||||
class RubricScore:
|
||||
"""A single dimension's score from a rubric-based evaluator run.
|
||||
|
||||
Rubric evaluators emit one ``RubricScore`` per dimension per item.
|
||||
Attached to :class:`EvalScoreResult` as a typed view of the raw
|
||||
``properties.rubric_scores`` payload returned by providers such as
|
||||
Foundry's generated rubric evaluators.
|
||||
|
||||
Attributes:
|
||||
id: Dimension id (matches the rubric definition).
|
||||
score: Numeric score, or ``None`` when the dimension was marked
|
||||
non-applicable for this item.
|
||||
applicable: Whether the dimension applied to this item.
|
||||
weight: Dimension weight (mirrors the rubric definition).
|
||||
reason: Short rationale produced by the evaluator.
|
||||
"""
|
||||
|
||||
id: str
|
||||
score: int | None
|
||||
applicable: bool
|
||||
weight: int
|
||||
reason: str
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -14,12 +14,13 @@ import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .._agents import Agent
|
||||
from .._agents import Agent, SupportsAgentRun
|
||||
from .._clients import SupportsWebSearchTool
|
||||
from .._compaction import CompactionProvider, ContextWindowCompactionStrategy, ToolResultCompactionStrategy
|
||||
from .._feature_stage import ExperimentalFeature, experimental
|
||||
from .._sessions import ContextProvider, HistoryProvider, InMemoryHistoryProvider
|
||||
from .._skills import SkillsProvider
|
||||
from ._background_agents import BackgroundAgentsProvider
|
||||
from ._memory import MemoryContextProvider, MemoryStore
|
||||
from ._mode import AgentModeProvider
|
||||
from ._todo import TodoProvider
|
||||
@@ -103,6 +104,8 @@ def _assemble_context_providers(
|
||||
memory_store: MemoryStore | None,
|
||||
skills_provider: SkillsProvider | None,
|
||||
skills_paths: Sequence[str] | None,
|
||||
background_agents: Sequence[SupportsAgentRun] | None,
|
||||
background_agents_instructions: str | None,
|
||||
extra_context_providers: Sequence[ContextProvider] | None,
|
||||
) -> list[ContextProvider]:
|
||||
"""Assemble the ordered list of context providers."""
|
||||
@@ -130,6 +133,10 @@ def _assemble_context_providers(
|
||||
if skills_paths:
|
||||
providers.append(SkillsProvider.from_paths(*skills_paths))
|
||||
|
||||
# Background agents are opt-in: only added when agents are provided.
|
||||
if background_agents:
|
||||
providers.append(BackgroundAgentsProvider(background_agents, instructions=background_agents_instructions))
|
||||
|
||||
# Append any user-supplied additional providers.
|
||||
if extra_context_providers:
|
||||
providers.extend(extra_context_providers)
|
||||
@@ -165,6 +172,8 @@ def create_harness_agent(
|
||||
memory_store: MemoryStore | None = None,
|
||||
skills_provider: SkillsProvider | None = None,
|
||||
skills_paths: Sequence[str] | None = None,
|
||||
background_agents: Sequence[SupportsAgentRun] | None = None,
|
||||
background_agents_instructions: str | None = None,
|
||||
disable_web_search: bool = False,
|
||||
otel_provider_name: str | None = None,
|
||||
context_providers: Sequence[ContextProvider] | None = None,
|
||||
@@ -182,6 +191,7 @@ def create_harness_agent(
|
||||
- **AgentModeProvider** — plan/execute mode tracking
|
||||
- **MemoryContextProvider** — file-based durable memory (when ``memory_store`` provided)
|
||||
- **SkillsProvider** — skill discovery and progressive loading
|
||||
- **BackgroundAgentsProvider** — delegate work to background sub-agents
|
||||
- **OpenTelemetry** — observability via ``AgentTelemetryLayer``
|
||||
|
||||
Each feature can be disabled or customized via keyword arguments.
|
||||
@@ -253,6 +263,13 @@ def create_harness_agent(
|
||||
skills_paths: Paths for file-based skill discovery (looks for SKILL.md files).
|
||||
Can be combined with ``skills_provider``. When neither ``skills_provider``
|
||||
nor ``skills_paths`` is provided, no SkillsProvider is added.
|
||||
background_agents: Collection of agents available for background task delegation.
|
||||
When provided, a ``BackgroundAgentsProvider`` is automatically included,
|
||||
enabling the agent to start, monitor, and retrieve results from background tasks.
|
||||
Each agent must have a non-empty, unique name (case-insensitive).
|
||||
background_agents_instructions: Optional instruction override for the
|
||||
``BackgroundAgentsProvider``. May include ``{background_agents}`` placeholder
|
||||
which will be replaced with the agent listing.
|
||||
disable_web_search: When True, skip automatic web search tool inclusion.
|
||||
When False (default), the web search tool is automatically added if the
|
||||
client implements SupportsWebSearchTool. A warning is logged if the client
|
||||
@@ -302,6 +319,8 @@ def create_harness_agent(
|
||||
memory_store=memory_store,
|
||||
skills_provider=skills_provider,
|
||||
skills_paths=skills_paths,
|
||||
background_agents=background_agents,
|
||||
background_agents_instructions=background_agents_instructions,
|
||||
extra_context_providers=context_providers,
|
||||
)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import logging
|
||||
import re
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable, Collection, Coroutine, Sequence
|
||||
from collections.abc import Callable, Collection, Coroutine, Mapping, Sequence
|
||||
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
@@ -142,6 +142,13 @@ def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str,
|
||||
return meta
|
||||
|
||||
|
||||
def _url_origin(url: Any) -> tuple[str, str, int | None]:
|
||||
port = url.port
|
||||
if port is None:
|
||||
port = 443 if url.scheme == "https" else 80 if url.scheme == "http" else None
|
||||
return (url.scheme, url.host or "", port)
|
||||
|
||||
|
||||
def streamable_http_client(*args: Any, **kwargs: Any) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
"""Lazily import the MCP streamable HTTP transport."""
|
||||
try:
|
||||
@@ -255,6 +262,7 @@ class MCPTool:
|
||||
self._exit_stack = AsyncExitStack()
|
||||
self._lifecycle_lock = asyncio.Lock()
|
||||
self._lifecycle_request_lock = asyncio.Lock()
|
||||
self._function_load_lock = asyncio.Lock()
|
||||
self._lifecycle_queue: asyncio.Queue[tuple[str, bool, bool, asyncio.Future[None]]] | None = None
|
||||
self._lifecycle_owner_task: asyncio.Task[None] | None = None
|
||||
self.session = session
|
||||
@@ -655,6 +663,11 @@ class MCPTool:
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("Could not cleanly close MCP exit stack because the lifecycle owner task was cancelled.")
|
||||
except Exception as e:
|
||||
if type(e).__name__ == "ExceptionGroup":
|
||||
logger.warning("Could not cleanly close MCP exit stack due to cleanup error group. Error: %s", e)
|
||||
else:
|
||||
raise
|
||||
|
||||
async def _close_and_check_cancelled(self, ex: BaseException) -> bool:
|
||||
"""Close the exit stack and return True if *ex* is a genuine task cancellation.
|
||||
@@ -1018,6 +1031,10 @@ class MCPTool:
|
||||
Raises:
|
||||
ToolExecutionException: If the MCP server is not connected.
|
||||
"""
|
||||
async with self._function_load_lock:
|
||||
await self._load_prompts_locked()
|
||||
|
||||
async def _load_prompts_locked(self) -> None:
|
||||
from anyio import ClosedResourceError
|
||||
from mcp import types
|
||||
|
||||
@@ -1100,6 +1117,10 @@ class MCPTool:
|
||||
Raises:
|
||||
ToolExecutionException: If the MCP server is not connected.
|
||||
"""
|
||||
async with self._function_load_lock:
|
||||
await self._load_tools_locked()
|
||||
|
||||
async def _load_tools_locked(self) -> None:
|
||||
from anyio import ClosedResourceError
|
||||
from mcp import types
|
||||
|
||||
@@ -1109,7 +1130,7 @@ class MCPTool:
|
||||
|
||||
# Track existing function names to prevent duplicates
|
||||
existing_names = {func.name for func in self._functions}
|
||||
self._tool_call_meta_by_name.clear()
|
||||
tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
|
||||
|
||||
params: types.PaginatedRequestParams | None = None
|
||||
while True:
|
||||
@@ -1145,7 +1166,7 @@ class MCPTool:
|
||||
|
||||
for tool in tool_list.tools:
|
||||
if tool.meta is not None:
|
||||
self._tool_call_meta_by_name[tool.name] = dict(tool.meta)
|
||||
tool_call_meta_by_name[tool.name] = dict(tool.meta)
|
||||
|
||||
normalized_name = _normalize_mcp_name(tool.name)
|
||||
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
|
||||
@@ -1194,6 +1215,8 @@ class MCPTool:
|
||||
break
|
||||
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)
|
||||
|
||||
self._tool_call_meta_by_name = tool_call_meta_by_name
|
||||
|
||||
async def _close_on_owner(self) -> None:
|
||||
# Cancel any pending reload tasks before tearing down the session.
|
||||
tasks = list(self._pending_reload_tasks)
|
||||
@@ -1276,7 +1299,11 @@ class MCPTool:
|
||||
tool_name: The name of the tool to call.
|
||||
|
||||
Keyword Args:
|
||||
kwargs: Arguments to pass to the tool.
|
||||
_meta: Optional ``dict[str, Any]`` of MCP request metadata. This reserved key is passed as the
|
||||
``meta`` parameter of the underlying ``session.call_tool`` call rather than as a tool argument.
|
||||
User-supplied keys override metadata from ``tools/list``; OpenTelemetry propagation fills in
|
||||
non-conflicting keys.
|
||||
kwargs: Remaining arguments to pass to the tool.
|
||||
|
||||
Returns:
|
||||
A list of Content items representing the tool output. The default
|
||||
@@ -1294,6 +1321,19 @@ class MCPTool:
|
||||
raise ToolExecutionException(
|
||||
"Tools are not loaded for this server, please set load_tools=True in the constructor."
|
||||
)
|
||||
|
||||
raw_user_meta: object | None = kwargs.get("_meta")
|
||||
user_meta: dict[str, Any] | None = None
|
||||
if raw_user_meta is not None and not isinstance(raw_user_meta, dict):
|
||||
raise ToolExecutionException("MCP tool metadata provided via _meta must be a dict.")
|
||||
if isinstance(raw_user_meta, dict):
|
||||
raw_user_meta_dict = cast(Mapping[object, object], raw_user_meta)
|
||||
user_meta = {}
|
||||
for key, value in raw_user_meta_dict.items():
|
||||
if not isinstance(key, str):
|
||||
raise ToolExecutionException("MCP tool metadata provided via _meta must use string keys.")
|
||||
user_meta[key] = value
|
||||
|
||||
# Filter out framework kwargs that cannot be serialized by the MCP SDK.
|
||||
# These are internal objects passed through the function invocation pipeline
|
||||
# that should not be forwarded to external MCP servers.
|
||||
@@ -1313,12 +1353,16 @@ class MCPTool:
|
||||
"conversation_id",
|
||||
"options",
|
||||
"response_format",
|
||||
"_meta",
|
||||
}
|
||||
}
|
||||
|
||||
# Some MCP proxies require their tools/list metadata to be echoed on tools/call.
|
||||
tool_meta = self._tool_call_meta_by_name.get(tool_name)
|
||||
meta = _inject_otel_into_mcp_meta(dict(tool_meta) if tool_meta is not None else None)
|
||||
request_meta = dict(tool_meta) if tool_meta is not None else None
|
||||
if user_meta is not None:
|
||||
request_meta = {**(request_meta or {}), **user_meta}
|
||||
meta = _inject_otel_into_mcp_meta(request_meta)
|
||||
|
||||
parser = self.parse_tool_results or self._parse_tool_result_from_mcp
|
||||
# Try the operation, reconnecting once if the connection is closed
|
||||
@@ -1336,28 +1380,33 @@ class MCPTool:
|
||||
return parser(result)
|
||||
except ToolExecutionException:
|
||||
raise
|
||||
except ClosedResourceError as cl_ex:
|
||||
except (ClosedResourceError, McpError) as call_ex:
|
||||
is_session_terminated = (
|
||||
isinstance(call_ex, McpError) and "session terminated" in call_ex.error.message.lower()
|
||||
)
|
||||
is_connection_lost = isinstance(call_ex, ClosedResourceError) or is_session_terminated
|
||||
if not is_connection_lost:
|
||||
error_message = call_ex.error.message if isinstance(call_ex, McpError) else str(call_ex)
|
||||
raise ToolExecutionException(error_message, inner_exception=call_ex) from call_ex
|
||||
|
||||
if attempt == 0:
|
||||
# First attempt failed, try reconnecting
|
||||
logger.info("MCP connection closed unexpectedly. Reconnecting...")
|
||||
# First attempt failed, try reconnecting.
|
||||
logger.info("MCP connection closed or terminated unexpectedly. Reconnecting...")
|
||||
try:
|
||||
await self.connect(reset=True)
|
||||
continue # Retry the operation
|
||||
continue
|
||||
except Exception as reconn_ex:
|
||||
raise ToolExecutionException(
|
||||
"Failed to reconnect to MCP server.",
|
||||
inner_exception=reconn_ex,
|
||||
) from reconn_ex
|
||||
else:
|
||||
# Second attempt also failed, give up
|
||||
logger.error(f"MCP connection closed unexpectedly after reconnection: {cl_ex}")
|
||||
raise ToolExecutionException(
|
||||
f"Failed to call tool '{tool_name}' - connection lost.",
|
||||
inner_exception=cl_ex,
|
||||
) from cl_ex
|
||||
except McpError as mcp_exc:
|
||||
error_message = mcp_exc.error.message
|
||||
raise ToolExecutionException(error_message, inner_exception=mcp_exc) from mcp_exc
|
||||
|
||||
# Second attempt also failed, give up.
|
||||
logger.error("MCP connection closed unexpectedly after reconnection: %s", call_ex)
|
||||
raise ToolExecutionException(
|
||||
f"Failed to call tool '{tool_name}' - connection lost.",
|
||||
inner_exception=call_ex,
|
||||
) from call_ex
|
||||
except Exception as ex:
|
||||
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
|
||||
raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.")
|
||||
@@ -1718,10 +1767,11 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
Returns:
|
||||
An async context manager for the streamable HTTP client transport.
|
||||
"""
|
||||
from httpx import AsyncClient, Request, Timeout
|
||||
from httpx import URL, AsyncClient, Request, Timeout
|
||||
|
||||
http_client = self._httpx_client
|
||||
if self._header_provider is not None:
|
||||
target_origin = _url_origin(URL(self.url))
|
||||
if http_client is None:
|
||||
http_client = AsyncClient(
|
||||
follow_redirects=True,
|
||||
@@ -1732,6 +1782,8 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
if not hasattr(self, "_inject_headers_hook"):
|
||||
|
||||
async def _inject_headers(request: Request) -> None: # noqa: RUF029
|
||||
if _url_origin(request.url) != target_origin:
|
||||
return
|
||||
headers = _mcp_call_headers.get({})
|
||||
for key, value in headers.items():
|
||||
request.headers[key] = value
|
||||
|
||||
@@ -36,11 +36,10 @@ if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._agents import SupportsAgentRun
|
||||
from ._clients import SupportsChatGetResponse
|
||||
from ._compaction import CompactionStrategy, TokenizerProtocol
|
||||
from ._sessions import AgentSession
|
||||
from ._tools import FunctionTool, ToolTypes
|
||||
from ._types import ChatOptions, ChatResponse, ChatResponseUpdate
|
||||
from ._types import ChatOptions
|
||||
|
||||
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from datetime import date, datetime
|
||||
from typing import Any, ClassVar, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
logger = logging.getLogger("agent_framework")
|
||||
@@ -614,3 +616,46 @@ class SerializationMixin:
|
||||
# Fallback and default
|
||||
# Convert class name to snake_case
|
||||
return _CAMEL_TO_SNAKE_PATTERN.sub("_", cls.__name__).lower()
|
||||
|
||||
|
||||
def make_json_safe(obj: Any) -> Any:
|
||||
"""Recursively convert an object to a JSON-serializable form.
|
||||
|
||||
Handles dataclasses, Pydantic models, objects with ``to_dict``/``dict``/``__dict__``,
|
||||
datetimes, lists, dicts, and primitives. Falls back to ``str()`` for any remaining
|
||||
non-serializable value so that ``json.dumps`` never raises a ``TypeError``.
|
||||
|
||||
Args:
|
||||
obj: Object to make JSON safe.
|
||||
|
||||
Returns:
|
||||
A JSON-serializable version of the object.
|
||||
"""
|
||||
if obj is None or isinstance(obj, (str, int, float, bool)):
|
||||
return obj
|
||||
if isinstance(obj, (datetime, date)):
|
||||
return obj.isoformat()
|
||||
if is_dataclass(obj) and not isinstance(obj, type):
|
||||
return make_json_safe(asdict(obj)) # type: ignore[arg-type]
|
||||
if callable(getattr(obj, "model_dump", None)):
|
||||
try:
|
||||
return make_json_safe(obj.model_dump()) # type: ignore[no-any-return]
|
||||
except TypeError:
|
||||
pass
|
||||
if callable(getattr(obj, "to_dict", None)):
|
||||
try:
|
||||
return make_json_safe(obj.to_dict()) # type: ignore[no-any-return]
|
||||
except TypeError:
|
||||
pass
|
||||
if callable(getattr(obj, "dict", None)):
|
||||
try:
|
||||
return make_json_safe(obj.dict()) # type: ignore[no-any-return]
|
||||
except TypeError:
|
||||
pass
|
||||
if isinstance(obj, dict):
|
||||
return {str(key): make_json_safe(value) for key, value in obj.items()} # type: ignore[misc]
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return [make_json_safe(item) for item in obj] # type: ignore[misc]
|
||||
if hasattr(obj, "__dict__"):
|
||||
return {key: make_json_safe(value) for key, value in vars(obj).items()} # type: ignore[misc]
|
||||
return str(obj)
|
||||
|
||||
@@ -1973,11 +1973,97 @@ def _coalesce_text_content(contents: list[Content], type_str: Literal["text", "t
|
||||
contents.extend(coalesced_contents)
|
||||
|
||||
|
||||
def _content_items_text(items: Any) -> str | None:
|
||||
"""Return concatenated text when a content item list only contains text."""
|
||||
if not isinstance(items, list):
|
||||
return None
|
||||
text_parts: list[str] = []
|
||||
content_items = cast(list[object], items)
|
||||
for item in content_items:
|
||||
if not isinstance(item, Content) or item.type != "text":
|
||||
return None
|
||||
text_parts.append(item.text or "")
|
||||
return "".join(text_parts)
|
||||
|
||||
|
||||
def _merge_content_item_lists(existing: Any, incoming: Any) -> Any:
|
||||
"""Merge streamed nested content lists, replacing deltas with a later full value when present."""
|
||||
if incoming is None:
|
||||
return existing
|
||||
if existing is None:
|
||||
return deepcopy(incoming)
|
||||
|
||||
existing_text = _content_items_text(existing)
|
||||
incoming_text = _content_items_text(incoming)
|
||||
if existing_text is not None and incoming_text is not None:
|
||||
if incoming_text.startswith(existing_text):
|
||||
return deepcopy(incoming)
|
||||
if existing_text.startswith(incoming_text):
|
||||
return existing
|
||||
|
||||
existing_items = cast(list[Content], existing)
|
||||
merged = deepcopy(existing_items[0])
|
||||
merged.text = existing_text + incoming_text
|
||||
return [merged]
|
||||
|
||||
if isinstance(existing, list) and isinstance(incoming, list):
|
||||
existing_list = cast(list[object], existing)
|
||||
incoming_list = cast(list[object], incoming)
|
||||
return [*existing_list, *deepcopy(incoming_list)]
|
||||
return deepcopy(incoming)
|
||||
|
||||
|
||||
def _merge_code_interpreter_content(existing: Content, incoming: Content) -> None:
|
||||
"""Merge two code interpreter content items for the same logical call."""
|
||||
existing.inputs = _merge_content_item_lists(existing.inputs, incoming.inputs)
|
||||
existing.outputs = _merge_content_item_lists(existing.outputs, incoming.outputs)
|
||||
existing.annotations = _combine_annotations(existing.annotations, incoming.annotations)
|
||||
existing.additional_properties = {**existing.additional_properties, **incoming.additional_properties}
|
||||
existing.raw_representation = _combine_raw_representations(existing.raw_representation, incoming.raw_representation)
|
||||
|
||||
|
||||
def _code_interpreter_key(content: Content) -> tuple[str, str] | None:
|
||||
"""Return the aggregation key for code interpreter call/result content."""
|
||||
if content.type not in {"code_interpreter_tool_call", "code_interpreter_tool_result"}:
|
||||
return None
|
||||
call_id = content.call_id or content.additional_properties.get("item_id")
|
||||
if not isinstance(call_id, str) or not call_id:
|
||||
return None
|
||||
return content.type, call_id
|
||||
|
||||
|
||||
def _coalesce_code_interpreter_content(contents: list[Content]) -> None:
|
||||
"""Coalesce streaming code interpreter chunks by call id."""
|
||||
if not contents:
|
||||
return
|
||||
|
||||
coalesced_contents: list[Content] = []
|
||||
seen: dict[tuple[str, str], Content] = {}
|
||||
for content in contents:
|
||||
key = _code_interpreter_key(content)
|
||||
if key is None:
|
||||
coalesced_contents.append(content)
|
||||
continue
|
||||
|
||||
existing = seen.get(key)
|
||||
if existing is None:
|
||||
copied = deepcopy(content)
|
||||
seen[key] = copied
|
||||
coalesced_contents.append(copied)
|
||||
continue
|
||||
|
||||
_merge_code_interpreter_content(existing, content)
|
||||
|
||||
contents.clear()
|
||||
contents.extend(coalesced_contents)
|
||||
|
||||
|
||||
def _finalize_response(response: ChatResponse | AgentResponse) -> None:
|
||||
"""Finalizes the response by performing any necessary post-processing."""
|
||||
for msg in response.messages:
|
||||
_coalesce_text_content(msg.contents, "text")
|
||||
_coalesce_text_content(msg.contents, "text_reasoning")
|
||||
_coalesce_code_interpreter_content(msg.contents)
|
||||
|
||||
|
||||
# region ContinuationToken
|
||||
|
||||
@@ -12,6 +12,7 @@ from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload
|
||||
|
||||
from .._agents import BaseAgent
|
||||
from .._serialization import make_json_safe
|
||||
from .._sessions import (
|
||||
AgentSession,
|
||||
ContextProvider,
|
||||
@@ -62,7 +63,7 @@ class WorkflowAgent(BaseAgent):
|
||||
data: Any
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {"request_id": self.request_id, "data": self.data}
|
||||
return {"request_id": self.request_id, "data": make_json_safe(self.data)}
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@@ -47,6 +47,7 @@ from copy import deepcopy
|
||||
from typing import Any, Generic, Literal, TypeVar, overload
|
||||
|
||||
from .._feature_stage import ExperimentalFeature, experimental
|
||||
from .._serialization import make_json_safe
|
||||
from .._types import AgentResponse, AgentResponseUpdate, ResponseStream
|
||||
from ..observability import OtelAttr, capture_exception, create_workflow_span
|
||||
from ._checkpoint import CheckpointStorage, WorkflowCheckpoint
|
||||
@@ -1515,7 +1516,7 @@ class FunctionalWorkflowAgent:
|
||||
function_call = Content.from_function_call(
|
||||
call_id=request_id,
|
||||
name=self.REQUEST_INFO_FUNCTION_NAME,
|
||||
arguments={"request_id": request_id, "data": event.data},
|
||||
arguments={"request_id": request_id, "data": make_json_safe(event.data)},
|
||||
)
|
||||
return Content.from_function_approval_request(
|
||||
id=request_id,
|
||||
|
||||
@@ -34,6 +34,7 @@ _IMPORTS: dict[str, tuple[str, str]] = {
|
||||
"FoundryLocalChatOptions": ("agent_framework_foundry_local", "agent-framework-foundry-local"),
|
||||
"FoundryLocalClient": ("agent_framework_foundry_local", "agent-framework-foundry-local"),
|
||||
"FoundryLocalSettings": ("agent_framework_foundry_local", "agent-framework-foundry-local"),
|
||||
"GeneratedEvaluatorRef": ("agent_framework_foundry", "agent-framework-foundry"),
|
||||
"RawAnthropicFoundryClient": ("agent_framework_anthropic", "agent-framework-anthropic"),
|
||||
"RawFoundryAgent": ("agent_framework_foundry", "agent-framework-foundry"),
|
||||
"RawFoundryAgentChatClient": ("agent_framework_foundry", "agent-framework-foundry"),
|
||||
|
||||
@@ -20,6 +20,7 @@ from agent_framework_foundry import (
|
||||
FoundryEmbeddingSettings,
|
||||
FoundryEvals,
|
||||
FoundryMemoryProvider,
|
||||
GeneratedEvaluatorRef,
|
||||
RawFoundryAgent,
|
||||
RawFoundryAgentChatClient,
|
||||
RawFoundryChatClient,
|
||||
@@ -52,6 +53,7 @@ __all__ = [
|
||||
"FoundryLocalClient",
|
||||
"FoundryLocalSettings",
|
||||
"FoundryMemoryProvider",
|
||||
"GeneratedEvaluatorRef",
|
||||
"RawAnthropicFoundryClient",
|
||||
"RawFoundryAgent",
|
||||
"RawFoundryAgentChatClient",
|
||||
|
||||
@@ -394,3 +394,94 @@ def test_create_harness_agent_logs_warning_when_no_web_search(caplog: pytest.Log
|
||||
max_output_tokens=16_384,
|
||||
)
|
||||
assert any("SupportsWebSearchTool" in msg for msg in caplog.messages)
|
||||
|
||||
|
||||
# --- Background Agents Tests ---
|
||||
|
||||
|
||||
class _FakeBackgroundAgent:
|
||||
"""Minimal agent stub satisfying SupportsAgentRun for background agents tests."""
|
||||
|
||||
def __init__(self, name: str, description: str | None = None):
|
||||
self.id = f"agent-{name}"
|
||||
self.name = name
|
||||
self.description = description
|
||||
|
||||
def create_session(self, *, session_id: str | None = None) -> AgentSession:
|
||||
return AgentSession(session_id=session_id)
|
||||
|
||||
def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession:
|
||||
return AgentSession(service_session_id=service_session_id, session_id=session_id)
|
||||
|
||||
async def run(self, messages: Any = None, *, stream: bool = False, session: Any = None, **kwargs: Any) -> Any:
|
||||
from agent_framework import AgentResponse
|
||||
|
||||
return AgentResponse(messages=[], response_id="fake-bg-response")
|
||||
|
||||
|
||||
def test_create_harness_agent_no_background_agents_by_default() -> None:
|
||||
"""No BackgroundAgentsProvider should be included when background_agents is not provided."""
|
||||
from agent_framework._harness._background_agents import BackgroundAgentsProvider
|
||||
|
||||
agent = create_harness_agent(
|
||||
client=_FakeChatClient(), # type: ignore[arg-type]
|
||||
max_context_window_tokens=128_000,
|
||||
max_output_tokens=16_384,
|
||||
disable_web_search=True,
|
||||
)
|
||||
providers = agent.context_providers or []
|
||||
assert not any(isinstance(p, BackgroundAgentsProvider) for p in providers)
|
||||
|
||||
|
||||
def test_create_harness_agent_adds_background_agents_provider() -> None:
|
||||
"""BackgroundAgentsProvider should be included when background_agents are provided."""
|
||||
from agent_framework._harness._background_agents import BackgroundAgentsProvider
|
||||
|
||||
bg_agent = _FakeBackgroundAgent("WebSearcher", "Searches the web")
|
||||
agent = create_harness_agent(
|
||||
client=_FakeChatClient(), # type: ignore[arg-type]
|
||||
max_context_window_tokens=128_000,
|
||||
max_output_tokens=16_384,
|
||||
disable_web_search=True,
|
||||
background_agents=[bg_agent],
|
||||
)
|
||||
providers = agent.context_providers or []
|
||||
bg_providers = [p for p in providers if isinstance(p, BackgroundAgentsProvider)]
|
||||
assert len(bg_providers) == 1
|
||||
|
||||
|
||||
def test_create_harness_agent_background_agents_custom_instructions() -> None:
|
||||
"""Custom instructions should be passed to BackgroundAgentsProvider."""
|
||||
from agent_framework._harness._background_agents import BackgroundAgentsProvider
|
||||
|
||||
custom_instructions = "## Custom\n\nUse agents wisely.\n\n{background_agents}"
|
||||
bg_agent = _FakeBackgroundAgent("Helper", "A helper agent")
|
||||
agent = create_harness_agent(
|
||||
client=_FakeChatClient(), # type: ignore[arg-type]
|
||||
max_context_window_tokens=128_000,
|
||||
max_output_tokens=16_384,
|
||||
disable_web_search=True,
|
||||
background_agents=[bg_agent],
|
||||
background_agents_instructions=custom_instructions,
|
||||
)
|
||||
providers = agent.context_providers or []
|
||||
bg_providers = [p for p in providers if isinstance(p, BackgroundAgentsProvider)]
|
||||
assert len(bg_providers) == 1
|
||||
# Verify the custom instructions were used (placeholder replaced with agent list).
|
||||
assert "Custom" in bg_providers[0]._instructions
|
||||
assert "Helper" in bg_providers[0]._instructions
|
||||
|
||||
|
||||
def test_create_harness_agent_empty_background_agents_list() -> None:
|
||||
"""An empty background_agents list should NOT add a BackgroundAgentsProvider."""
|
||||
from agent_framework._harness._background_agents import BackgroundAgentsProvider
|
||||
|
||||
agent = create_harness_agent(
|
||||
client=_FakeChatClient(), # type: ignore[arg-type]
|
||||
max_context_window_tokens=128_000,
|
||||
max_output_tokens=16_384,
|
||||
disable_web_search=True,
|
||||
background_agents=[],
|
||||
)
|
||||
providers = agent.context_providers or []
|
||||
assert not any(isinstance(p, BackgroundAgentsProvider) for p in providers)
|
||||
|
||||
@@ -11,8 +11,13 @@ import pytest
|
||||
from agent_framework._evaluation import (
|
||||
CheckResult,
|
||||
EvalItem,
|
||||
EvalItemResult,
|
||||
EvalNotPassedError,
|
||||
EvalResults,
|
||||
EvalScoreResult,
|
||||
ExpectedToolCall,
|
||||
LocalEvaluator,
|
||||
RubricScore,
|
||||
_coerce_result,
|
||||
evaluator,
|
||||
keyword_check,
|
||||
@@ -1010,19 +1015,101 @@ class TestAllPassedSubResults:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# r5 review: _build_overall_item with empty outputs
|
||||
# Rubric assertions (EvalResults.assert_*)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildOverallItemEmpty:
|
||||
"""Test _build_overall_item returns None for empty workflow outputs."""
|
||||
def _rubric_results(*scores_per_item: list[EvalScoreResult]) -> EvalResults:
|
||||
items = [
|
||||
EvalItemResult(item_id=f"item-{i}", status="pass", scores=scores) for i, scores in enumerate(scores_per_item)
|
||||
]
|
||||
return EvalResults(
|
||||
provider="test",
|
||||
eval_id="ev1",
|
||||
run_id="run1",
|
||||
result_counts={"passed": len(items), "failed": 0, "errored": 0, "total": len(items)},
|
||||
items=items,
|
||||
)
|
||||
|
||||
def test_returns_none_for_empty_outputs(self):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from agent_framework._evaluation import _build_overall_item
|
||||
class TestRubricAssertions:
|
||||
"""Tests for EvalResults.assert_dimension_score_at_least."""
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.get_outputs.return_value = []
|
||||
item = _build_overall_item("Hello", mock_result)
|
||||
assert item is None
|
||||
def test_dimension_at_or_above_threshold_passes(self) -> None:
|
||||
results = _rubric_results(
|
||||
[
|
||||
EvalScoreResult(
|
||||
name="policy",
|
||||
score=0.9,
|
||||
dimensions=[RubricScore(id="clarity", score=4, applicable=True, weight=1, reason="")],
|
||||
)
|
||||
],
|
||||
)
|
||||
# Should not raise.
|
||||
results.assert_dimension_score_at_least("clarity", 3)
|
||||
|
||||
def test_dimension_below_threshold_raises(self) -> None:
|
||||
results = _rubric_results(
|
||||
[
|
||||
EvalScoreResult(
|
||||
name="policy",
|
||||
score=0.5,
|
||||
dimensions=[RubricScore(id="clarity", score=2, applicable=True, weight=1, reason="")],
|
||||
)
|
||||
],
|
||||
)
|
||||
with pytest.raises(EvalNotPassedError):
|
||||
results.assert_dimension_score_at_least("clarity", 3)
|
||||
|
||||
def test_non_applicable_skipped_by_default(self) -> None:
|
||||
results = _rubric_results(
|
||||
[
|
||||
EvalScoreResult(
|
||||
name="policy",
|
||||
score=1.0,
|
||||
dimensions=[RubricScore(id="clarity", score=None, applicable=False, weight=1, reason="n/a")],
|
||||
)
|
||||
],
|
||||
)
|
||||
# No applicable scores; default behaviour is to skip silently.
|
||||
results.assert_dimension_score_at_least("clarity", 3)
|
||||
|
||||
def test_require_applicable_raises_when_dimension_absent(self) -> None:
|
||||
results = _rubric_results(
|
||||
[EvalScoreResult(name="policy", score=1.0, dimensions=[])],
|
||||
)
|
||||
with pytest.raises(EvalNotPassedError, match="not applicable"):
|
||||
results.assert_dimension_score_at_least("clarity", 3, require_applicable=True)
|
||||
|
||||
def test_require_applicable_raises_when_filtered_evaluator_missing(self) -> None:
|
||||
# Regression: previously the (not evaluator or found_any) guard caused
|
||||
# this case to silently pass even with require_applicable=True.
|
||||
results = _rubric_results(
|
||||
[
|
||||
EvalScoreResult(
|
||||
name="other",
|
||||
score=0.9,
|
||||
dimensions=[RubricScore(id="clarity", score=4, applicable=True, weight=1, reason="")],
|
||||
)
|
||||
],
|
||||
)
|
||||
with pytest.raises(EvalNotPassedError, match="not applicable"):
|
||||
results.assert_dimension_score_at_least("clarity", 3, evaluator="policy", require_applicable=True)
|
||||
|
||||
def test_evaluator_filter_isolates_offenders(self) -> None:
|
||||
results = _rubric_results(
|
||||
[
|
||||
EvalScoreResult(
|
||||
name="other",
|
||||
score=0.1,
|
||||
dimensions=[RubricScore(id="clarity", score=1, applicable=True, weight=1, reason="")],
|
||||
),
|
||||
EvalScoreResult(
|
||||
name="policy",
|
||||
score=0.9,
|
||||
dimensions=[RubricScore(id="clarity", score=4, applicable=True, weight=1, reason="")],
|
||||
),
|
||||
],
|
||||
)
|
||||
# The low-scoring "other" evaluator is filtered out; "policy" passes.
|
||||
results.assert_dimension_score_at_least("clarity", 3, evaluator="policy")
|
||||
|
||||
@@ -1161,6 +1161,43 @@ async def test_local_mcp_server_function_execution_error():
|
||||
await func.invoke(param="test_value")
|
||||
|
||||
|
||||
async def test_mcp_tool_reconnects_after_session_terminated_error():
|
||||
"""Session termination errors should reconnect once and retry the tool call."""
|
||||
|
||||
class TestServer(MCPTool):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.connect_count = 0
|
||||
self.sessions: list[Any] = []
|
||||
|
||||
async def connect(self, *, reset: bool = False) -> None:
|
||||
self.connect_count += 1
|
||||
self.session = Mock(spec=ClientSession)
|
||||
self.sessions.append(self.session)
|
||||
if self.connect_count == 1:
|
||||
self.session.call_tool = AsyncMock(
|
||||
side_effect=McpError(types.ErrorData(code=-32000, message="Session terminated"))
|
||||
)
|
||||
else:
|
||||
self.session.call_tool = AsyncMock(
|
||||
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="recovered")])
|
||||
)
|
||||
self.is_connected = True
|
||||
|
||||
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
return None
|
||||
|
||||
server = TestServer(name="test_server")
|
||||
await server.connect()
|
||||
|
||||
result = await server.call_tool("test_tool", param="test_value")
|
||||
|
||||
assert _mcp_result_to_text(result) == "recovered"
|
||||
assert server.connect_count == 2
|
||||
assert server.sessions[0].call_tool.await_count == 1
|
||||
assert server.sessions[1].call_tool.await_count == 1
|
||||
|
||||
|
||||
async def test_mcp_tool_call_tool_raises_on_is_error():
|
||||
"""Test that call_tool raises ToolExecutionException when MCP returns isError=True."""
|
||||
|
||||
@@ -3260,6 +3297,68 @@ async def test_load_prompts_pagination_with_duplicates():
|
||||
assert [f.name for f in tool._functions] == ["prompt_1", "prompt_2"]
|
||||
|
||||
|
||||
async def test_load_tools_concurrent_reload_does_not_duplicate_tools_and_preserves_meta():
|
||||
"""Concurrent tool reloads should not duplicate functions or lose tools/list metadata."""
|
||||
tool = MCPTool(name="test_tool")
|
||||
mock_session = AsyncMock()
|
||||
tool.session = mock_session
|
||||
tool.load_tools_flag = True
|
||||
|
||||
page = Mock()
|
||||
page.tools = [
|
||||
types.Tool(
|
||||
name="tool_1",
|
||||
description="First tool",
|
||||
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
|
||||
_meta={"echo": "tool_1"},
|
||||
),
|
||||
]
|
||||
page.nextCursor = None
|
||||
|
||||
async def mock_list_tools(params: Any = None) -> Any:
|
||||
assert params is None
|
||||
await asyncio.sleep(0)
|
||||
return page
|
||||
|
||||
mock_session.list_tools = AsyncMock(side_effect=mock_list_tools)
|
||||
|
||||
await asyncio.wait_for(asyncio.gather(tool.load_tools(), tool.load_tools()), timeout=1)
|
||||
|
||||
assert mock_session.list_tools.call_count == 2
|
||||
assert [f.name for f in tool._functions] == ["tool_1"]
|
||||
assert tool._tool_call_meta_by_name == {"tool_1": {"echo": "tool_1"}}
|
||||
|
||||
|
||||
async def test_load_prompts_concurrent_reload_does_not_duplicate_prompts():
|
||||
"""Concurrent prompt reloads should not duplicate functions."""
|
||||
tool = MCPTool(name="test_tool")
|
||||
mock_session = AsyncMock()
|
||||
tool.session = mock_session
|
||||
tool.load_prompts_flag = True
|
||||
|
||||
page = Mock()
|
||||
page.prompts = [
|
||||
types.Prompt(
|
||||
name="prompt_1",
|
||||
description="First prompt",
|
||||
arguments=[types.PromptArgument(name="arg1", description="Arg 1", required=True)],
|
||||
),
|
||||
]
|
||||
page.nextCursor = None
|
||||
|
||||
async def mock_list_prompts(params: Any = None) -> Any:
|
||||
assert params is None
|
||||
await asyncio.sleep(0)
|
||||
return page
|
||||
|
||||
mock_session.list_prompts = AsyncMock(side_effect=mock_list_prompts)
|
||||
|
||||
await asyncio.wait_for(asyncio.gather(tool.load_prompts(), tool.load_prompts()), timeout=1)
|
||||
|
||||
assert mock_session.list_prompts.call_count == 2
|
||||
assert [f.name for f in tool._functions] == ["prompt_1"]
|
||||
|
||||
|
||||
async def test_load_tools_pagination_exception_handling():
|
||||
"""Test that load_tools handles exceptions during pagination gracefully."""
|
||||
from unittest.mock import AsyncMock
|
||||
@@ -3891,6 +3990,31 @@ async def test_mcp_tool_safe_close_handles_cancelled_error():
|
||||
mock_exit_stack.aclose.assert_called_once()
|
||||
|
||||
|
||||
async def test_mcp_tool_safe_close_handles_cleanup_exception_group():
|
||||
"""Cleanup task groups should not hide the original connect failure."""
|
||||
import builtins
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
exception_group_type = getattr(builtins, "ExceptionGroup", None)
|
||||
if exception_group_type is None:
|
||||
pytest.skip("ExceptionGroup is not available on this Python version")
|
||||
|
||||
tool = MCPStreamableHTTPTool(
|
||||
name="test",
|
||||
url="http://example.com/mcp",
|
||||
load_tools=False,
|
||||
load_prompts=False,
|
||||
)
|
||||
|
||||
mock_exit_stack = AsyncMock(spec=AsyncExitStack)
|
||||
mock_exit_stack.aclose = AsyncMock(side_effect=exception_group_type("cleanup failed", [RuntimeError("reader")]))
|
||||
tool._exit_stack = mock_exit_stack
|
||||
|
||||
await tool._safe_close_exit_stack()
|
||||
|
||||
mock_exit_stack.aclose.assert_called_once()
|
||||
|
||||
|
||||
async def test_connect_sets_logging_level_when_logger_level_is_set():
|
||||
"""Test that connect() sets the MCP server logging level when the logger level is not NOTSET."""
|
||||
|
||||
@@ -4389,6 +4513,52 @@ async def test_mcp_tool_call_tool_forwards_tool_list_meta():
|
||||
assert server.session.call_tool.call_args.kwargs["meta"] == tool_meta
|
||||
|
||||
|
||||
async def test_mcp_tool_call_tool_user_meta_merges_with_tool_list_meta():
|
||||
"""User-provided _meta should be sent as MCP request metadata, not tool arguments."""
|
||||
from opentelemetry import trace
|
||||
|
||||
tool_meta = {"from_tool": "tool-value", "shared": "tool-value"}
|
||||
user_meta = {"from_user": "user-value", "shared": "user-value"}
|
||||
|
||||
class TestServer(MCPTool):
|
||||
async def connect(self) -> None:
|
||||
self.session = Mock(spec=ClientSession)
|
||||
self.session.list_tools = AsyncMock(
|
||||
return_value=types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
|
||||
_meta=tool_meta,
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
self.session.call_tool = AsyncMock(
|
||||
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="result")])
|
||||
)
|
||||
|
||||
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
return None
|
||||
|
||||
server = TestServer(name="test_server")
|
||||
async with server:
|
||||
await server.load_tools()
|
||||
|
||||
with trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)):
|
||||
await server.call_tool("test_tool", param="test_value", _meta=user_meta)
|
||||
|
||||
call_kwargs = server.session.call_tool.call_args.kwargs
|
||||
assert call_kwargs["arguments"] == {"param": "test_value"}
|
||||
assert call_kwargs["meta"] == {
|
||||
"from_tool": "tool-value",
|
||||
"from_user": "user-value",
|
||||
"shared": "user-value",
|
||||
}
|
||||
assert user_meta == {"from_user": "user-value", "shared": "user-value"}
|
||||
|
||||
|
||||
async def test_mcp_streamable_http_tool_hook_not_duplicated_on_repeated_get_mcp_client():
|
||||
"""Test that calling get_mcp_client multiple times does not accumulate duplicate hooks."""
|
||||
tool = MCPStreamableHTTPTool(
|
||||
@@ -4641,6 +4811,42 @@ async def test_mcp_streamable_http_tool_header_provider_with_httpx_event_hook():
|
||||
await tool._httpx_client.aclose()
|
||||
|
||||
|
||||
async def test_mcp_streamable_http_tool_header_provider_skips_cross_origin_redirect():
|
||||
"""The request hook must not re-add caller headers after a cross-origin redirect."""
|
||||
import httpx
|
||||
|
||||
from agent_framework._mcp import _mcp_call_headers
|
||||
|
||||
tool = MCPStreamableHTTPTool(
|
||||
name="test",
|
||||
url="http://example.com/mcp",
|
||||
header_provider=lambda kw: {"Authorization": f"Bearer {kw.get('token', '')}"},
|
||||
)
|
||||
|
||||
try:
|
||||
with patch("agent_framework._mcp.streamable_http_client"):
|
||||
tool.get_mcp_client()
|
||||
|
||||
assert tool._httpx_client is not None
|
||||
hooks = tool._httpx_client.event_hooks.get("request", [])
|
||||
assert len(hooks) == 1
|
||||
|
||||
token = _mcp_call_headers.set({"Authorization": "Bearer secret"})
|
||||
try:
|
||||
same_origin = httpx.Request("POST", "http://example.com/redirected")
|
||||
await hooks[0](same_origin)
|
||||
assert same_origin.headers.get("Authorization") == "Bearer secret"
|
||||
|
||||
cross_origin = httpx.Request("POST", "http://attacker.example/capture")
|
||||
await hooks[0](cross_origin)
|
||||
assert "Authorization" not in cross_origin.headers
|
||||
finally:
|
||||
_mcp_call_headers.reset(token)
|
||||
finally:
|
||||
if getattr(tool, "_httpx_client", None) is not None:
|
||||
await tool._httpx_client.aclose()
|
||||
|
||||
|
||||
async def test_mcp_streamable_http_tool_header_provider_with_user_httpx_client():
|
||||
"""Test that header_provider works when the user provides their own httpx client."""
|
||||
import httpx
|
||||
|
||||
@@ -1691,6 +1691,65 @@ def test_to_otel_part_function_call():
|
||||
}
|
||||
|
||||
|
||||
def test_to_otel_part_function_call_reuses_prepared_arguments():
|
||||
"""Test _to_otel_part does not re-serialize function-call arguments in the observability hot path."""
|
||||
from agent_framework import Content
|
||||
from agent_framework.observability import _to_otel_part
|
||||
|
||||
arguments = {"payload": object()}
|
||||
content = Content(type="function_call", call_id="call_789", name="handoff", arguments=arguments)
|
||||
result = _to_otel_part(content)
|
||||
|
||||
assert result is not None
|
||||
assert result["arguments"] is arguments
|
||||
|
||||
|
||||
def test_make_json_safe_non_callable_method_attribute():
|
||||
"""Test make_json_safe handles objects where model_dump/to_dict/dict are non-callable attributes."""
|
||||
from agent_framework._serialization import make_json_safe
|
||||
|
||||
class ObjWithNonCallableModelDump:
|
||||
model_dump = 42 # not callable
|
||||
|
||||
obj = ObjWithNonCallableModelDump()
|
||||
result = make_json_safe(obj)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_make_json_safe_callable_method_type_error_falls_through():
|
||||
"""Test make_json_safe falls through when serializer-like methods require arguments."""
|
||||
from agent_framework._serialization import make_json_safe
|
||||
|
||||
class ObjWithRequiredArgModelDump:
|
||||
def __init__(self) -> None:
|
||||
self.value = "fallback"
|
||||
|
||||
def model_dump(self, required: str) -> dict[str, str]:
|
||||
return {"required": required}
|
||||
|
||||
obj = ObjWithRequiredArgModelDump()
|
||||
result = make_json_safe(obj)
|
||||
assert result == {"value": "fallback"}
|
||||
|
||||
|
||||
def test_make_json_safe_dict_with_non_string_keys():
|
||||
"""Test make_json_safe converts non-primitive dict keys to strings."""
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from agent_framework._serialization import make_json_safe
|
||||
|
||||
dt_key = datetime(2024, 1, 1)
|
||||
obj = {dt_key: "value", 42: "num_value", "str_key": "normal"}
|
||||
result = make_json_safe(obj)
|
||||
# json.dumps must not raise TypeError
|
||||
serialized = json.dumps(result)
|
||||
parsed = json.loads(serialized)
|
||||
assert parsed[str(dt_key)] == "value"
|
||||
assert parsed["42"] == "num_value"
|
||||
assert parsed["str_key"] == "normal"
|
||||
|
||||
|
||||
def test_to_otel_part_function_result():
|
||||
"""Test _to_otel_part with function_result content."""
|
||||
from agent_framework import Content
|
||||
@@ -3019,6 +3078,49 @@ async def test_system_instructions_preserves_non_ascii_characters(span_exporter:
|
||||
assert [msg.get("role") for msg in input_messages] == ["user"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_sensitive_data", [True], indirect=True)
|
||||
def test_capture_messages_with_prepared_request_info_function_call_arguments(span_exporter: InMemorySpanExporter):
|
||||
"""Test _capture_messages handles request-info function-call arguments prepared at Content creation."""
|
||||
import dataclasses
|
||||
import json
|
||||
|
||||
from opentelemetry import trace
|
||||
|
||||
from agent_framework import WorkflowAgent
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HandoffRequest:
|
||||
target_agent: str
|
||||
reason: str
|
||||
|
||||
arguments = WorkflowAgent.RequestInfoFunctionArgs(
|
||||
request_id="call_dc",
|
||||
data=HandoffRequest(target_agent="helper", reason="overflow"),
|
||||
).to_dict()
|
||||
msg = Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content(
|
||||
type="function_call",
|
||||
call_id="call_dc",
|
||||
name="request_info",
|
||||
arguments=arguments,
|
||||
)
|
||||
],
|
||||
)
|
||||
span_exporter.clear()
|
||||
tracer = trace.get_tracer("test")
|
||||
with tracer.start_as_current_span("test_span") as span:
|
||||
_capture_messages(span=span, provider_name="test_provider", messages=[msg])
|
||||
|
||||
spans = span_exporter.get_finished_spans()
|
||||
span = spans[0]
|
||||
input_messages = json.loads(span.attributes[OtelAttr.INPUT_MESSAGES])
|
||||
tool_part = input_messages[0]["parts"][0]
|
||||
assert tool_part["type"] == "tool_call"
|
||||
assert tool_part["arguments"]["data"] == {"target_agent": "helper", "reason": "overflow"}
|
||||
|
||||
|
||||
def test_capture_messages_keeps_framework_instructions_out_of_logs_and_span_messages(
|
||||
span_exporter: InMemorySpanExporter,
|
||||
):
|
||||
|
||||
@@ -307,6 +307,63 @@ class TestHistoryProviderBase:
|
||||
assert provider.stored[0].text == "hello"
|
||||
assert provider.stored[1].text == "hi"
|
||||
|
||||
async def test_after_run_stores_coalesced_code_interpreter_chunks(self) -> None:
|
||||
from agent_framework import AgentResponse, AgentResponseUpdate, Content
|
||||
|
||||
provider = ConcreteHistoryProvider("mem", store_inputs=False)
|
||||
updates = [
|
||||
AgentResponseUpdate(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_code_interpreter_tool_result(
|
||||
call_id="ci_123",
|
||||
outputs=[],
|
||||
)
|
||||
],
|
||||
),
|
||||
AgentResponseUpdate(
|
||||
contents=[
|
||||
Content.from_code_interpreter_tool_call(
|
||||
call_id="ci_123",
|
||||
inputs=[Content.from_text(text="import")],
|
||||
additional_properties={"sequence_number": 1},
|
||||
)
|
||||
],
|
||||
),
|
||||
AgentResponseUpdate(
|
||||
contents=[
|
||||
Content.from_code_interpreter_tool_call(
|
||||
call_id="ci_123",
|
||||
inputs=[Content.from_text(text=" pandas")],
|
||||
additional_properties={"sequence_number": 2},
|
||||
)
|
||||
],
|
||||
),
|
||||
AgentResponseUpdate(
|
||||
contents=[
|
||||
Content.from_code_interpreter_tool_call(
|
||||
call_id="ci_123",
|
||||
inputs=[Content.from_text(text="import pandas as pd")],
|
||||
additional_properties={"sequence_number": 3},
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
ctx = SessionContext(session_id="s1", input_messages=[Message(role="user", contents=["make a sheet"])])
|
||||
ctx._response = AgentResponse.from_updates(updates)
|
||||
|
||||
await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type]
|
||||
|
||||
assert len(provider.stored) == 1
|
||||
stored_contents = provider.stored[0].contents
|
||||
calls = [content for content in stored_contents if content.type == "code_interpreter_tool_call"]
|
||||
results = [content for content in stored_contents if content.type == "code_interpreter_tool_result"]
|
||||
assert len(calls) == 1
|
||||
assert len(results) == 1
|
||||
assert calls[0].inputs is not None
|
||||
assert len(calls[0].inputs) == 1
|
||||
assert calls[0].inputs[0].text == "import pandas as pd"
|
||||
|
||||
async def test_after_run_skips_inputs_when_disabled(self) -> None:
|
||||
from agent_framework import AgentResponse
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
@@ -1642,6 +1643,37 @@ class TestFunctionalWorkflowAgentHITL:
|
||||
break
|
||||
assert approval_found, "expected FunctionApprovalRequestContent in agent response"
|
||||
|
||||
async def test_request_info_dataclass_arguments_are_serialized_for_agent(self):
|
||||
@dataclass
|
||||
class HandoffRequest:
|
||||
target_agent: str
|
||||
reason: str
|
||||
|
||||
@workflow
|
||||
async def wf(x: str, ctx: RunContext) -> str:
|
||||
answer = await ctx.request_info(
|
||||
HandoffRequest(target_agent=x, reason="overflow"),
|
||||
response_type=str,
|
||||
request_id="rid-1",
|
||||
)
|
||||
return f"got:{answer}"
|
||||
|
||||
agent = wf.as_agent()
|
||||
response = await agent.run("helper")
|
||||
|
||||
function_call_arguments = None
|
||||
for message in response.messages:
|
||||
for content in message.contents:
|
||||
if getattr(content, "type", None) == "function_approval_request" and content.function_call is not None:
|
||||
function_call_arguments = content.function_call.arguments
|
||||
break
|
||||
|
||||
assert function_call_arguments == {
|
||||
"request_id": "rid-1",
|
||||
"data": {"target_agent": "helper", "reason": "overflow"},
|
||||
}
|
||||
assert json.loads(json.dumps(function_call_arguments)) == function_call_arguments
|
||||
|
||||
async def test_resume_via_agent_responses_kwarg(self):
|
||||
@workflow
|
||||
async def wf(x: str, ctx: RunContext) -> str:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
import pytest
|
||||
@@ -23,6 +25,7 @@ from agent_framework import (
|
||||
WorkflowAgent,
|
||||
WorkflowBuilder,
|
||||
WorkflowContext,
|
||||
WorkflowEvent,
|
||||
executor,
|
||||
handler,
|
||||
response_handler,
|
||||
@@ -292,6 +295,33 @@ class TestWorkflowAgent:
|
||||
pending_requests = await workflow._runner_context.get_pending_request_info_events()
|
||||
assert len(pending_requests) == 0
|
||||
|
||||
def test_request_info_dataclass_arguments_are_serialized_when_content_is_created(self) -> None:
|
||||
"""Test WorkflowAgent prepares request_info arguments before observability captures messages."""
|
||||
|
||||
@dataclass
|
||||
class HandoffRequest:
|
||||
target_agent: str
|
||||
reason: str
|
||||
|
||||
executor = SimpleExecutor(id="executor1", response_text="Response")
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="Request Test Agent")
|
||||
event = WorkflowEvent.request_info(
|
||||
request_id="request_123",
|
||||
source_executor_id="executor1",
|
||||
request_data=HandoffRequest(target_agent="helper", reason="overflow"),
|
||||
response_type=str,
|
||||
)
|
||||
|
||||
function_call, approval_request = agent._process_request_info_event(event) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
assert function_call.arguments == {
|
||||
"request_id": "request_123",
|
||||
"data": {"target_agent": "helper", "reason": "overflow"},
|
||||
}
|
||||
assert approval_request.function_call is function_call
|
||||
assert json.loads(json.dumps(function_call.arguments)) == function_call.arguments
|
||||
|
||||
def test_workflow_as_agent_method(self) -> None:
|
||||
"""Test that Workflow.as_agent() creates a properly configured WorkflowAgent."""
|
||||
# Create a simple workflow
|
||||
|
||||
Reference in New Issue
Block a user