Fix cache and PR comments

This commit is contained in:
Peter Ibekwe
2026-05-04 15:08:04 -07:00
Unverified
parent 92cd194122
commit ea0b3c1210
4 changed files with 271 additions and 33 deletions
@@ -163,15 +163,18 @@ def _get_output_path(action_def: Mapping[str, Any], key: str) -> str | None:
def _format_outputs_for_send(parsed_results: list[Any]) -> str:
"""Render parsed MCP outputs to a string for ``ctx.yield_output(...)``.
- Empty list → ``""``.
- All-string list → newline-joined.
- Single dict / list → JSON.
- Empty / mixed → JSON-dump the whole list.
- Single element (any type — scalar, dict, list) → JSON-dumped element.
This avoids surprising ``"[42]"`` / ``"[true]"`` / ``"[null]"`` when
an MCP tool returns a single scalar JSON value.
- Multi-element non-string list → JSON-dump the whole list.
"""
if not parsed_results:
return ""
if all(isinstance(item, str) for item in parsed_results):
return "\n".join(parsed_results) # type: ignore[arg-type]
if len(parsed_results) == 1 and isinstance(parsed_results[0], (dict, list)):
if len(parsed_results) == 1:
return json.dumps(parsed_results[0], ensure_ascii=False)
return json.dumps(parsed_results, ensure_ascii=False)
@@ -158,10 +158,17 @@ class DefaultMCPToolHandler:
"""Default :class:`MCPToolHandler` backed by :class:`agent_framework.MCPStreamableHTTPTool`.
Caches one :class:`agent_framework.MCPStreamableHTTPTool` instance per
``(server_url, headers_hash)`` in a bounded LRU. The cache prevents
re-establishing an MCP session for every invocation while ensuring
different header sets (auth tokens) cannot share a session — matches the
.NET design intent while bounding cardinality.
``(server_url, server_label, connection_name, headers_hash)`` in a
bounded LRU. The cache prevents re-establishing an MCP session for every
invocation while ensuring different header sets (auth tokens) cannot
share a session — matches the .NET design intent while bounding
cardinality. ``server_label`` and ``connection_name`` participate in
the key so that callers using ``client_provider`` to dispatch on those
fields receive a fresh client per logical connection (see below).
Header *names* are lower-cased inside the hash payload only — the
headers passed on the wire keep the caller's original casing — so two
YAML actions that spell ``Authorization`` differently still share a
cache entry.
Construction modes:
@@ -197,14 +204,17 @@ class DefaultMCPToolHandler:
raise ValueError(f"cache_max_size must be positive, got {cache_max_size}")
self._client_provider = client_provider
self._cache_max_size = cache_max_size
self._cache: OrderedDict[tuple[str, str], _CacheEntry] = OrderedDict()
self._cache: OrderedDict[tuple[str, str, str, str], _CacheEntry] = OrderedDict()
# Outer lock guards the cache + in-flight-future map only — never
# held across network I/O.
self._cache_lock = asyncio.Lock()
# Per-key in-flight futures: while one task is connecting, other
# tasks awaiting the same key will await the same future and share
# the resulting cache entry.
self._inflight: dict[tuple[str, str], asyncio.Future[_CacheEntry]] = {}
self._inflight: dict[tuple[str, str, str, str], asyncio.Future[_CacheEntry]] = {}
# Set by ``aclose`` to prevent post-close cache insertions and to
# reject new ``invoke_tool`` calls. Once set, never cleared.
self._closed = False
async def invoke_tool(self, invocation: MCPToolInvocation) -> MCPToolResult:
"""Invoke ``invocation.tool_name`` on the cached MCP client for the server."""
@@ -279,10 +289,32 @@ class DefaultMCPToolHandler:
Caller-supplied :class:`httpx.AsyncClient` instances (returned by the
``client_provider`` callback) are NOT closed.
Idempotent — a second call returns immediately. Drains any in-flight
``_create_entry`` tasks before returning so their resources are
cleaned up; the in-flight tasks see ``self._closed`` in phase 3 of
:meth:`_get_or_create_entry`, close their own entry, and resolve
their future with ``RuntimeError("DefaultMCPToolHandler is closed")``.
"""
async with self._cache_lock:
if self._closed:
return
self._closed = True
entries = list(self._cache.values())
self._cache.clear()
inflight_futures = list(self._inflight.values())
# Wait for in-flight creations to finish their self-cleanup. Each
# in-flight task self-closes its entry under the closed-flag branch
# in phase 3 and resolves its future with ``RuntimeError``; we
# swallow it here because the failure is expected at shutdown.
for fut in inflight_futures:
try:
await fut
except BaseException:
logger.debug("DefaultMCPToolHandler: in-flight future raised during aclose", exc_info=True)
continue
for entry in entries:
await self._close_entry(entry)
@@ -298,12 +330,19 @@ class DefaultMCPToolHandler:
async def _get_or_create_entry(self, invocation: MCPToolInvocation) -> _CacheEntry:
"""Look up (or create) the cached MCP client for this invocation."""
key = self._cache_key(invocation.server_url, invocation.headers)
key = self._cache_key(
invocation.server_url,
invocation.server_label,
invocation.connection_name,
invocation.headers,
)
# Phase 1: check the cache and either claim creation or wait for an
# already in-flight creation.
creating = False
async with self._cache_lock:
if self._closed:
raise RuntimeError("DefaultMCPToolHandler is closed")
existing = self._cache.get(key)
if existing is not None:
self._cache.move_to_end(key)
@@ -332,25 +371,44 @@ class DefaultMCPToolHandler:
raise
# Phase 3: insert with LRU eviction; resolve the in-flight future.
# If ``aclose`` ran while we were connecting, ``_closed`` is now
# True; don't insert into the cache (it has been drained), close
# the just-built entry, and surface the closed-handler error to
# all awaiters of the future.
evicted: _CacheEntry | None = None
duplicate: _CacheEntry | None = None
handler_closed = False
async with self._cache_lock:
self._inflight.pop(key, None)
existing = self._cache.get(key)
if existing is not None:
# Another writer beat us; prefer the existing entry and
# discard ours after the lock is released.
self._cache.move_to_end(key)
duplicate = entry
entry = existing
if self._closed:
handler_closed = True
else:
self._cache[key] = entry
self._cache.move_to_end(key)
if len(self._cache) > self._cache_max_size:
_evicted_key, evicted = self._cache.popitem(last=False)
if not inflight.done():
inflight.set_result(entry)
existing = self._cache.get(key)
if existing is not None:
# Another writer beat us; prefer the existing entry and
# discard ours after the lock is released.
self._cache.move_to_end(key)
duplicate = entry
entry = existing
else:
self._cache[key] = entry
self._cache.move_to_end(key)
if len(self._cache) > self._cache_max_size:
_evicted_key, evicted = self._cache.popitem(last=False)
if not inflight.done():
inflight.set_result(entry)
if handler_closed:
# Close our orphaned entry; resolve the future with a clear
# error so the caller (and any other awaiters) surface a
# consistent "handler is closed" failure rather than receiving
# an entry we are about to close behind their back.
await self._close_entry(entry)
err = RuntimeError("DefaultMCPToolHandler is closed")
if not inflight.done():
inflight.set_exception(err)
inflight.exception()
raise err
if duplicate is not None:
await self._close_entry(duplicate)
if evicted is not None:
@@ -410,11 +468,27 @@ class DefaultMCPToolHandler:
logger.debug("DefaultMCPToolHandler: error closing owned httpx client", exc_info=True)
@staticmethod
def _cache_key(server_url: str, headers: dict[str, str] | None) -> tuple[str, str]:
"""Build an order-independent cache key for ``(server_url, headers)``."""
def _cache_key(
server_url: str,
server_label: str | None,
connection_name: str | None,
headers: dict[str, str] | None,
) -> tuple[str, str, str, str]:
"""Build an order-independent cache key for the invocation identity.
The key includes ``server_label`` and ``connection_name`` so that
callers using ``client_provider`` to dispatch on those fields
receive a fresh client per logical connection (matches the
documented dispatch contract).
Header *names* are lower-cased inside the hash payload only so
that ``Authorization`` and ``authorization`` map to the same
cache entry. Header values remain case-sensitive (per RFC 7235).
"""
if not headers:
headers_hash = "0"
else:
payload = json.dumps(sorted(headers.items()), ensure_ascii=False)
normalized = sorted((k.lower(), v) for k, v in headers.items())
payload = json.dumps(normalized, ensure_ascii=False)
headers_hash = hashlib.sha256(payload.encode("utf-8")).hexdigest()
return (server_url, headers_hash)
return (server_url, server_label or "", connection_name or "", headers_hash)
@@ -237,6 +237,54 @@ class TestCache:
assert len(FakeTool.instances) == 1
assert FakeTool.instances[0].connect_count == 1
@pytest.mark.asyncio
async def test_different_connection_names_create_separate_entries(self) -> None:
"""Same URL/headers but different ``connection_name`` must dispatch separately."""
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(connection_name="conn-A"))
await handler.invoke_tool(_invocation(connection_name="conn-B"))
assert len(FakeTool.instances) == 2
@pytest.mark.asyncio
async def test_different_server_labels_create_separate_entries(self) -> None:
"""Same URL/headers but different ``server_label`` must dispatch separately."""
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(server_label="LabelA"))
await handler.invoke_tool(_invocation(server_label="LabelB"))
assert len(FakeTool.instances) == 2
@pytest.mark.asyncio
async def test_full_identity_match_hits_cache(self) -> None:
"""All four identity components match → single cached entry."""
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(server_label="Lbl", connection_name="C", headers={"X": "1"}))
await handler.invoke_tool(_invocation(server_label="Lbl", connection_name="C", headers={"X": "1"}))
assert len(FakeTool.instances) == 1
assert FakeTool.instances[0].connect_count == 1
@pytest.mark.asyncio
async def test_header_name_case_collapses_to_one_cache_entry(self) -> None:
"""Header name spelling differences (case-only) must share a cache entry."""
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(headers={"Authorization": "tk"}))
await handler.invoke_tool(_invocation(headers={"authorization": "tk"}))
await handler.invoke_tool(_invocation(headers={"AUTHORIZATION": "tk"}))
assert len(FakeTool.instances) == 1
assert FakeTool.instances[0].connect_count == 1
@pytest.mark.asyncio
async def test_header_value_case_does_not_collapse(self) -> None:
"""Header *values* remain case-sensitive (different tokens → different sessions)."""
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(headers={"Authorization": "Bearer-A"}))
await handler.invoke_tool(_invocation(headers={"Authorization": "bearer-a"}))
assert len(FakeTool.instances) == 2
# ---------- Aclose semantics ----------------------------------------------
@@ -280,6 +328,66 @@ class TestAclose:
tool = FakeTool.instances[0]
assert tool.close_count == 1
@pytest.mark.asyncio
async def test_aclose_is_idempotent(self) -> None:
"""A second ``aclose`` is a no-op (no exception, no double-close)."""
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(headers={"X": "1"}))
await handler.aclose()
await handler.aclose()
assert FakeTool.instances[0].close_count == 1
@pytest.mark.asyncio
async def test_invoke_after_close_returns_error_result(self) -> None:
"""Post-close ``invoke_tool`` surfaces a tool error rather than crashing."""
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.aclose()
result = await handler.invoke_tool(_invocation())
assert result.is_error is True
assert "closed" in (result.error_message or "").lower()
@pytest.mark.asyncio
async def test_aclose_drains_inflight_creation(self) -> None:
"""An in-flight ``_create_entry`` must not leak when ``aclose`` races with it.
Reproduces the race described in PR #5630 review-comment 3:
task A claims an inflight future and starts a slow connect; task B
runs ``aclose``; task A must self-clean (close its tool + httpx
client) and surface a closed-handler error rather than orphaning
the entry.
"""
handler = DefaultMCPToolHandler()
connect_started = asyncio.Event()
release_connect = asyncio.Event()
original_connect = FakeTool.connect
async def gated_connect(self: FakeTool) -> None:
connect_started.set()
await release_connect.wait()
await original_connect(self)
with _patch_tool(), patch.object(FakeTool, "connect", gated_connect):
invoke_task = asyncio.create_task(handler.invoke_tool(_invocation(headers={"X": "1"})))
# Wait until task A is mid-connect.
await connect_started.wait()
# Race: kick off aclose. It must wait for the in-flight task.
close_task = asyncio.create_task(handler.aclose())
# Yield once to ensure aclose has set _closed and is awaiting.
await asyncio.sleep(0)
# Allow the connect to complete; phase 3 sees _closed and self-cleans.
release_connect.set()
result = await invoke_task
await close_task
# Entry was created and then closed by the in-flight task itself.
assert len(FakeTool.instances) == 1
assert FakeTool.instances[0].close_count == 1
# The originating invocation surfaces a closed-handler error.
assert result.is_error is True
assert "closed" in (result.error_message or "").lower()
# ---------- Result normalisation ------------------------------------------
@@ -400,16 +508,36 @@ class TestErrorMapping:
class TestCacheKey:
def test_key_order_independent(self) -> None:
k1 = DefaultMCPToolHandler._cache_key("https://x/", {"A": "1", "B": "2"})
k2 = DefaultMCPToolHandler._cache_key("https://x/", {"B": "2", "A": "1"})
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"A": "1", "B": "2"})
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"B": "2", "A": "1"})
assert k1 == k2
def test_key_distinguishes_values(self) -> None:
k1 = DefaultMCPToolHandler._cache_key("https://x/", {"A": "1"})
k2 = DefaultMCPToolHandler._cache_key("https://x/", {"A": "2"})
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"A": "1"})
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"A": "2"})
assert k1 != k2
def test_empty_headers_use_fixed_hash(self) -> None:
k1 = DefaultMCPToolHandler._cache_key("https://x/", None)
k2 = DefaultMCPToolHandler._cache_key("https://x/", {})
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, None)
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {})
assert k1 == k2
def test_key_distinguishes_connection_name(self) -> None:
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, "conn-A", None)
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, "conn-B", None)
assert k1 != k2
def test_key_distinguishes_server_label(self) -> None:
k1 = DefaultMCPToolHandler._cache_key("https://x/", "Lbl-A", None, None)
k2 = DefaultMCPToolHandler._cache_key("https://x/", "Lbl-B", None, None)
assert k1 != k2
def test_key_collapses_header_name_case(self) -> None:
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"Authorization": "tk"})
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"authorization": "tk"})
assert k1 == k2
def test_key_keeps_header_value_case(self) -> None:
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"X": "Bearer-A"})
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"X": "bearer-a"})
assert k1 != k2
@@ -629,3 +629,36 @@ class TestProtocol:
def test_stub_handler_satisfies_protocol(self) -> None:
handler = StubMcpHandler(_ok())
assert isinstance(handler, MCPToolHandler)
# ---------- _format_outputs_for_send --------------------------------------
class TestFormatOutputsForSend:
"""Direct tests for the auto-send rendering helper.
Regression for PR #5630 review-comment 4: a single scalar JSON value
must render bare (e.g. ``"42"``) rather than wrapped (``"[42]"``).
"""
@pytest.mark.parametrize(
("parsed", "expected"),
[
([], ""),
(["hello"], "hello"),
(["a", "b"], "a\nb"),
([42], "42"),
([3.14], "3.14"),
([True], "true"),
([False], "false"),
([None], "null"),
([{"k": "v"}], '{"k": "v"}'),
([[1, 2]], "[1, 2]"),
(["hello", 42], '["hello", 42]'),
([{"a": 1}, {"b": 2}], '[{"a": 1}, {"b": 2}]'),
],
)
def test_format_outputs_for_send(self, parsed: list[Any], expected: str) -> None:
from agent_framework_declarative._workflows._executors_mcp import _format_outputs_for_send
assert _format_outputs_for_send(parsed) == expected