mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Fix cache and PR comments
This commit is contained in:
+6
-3
@@ -163,15 +163,18 @@ def _get_output_path(action_def: Mapping[str, Any], key: str) -> str | None:
|
||||
def _format_outputs_for_send(parsed_results: list[Any]) -> str:
|
||||
"""Render parsed MCP outputs to a string for ``ctx.yield_output(...)``.
|
||||
|
||||
- Empty list → ``""``.
|
||||
- All-string list → newline-joined.
|
||||
- Single dict / list → JSON.
|
||||
- Empty / mixed → JSON-dump the whole list.
|
||||
- Single element (any type — scalar, dict, list) → JSON-dumped element.
|
||||
This avoids surprising ``"[42]"`` / ``"[true]"`` / ``"[null]"`` when
|
||||
an MCP tool returns a single scalar JSON value.
|
||||
- Multi-element non-string list → JSON-dump the whole list.
|
||||
"""
|
||||
if not parsed_results:
|
||||
return ""
|
||||
if all(isinstance(item, str) for item in parsed_results):
|
||||
return "\n".join(parsed_results) # type: ignore[arg-type]
|
||||
if len(parsed_results) == 1 and isinstance(parsed_results[0], (dict, list)):
|
||||
if len(parsed_results) == 1:
|
||||
return json.dumps(parsed_results[0], ensure_ascii=False)
|
||||
return json.dumps(parsed_results, ensure_ascii=False)
|
||||
|
||||
|
||||
+98
-24
@@ -158,10 +158,17 @@ class DefaultMCPToolHandler:
|
||||
"""Default :class:`MCPToolHandler` backed by :class:`agent_framework.MCPStreamableHTTPTool`.
|
||||
|
||||
Caches one :class:`agent_framework.MCPStreamableHTTPTool` instance per
|
||||
``(server_url, headers_hash)`` in a bounded LRU. The cache prevents
|
||||
re-establishing an MCP session for every invocation while ensuring
|
||||
different header sets (auth tokens) cannot share a session — matches the
|
||||
.NET design intent while bounding cardinality.
|
||||
``(server_url, server_label, connection_name, headers_hash)`` in a
|
||||
bounded LRU. The cache prevents re-establishing an MCP session for every
|
||||
invocation while ensuring different header sets (auth tokens) cannot
|
||||
share a session — matches the .NET design intent while bounding
|
||||
cardinality. ``server_label`` and ``connection_name`` participate in
|
||||
the key so that callers using ``client_provider`` to dispatch on those
|
||||
fields receive a fresh client per logical connection (see below).
|
||||
Header *names* are lower-cased inside the hash payload only — the
|
||||
headers passed on the wire keep the caller's original casing — so two
|
||||
YAML actions that spell ``Authorization`` differently still share a
|
||||
cache entry.
|
||||
|
||||
Construction modes:
|
||||
|
||||
@@ -197,14 +204,17 @@ class DefaultMCPToolHandler:
|
||||
raise ValueError(f"cache_max_size must be positive, got {cache_max_size}")
|
||||
self._client_provider = client_provider
|
||||
self._cache_max_size = cache_max_size
|
||||
self._cache: OrderedDict[tuple[str, str], _CacheEntry] = OrderedDict()
|
||||
self._cache: OrderedDict[tuple[str, str, str, str], _CacheEntry] = OrderedDict()
|
||||
# Outer lock guards the cache + in-flight-future map only — never
|
||||
# held across network I/O.
|
||||
self._cache_lock = asyncio.Lock()
|
||||
# Per-key in-flight futures: while one task is connecting, other
|
||||
# tasks awaiting the same key will await the same future and share
|
||||
# the resulting cache entry.
|
||||
self._inflight: dict[tuple[str, str], asyncio.Future[_CacheEntry]] = {}
|
||||
self._inflight: dict[tuple[str, str, str, str], asyncio.Future[_CacheEntry]] = {}
|
||||
# Set by ``aclose`` to prevent post-close cache insertions and to
|
||||
# reject new ``invoke_tool`` calls. Once set, never cleared.
|
||||
self._closed = False
|
||||
|
||||
async def invoke_tool(self, invocation: MCPToolInvocation) -> MCPToolResult:
|
||||
"""Invoke ``invocation.tool_name`` on the cached MCP client for the server."""
|
||||
@@ -279,10 +289,32 @@ class DefaultMCPToolHandler:
|
||||
|
||||
Caller-supplied :class:`httpx.AsyncClient` instances (returned by the
|
||||
``client_provider`` callback) are NOT closed.
|
||||
|
||||
Idempotent — a second call returns immediately. Drains any in-flight
|
||||
``_create_entry`` tasks before returning so their resources are
|
||||
cleaned up; the in-flight tasks see ``self._closed`` in phase 3 of
|
||||
:meth:`_get_or_create_entry`, close their own entry, and resolve
|
||||
their future with ``RuntimeError("DefaultMCPToolHandler is closed")``.
|
||||
"""
|
||||
async with self._cache_lock:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
entries = list(self._cache.values())
|
||||
self._cache.clear()
|
||||
inflight_futures = list(self._inflight.values())
|
||||
|
||||
# Wait for in-flight creations to finish their self-cleanup. Each
|
||||
# in-flight task self-closes its entry under the closed-flag branch
|
||||
# in phase 3 and resolves its future with ``RuntimeError``; we
|
||||
# swallow it here because the failure is expected at shutdown.
|
||||
for fut in inflight_futures:
|
||||
try:
|
||||
await fut
|
||||
except BaseException:
|
||||
logger.debug("DefaultMCPToolHandler: in-flight future raised during aclose", exc_info=True)
|
||||
continue
|
||||
|
||||
for entry in entries:
|
||||
await self._close_entry(entry)
|
||||
|
||||
@@ -298,12 +330,19 @@ class DefaultMCPToolHandler:
|
||||
|
||||
async def _get_or_create_entry(self, invocation: MCPToolInvocation) -> _CacheEntry:
|
||||
"""Look up (or create) the cached MCP client for this invocation."""
|
||||
key = self._cache_key(invocation.server_url, invocation.headers)
|
||||
key = self._cache_key(
|
||||
invocation.server_url,
|
||||
invocation.server_label,
|
||||
invocation.connection_name,
|
||||
invocation.headers,
|
||||
)
|
||||
|
||||
# Phase 1: check the cache and either claim creation or wait for an
|
||||
# already in-flight creation.
|
||||
creating = False
|
||||
async with self._cache_lock:
|
||||
if self._closed:
|
||||
raise RuntimeError("DefaultMCPToolHandler is closed")
|
||||
existing = self._cache.get(key)
|
||||
if existing is not None:
|
||||
self._cache.move_to_end(key)
|
||||
@@ -332,25 +371,44 @@ class DefaultMCPToolHandler:
|
||||
raise
|
||||
|
||||
# Phase 3: insert with LRU eviction; resolve the in-flight future.
|
||||
# If ``aclose`` ran while we were connecting, ``_closed`` is now
|
||||
# True; don't insert into the cache (it has been drained), close
|
||||
# the just-built entry, and surface the closed-handler error to
|
||||
# all awaiters of the future.
|
||||
evicted: _CacheEntry | None = None
|
||||
duplicate: _CacheEntry | None = None
|
||||
handler_closed = False
|
||||
async with self._cache_lock:
|
||||
self._inflight.pop(key, None)
|
||||
existing = self._cache.get(key)
|
||||
if existing is not None:
|
||||
# Another writer beat us; prefer the existing entry and
|
||||
# discard ours after the lock is released.
|
||||
self._cache.move_to_end(key)
|
||||
duplicate = entry
|
||||
entry = existing
|
||||
if self._closed:
|
||||
handler_closed = True
|
||||
else:
|
||||
self._cache[key] = entry
|
||||
self._cache.move_to_end(key)
|
||||
if len(self._cache) > self._cache_max_size:
|
||||
_evicted_key, evicted = self._cache.popitem(last=False)
|
||||
if not inflight.done():
|
||||
inflight.set_result(entry)
|
||||
existing = self._cache.get(key)
|
||||
if existing is not None:
|
||||
# Another writer beat us; prefer the existing entry and
|
||||
# discard ours after the lock is released.
|
||||
self._cache.move_to_end(key)
|
||||
duplicate = entry
|
||||
entry = existing
|
||||
else:
|
||||
self._cache[key] = entry
|
||||
self._cache.move_to_end(key)
|
||||
if len(self._cache) > self._cache_max_size:
|
||||
_evicted_key, evicted = self._cache.popitem(last=False)
|
||||
if not inflight.done():
|
||||
inflight.set_result(entry)
|
||||
|
||||
if handler_closed:
|
||||
# Close our orphaned entry; resolve the future with a clear
|
||||
# error so the caller (and any other awaiters) surface a
|
||||
# consistent "handler is closed" failure rather than receiving
|
||||
# an entry we are about to close behind their back.
|
||||
await self._close_entry(entry)
|
||||
err = RuntimeError("DefaultMCPToolHandler is closed")
|
||||
if not inflight.done():
|
||||
inflight.set_exception(err)
|
||||
inflight.exception()
|
||||
raise err
|
||||
if duplicate is not None:
|
||||
await self._close_entry(duplicate)
|
||||
if evicted is not None:
|
||||
@@ -410,11 +468,27 @@ class DefaultMCPToolHandler:
|
||||
logger.debug("DefaultMCPToolHandler: error closing owned httpx client", exc_info=True)
|
||||
|
||||
@staticmethod
|
||||
def _cache_key(server_url: str, headers: dict[str, str] | None) -> tuple[str, str]:
|
||||
"""Build an order-independent cache key for ``(server_url, headers)``."""
|
||||
def _cache_key(
|
||||
server_url: str,
|
||||
server_label: str | None,
|
||||
connection_name: str | None,
|
||||
headers: dict[str, str] | None,
|
||||
) -> tuple[str, str, str, str]:
|
||||
"""Build an order-independent cache key for the invocation identity.
|
||||
|
||||
The key includes ``server_label`` and ``connection_name`` so that
|
||||
callers using ``client_provider`` to dispatch on those fields
|
||||
receive a fresh client per logical connection (matches the
|
||||
documented dispatch contract).
|
||||
|
||||
Header *names* are lower-cased inside the hash payload only so
|
||||
that ``Authorization`` and ``authorization`` map to the same
|
||||
cache entry. Header values remain case-sensitive (per RFC 7235).
|
||||
"""
|
||||
if not headers:
|
||||
headers_hash = "0"
|
||||
else:
|
||||
payload = json.dumps(sorted(headers.items()), ensure_ascii=False)
|
||||
normalized = sorted((k.lower(), v) for k, v in headers.items())
|
||||
payload = json.dumps(normalized, ensure_ascii=False)
|
||||
headers_hash = hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
||||
return (server_url, headers_hash)
|
||||
return (server_url, server_label or "", connection_name or "", headers_hash)
|
||||
|
||||
@@ -237,6 +237,54 @@ class TestCache:
|
||||
assert len(FakeTool.instances) == 1
|
||||
assert FakeTool.instances[0].connect_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_connection_names_create_separate_entries(self) -> None:
|
||||
"""Same URL/headers but different ``connection_name`` must dispatch separately."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(connection_name="conn-A"))
|
||||
await handler.invoke_tool(_invocation(connection_name="conn-B"))
|
||||
assert len(FakeTool.instances) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_server_labels_create_separate_entries(self) -> None:
|
||||
"""Same URL/headers but different ``server_label`` must dispatch separately."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(server_label="LabelA"))
|
||||
await handler.invoke_tool(_invocation(server_label="LabelB"))
|
||||
assert len(FakeTool.instances) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_identity_match_hits_cache(self) -> None:
|
||||
"""All four identity components match → single cached entry."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(server_label="Lbl", connection_name="C", headers={"X": "1"}))
|
||||
await handler.invoke_tool(_invocation(server_label="Lbl", connection_name="C", headers={"X": "1"}))
|
||||
assert len(FakeTool.instances) == 1
|
||||
assert FakeTool.instances[0].connect_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_header_name_case_collapses_to_one_cache_entry(self) -> None:
|
||||
"""Header name spelling differences (case-only) must share a cache entry."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(headers={"Authorization": "tk"}))
|
||||
await handler.invoke_tool(_invocation(headers={"authorization": "tk"}))
|
||||
await handler.invoke_tool(_invocation(headers={"AUTHORIZATION": "tk"}))
|
||||
assert len(FakeTool.instances) == 1
|
||||
assert FakeTool.instances[0].connect_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_header_value_case_does_not_collapse(self) -> None:
|
||||
"""Header *values* remain case-sensitive (different tokens → different sessions)."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(headers={"Authorization": "Bearer-A"}))
|
||||
await handler.invoke_tool(_invocation(headers={"Authorization": "bearer-a"}))
|
||||
assert len(FakeTool.instances) == 2
|
||||
|
||||
|
||||
# ---------- Aclose semantics ----------------------------------------------
|
||||
|
||||
@@ -280,6 +328,66 @@ class TestAclose:
|
||||
tool = FakeTool.instances[0]
|
||||
assert tool.close_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aclose_is_idempotent(self) -> None:
|
||||
"""A second ``aclose`` is a no-op (no exception, no double-close)."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(headers={"X": "1"}))
|
||||
await handler.aclose()
|
||||
await handler.aclose()
|
||||
assert FakeTool.instances[0].close_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_after_close_returns_error_result(self) -> None:
|
||||
"""Post-close ``invoke_tool`` surfaces a tool error rather than crashing."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.aclose()
|
||||
result = await handler.invoke_tool(_invocation())
|
||||
assert result.is_error is True
|
||||
assert "closed" in (result.error_message or "").lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aclose_drains_inflight_creation(self) -> None:
|
||||
"""An in-flight ``_create_entry`` must not leak when ``aclose`` races with it.
|
||||
|
||||
Reproduces the race described in PR #5630 review-comment 3:
|
||||
task A claims an inflight future and starts a slow connect; task B
|
||||
runs ``aclose``; task A must self-clean (close its tool + httpx
|
||||
client) and surface a closed-handler error rather than orphaning
|
||||
the entry.
|
||||
"""
|
||||
handler = DefaultMCPToolHandler()
|
||||
connect_started = asyncio.Event()
|
||||
release_connect = asyncio.Event()
|
||||
original_connect = FakeTool.connect
|
||||
|
||||
async def gated_connect(self: FakeTool) -> None:
|
||||
connect_started.set()
|
||||
await release_connect.wait()
|
||||
await original_connect(self)
|
||||
|
||||
with _patch_tool(), patch.object(FakeTool, "connect", gated_connect):
|
||||
invoke_task = asyncio.create_task(handler.invoke_tool(_invocation(headers={"X": "1"})))
|
||||
# Wait until task A is mid-connect.
|
||||
await connect_started.wait()
|
||||
# Race: kick off aclose. It must wait for the in-flight task.
|
||||
close_task = asyncio.create_task(handler.aclose())
|
||||
# Yield once to ensure aclose has set _closed and is awaiting.
|
||||
await asyncio.sleep(0)
|
||||
# Allow the connect to complete; phase 3 sees _closed and self-cleans.
|
||||
release_connect.set()
|
||||
result = await invoke_task
|
||||
await close_task
|
||||
|
||||
# Entry was created and then closed by the in-flight task itself.
|
||||
assert len(FakeTool.instances) == 1
|
||||
assert FakeTool.instances[0].close_count == 1
|
||||
# The originating invocation surfaces a closed-handler error.
|
||||
assert result.is_error is True
|
||||
assert "closed" in (result.error_message or "").lower()
|
||||
|
||||
|
||||
# ---------- Result normalisation ------------------------------------------
|
||||
|
||||
@@ -400,16 +508,36 @@ class TestErrorMapping:
|
||||
|
||||
class TestCacheKey:
|
||||
def test_key_order_independent(self) -> None:
|
||||
k1 = DefaultMCPToolHandler._cache_key("https://x/", {"A": "1", "B": "2"})
|
||||
k2 = DefaultMCPToolHandler._cache_key("https://x/", {"B": "2", "A": "1"})
|
||||
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"A": "1", "B": "2"})
|
||||
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"B": "2", "A": "1"})
|
||||
assert k1 == k2
|
||||
|
||||
def test_key_distinguishes_values(self) -> None:
|
||||
k1 = DefaultMCPToolHandler._cache_key("https://x/", {"A": "1"})
|
||||
k2 = DefaultMCPToolHandler._cache_key("https://x/", {"A": "2"})
|
||||
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"A": "1"})
|
||||
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"A": "2"})
|
||||
assert k1 != k2
|
||||
|
||||
def test_empty_headers_use_fixed_hash(self) -> None:
|
||||
k1 = DefaultMCPToolHandler._cache_key("https://x/", None)
|
||||
k2 = DefaultMCPToolHandler._cache_key("https://x/", {})
|
||||
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, None)
|
||||
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {})
|
||||
assert k1 == k2
|
||||
|
||||
def test_key_distinguishes_connection_name(self) -> None:
|
||||
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, "conn-A", None)
|
||||
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, "conn-B", None)
|
||||
assert k1 != k2
|
||||
|
||||
def test_key_distinguishes_server_label(self) -> None:
|
||||
k1 = DefaultMCPToolHandler._cache_key("https://x/", "Lbl-A", None, None)
|
||||
k2 = DefaultMCPToolHandler._cache_key("https://x/", "Lbl-B", None, None)
|
||||
assert k1 != k2
|
||||
|
||||
def test_key_collapses_header_name_case(self) -> None:
|
||||
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"Authorization": "tk"})
|
||||
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"authorization": "tk"})
|
||||
assert k1 == k2
|
||||
|
||||
def test_key_keeps_header_value_case(self) -> None:
|
||||
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"X": "Bearer-A"})
|
||||
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"X": "bearer-a"})
|
||||
assert k1 != k2
|
||||
|
||||
@@ -629,3 +629,36 @@ class TestProtocol:
|
||||
def test_stub_handler_satisfies_protocol(self) -> None:
|
||||
handler = StubMcpHandler(_ok())
|
||||
assert isinstance(handler, MCPToolHandler)
|
||||
|
||||
|
||||
# ---------- _format_outputs_for_send --------------------------------------
|
||||
|
||||
|
||||
class TestFormatOutputsForSend:
|
||||
"""Direct tests for the auto-send rendering helper.
|
||||
|
||||
Regression for PR #5630 review-comment 4: a single scalar JSON value
|
||||
must render bare (e.g. ``"42"``) rather than wrapped (``"[42]"``).
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("parsed", "expected"),
|
||||
[
|
||||
([], ""),
|
||||
(["hello"], "hello"),
|
||||
(["a", "b"], "a\nb"),
|
||||
([42], "42"),
|
||||
([3.14], "3.14"),
|
||||
([True], "true"),
|
||||
([False], "false"),
|
||||
([None], "null"),
|
||||
([{"k": "v"}], '{"k": "v"}'),
|
||||
([[1, 2]], "[1, 2]"),
|
||||
(["hello", 42], '["hello", 42]'),
|
||||
([{"a": 1}, {"b": 2}], '[{"a": 1}, {"b": 2}]'),
|
||||
],
|
||||
)
|
||||
def test_format_outputs_for_send(self, parsed: list[Any], expected: str) -> None:
|
||||
from agent_framework_declarative._workflows._executors_mcp import _format_outputs_for_send
|
||||
|
||||
assert _format_outputs_for_send(parsed) == expected
|
||||
|
||||
Reference in New Issue
Block a user