From ea0b3c12107e128aec7af08b87b1d8dd13762fcf Mon Sep 17 00:00:00 2001 From: Peter Ibekwe Date: Mon, 4 May 2026 15:08:04 -0700 Subject: [PATCH] Fix cache and PR comments --- .../_workflows/_executors_mcp.py | 9 +- .../_workflows/_mcp_handler.py | 122 ++++++++++++--- .../tests/test_default_mcp_tool_handler.py | 140 +++++++++++++++++- .../tests/test_invoke_mcp_tool_executor.py | 33 +++++ 4 files changed, 271 insertions(+), 33 deletions(-) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py index a1501a3cb1..73b66341ea 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py @@ -163,15 +163,18 @@ def _get_output_path(action_def: Mapping[str, Any], key: str) -> str | None: def _format_outputs_for_send(parsed_results: list[Any]) -> str: """Render parsed MCP outputs to a string for ``ctx.yield_output(...)``. + - Empty list → ``""``. - All-string list → newline-joined. - - Single dict / list → JSON. - - Empty / mixed → JSON-dump the whole list. + - Single element (any type — scalar, dict, list) → JSON-dumped element. + This avoids surprising ``"[42]"`` / ``"[true]"`` / ``"[null]"`` when + an MCP tool returns a single scalar JSON value. + - Multi-element non-string list → JSON-dump the whole list. """ if not parsed_results: return "" if all(isinstance(item, str) for item in parsed_results): return "\n".join(parsed_results) # type: ignore[arg-type] - if len(parsed_results) == 1 and isinstance(parsed_results[0], (dict, list)): + if len(parsed_results) == 1: return json.dumps(parsed_results[0], ensure_ascii=False) return json.dumps(parsed_results, ensure_ascii=False) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_mcp_handler.py b/python/packages/declarative/agent_framework_declarative/_workflows/_mcp_handler.py index 5ea17aabc1..658ce42c23 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_mcp_handler.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_mcp_handler.py @@ -158,10 +158,17 @@ class DefaultMCPToolHandler: """Default :class:`MCPToolHandler` backed by :class:`agent_framework.MCPStreamableHTTPTool`. Caches one :class:`agent_framework.MCPStreamableHTTPTool` instance per - ``(server_url, headers_hash)`` in a bounded LRU. The cache prevents - re-establishing an MCP session for every invocation while ensuring - different header sets (auth tokens) cannot share a session — matches the - .NET design intent while bounding cardinality. + ``(server_url, server_label, connection_name, headers_hash)`` in a + bounded LRU. The cache prevents re-establishing an MCP session for every + invocation while ensuring different header sets (auth tokens) cannot + share a session — matches the .NET design intent while bounding + cardinality. ``server_label`` and ``connection_name`` participate in + the key so that callers using ``client_provider`` to dispatch on those + fields receive a fresh client per logical connection (see below). + Header *names* are lower-cased inside the hash payload only — the + headers passed on the wire keep the caller's original casing — so two + YAML actions that spell ``Authorization`` differently still share a + cache entry. Construction modes: @@ -197,14 +204,17 @@ class DefaultMCPToolHandler: raise ValueError(f"cache_max_size must be positive, got {cache_max_size}") self._client_provider = client_provider self._cache_max_size = cache_max_size - self._cache: OrderedDict[tuple[str, str], _CacheEntry] = OrderedDict() + self._cache: OrderedDict[tuple[str, str, str, str], _CacheEntry] = OrderedDict() # Outer lock guards the cache + in-flight-future map only — never # held across network I/O. self._cache_lock = asyncio.Lock() # Per-key in-flight futures: while one task is connecting, other # tasks awaiting the same key will await the same future and share # the resulting cache entry. - self._inflight: dict[tuple[str, str], asyncio.Future[_CacheEntry]] = {} + self._inflight: dict[tuple[str, str, str, str], asyncio.Future[_CacheEntry]] = {} + # Set by ``aclose`` to prevent post-close cache insertions and to + # reject new ``invoke_tool`` calls. Once set, never cleared. + self._closed = False async def invoke_tool(self, invocation: MCPToolInvocation) -> MCPToolResult: """Invoke ``invocation.tool_name`` on the cached MCP client for the server.""" @@ -279,10 +289,32 @@ class DefaultMCPToolHandler: Caller-supplied :class:`httpx.AsyncClient` instances (returned by the ``client_provider`` callback) are NOT closed. + + Idempotent — a second call returns immediately. Drains any in-flight + ``_create_entry`` tasks before returning so their resources are + cleaned up; the in-flight tasks see ``self._closed`` in phase 3 of + :meth:`_get_or_create_entry`, close their own entry, and resolve + their future with ``RuntimeError("DefaultMCPToolHandler is closed")``. """ async with self._cache_lock: + if self._closed: + return + self._closed = True entries = list(self._cache.values()) self._cache.clear() + inflight_futures = list(self._inflight.values()) + + # Wait for in-flight creations to finish their self-cleanup. Each + # in-flight task self-closes its entry under the closed-flag branch + # in phase 3 and resolves its future with ``RuntimeError``; we + # swallow it here because the failure is expected at shutdown. + for fut in inflight_futures: + try: + await fut + except BaseException: + logger.debug("DefaultMCPToolHandler: in-flight future raised during aclose", exc_info=True) + continue + for entry in entries: await self._close_entry(entry) @@ -298,12 +330,19 @@ class DefaultMCPToolHandler: async def _get_or_create_entry(self, invocation: MCPToolInvocation) -> _CacheEntry: """Look up (or create) the cached MCP client for this invocation.""" - key = self._cache_key(invocation.server_url, invocation.headers) + key = self._cache_key( + invocation.server_url, + invocation.server_label, + invocation.connection_name, + invocation.headers, + ) # Phase 1: check the cache and either claim creation or wait for an # already in-flight creation. creating = False async with self._cache_lock: + if self._closed: + raise RuntimeError("DefaultMCPToolHandler is closed") existing = self._cache.get(key) if existing is not None: self._cache.move_to_end(key) @@ -332,25 +371,44 @@ class DefaultMCPToolHandler: raise # Phase 3: insert with LRU eviction; resolve the in-flight future. + # If ``aclose`` ran while we were connecting, ``_closed`` is now + # True; don't insert into the cache (it has been drained), close + # the just-built entry, and surface the closed-handler error to + # all awaiters of the future. evicted: _CacheEntry | None = None duplicate: _CacheEntry | None = None + handler_closed = False async with self._cache_lock: self._inflight.pop(key, None) - existing = self._cache.get(key) - if existing is not None: - # Another writer beat us; prefer the existing entry and - # discard ours after the lock is released. - self._cache.move_to_end(key) - duplicate = entry - entry = existing + if self._closed: + handler_closed = True else: - self._cache[key] = entry - self._cache.move_to_end(key) - if len(self._cache) > self._cache_max_size: - _evicted_key, evicted = self._cache.popitem(last=False) - if not inflight.done(): - inflight.set_result(entry) + existing = self._cache.get(key) + if existing is not None: + # Another writer beat us; prefer the existing entry and + # discard ours after the lock is released. + self._cache.move_to_end(key) + duplicate = entry + entry = existing + else: + self._cache[key] = entry + self._cache.move_to_end(key) + if len(self._cache) > self._cache_max_size: + _evicted_key, evicted = self._cache.popitem(last=False) + if not inflight.done(): + inflight.set_result(entry) + if handler_closed: + # Close our orphaned entry; resolve the future with a clear + # error so the caller (and any other awaiters) surface a + # consistent "handler is closed" failure rather than receiving + # an entry we are about to close behind their back. + await self._close_entry(entry) + err = RuntimeError("DefaultMCPToolHandler is closed") + if not inflight.done(): + inflight.set_exception(err) + inflight.exception() + raise err if duplicate is not None: await self._close_entry(duplicate) if evicted is not None: @@ -410,11 +468,27 @@ class DefaultMCPToolHandler: logger.debug("DefaultMCPToolHandler: error closing owned httpx client", exc_info=True) @staticmethod - def _cache_key(server_url: str, headers: dict[str, str] | None) -> tuple[str, str]: - """Build an order-independent cache key for ``(server_url, headers)``.""" + def _cache_key( + server_url: str, + server_label: str | None, + connection_name: str | None, + headers: dict[str, str] | None, + ) -> tuple[str, str, str, str]: + """Build an order-independent cache key for the invocation identity. + + The key includes ``server_label`` and ``connection_name`` so that + callers using ``client_provider`` to dispatch on those fields + receive a fresh client per logical connection (matches the + documented dispatch contract). + + Header *names* are lower-cased inside the hash payload only so + that ``Authorization`` and ``authorization`` map to the same + cache entry. Header values remain case-sensitive (per RFC 7235). + """ if not headers: headers_hash = "0" else: - payload = json.dumps(sorted(headers.items()), ensure_ascii=False) + normalized = sorted((k.lower(), v) for k, v in headers.items()) + payload = json.dumps(normalized, ensure_ascii=False) headers_hash = hashlib.sha256(payload.encode("utf-8")).hexdigest() - return (server_url, headers_hash) + return (server_url, server_label or "", connection_name or "", headers_hash) diff --git a/python/packages/declarative/tests/test_default_mcp_tool_handler.py b/python/packages/declarative/tests/test_default_mcp_tool_handler.py index c40d275e80..3a5c67e1d6 100644 --- a/python/packages/declarative/tests/test_default_mcp_tool_handler.py +++ b/python/packages/declarative/tests/test_default_mcp_tool_handler.py @@ -237,6 +237,54 @@ class TestCache: assert len(FakeTool.instances) == 1 assert FakeTool.instances[0].connect_count == 1 + @pytest.mark.asyncio + async def test_different_connection_names_create_separate_entries(self) -> None: + """Same URL/headers but different ``connection_name`` must dispatch separately.""" + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(connection_name="conn-A")) + await handler.invoke_tool(_invocation(connection_name="conn-B")) + assert len(FakeTool.instances) == 2 + + @pytest.mark.asyncio + async def test_different_server_labels_create_separate_entries(self) -> None: + """Same URL/headers but different ``server_label`` must dispatch separately.""" + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(server_label="LabelA")) + await handler.invoke_tool(_invocation(server_label="LabelB")) + assert len(FakeTool.instances) == 2 + + @pytest.mark.asyncio + async def test_full_identity_match_hits_cache(self) -> None: + """All four identity components match → single cached entry.""" + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(server_label="Lbl", connection_name="C", headers={"X": "1"})) + await handler.invoke_tool(_invocation(server_label="Lbl", connection_name="C", headers={"X": "1"})) + assert len(FakeTool.instances) == 1 + assert FakeTool.instances[0].connect_count == 1 + + @pytest.mark.asyncio + async def test_header_name_case_collapses_to_one_cache_entry(self) -> None: + """Header name spelling differences (case-only) must share a cache entry.""" + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(headers={"Authorization": "tk"})) + await handler.invoke_tool(_invocation(headers={"authorization": "tk"})) + await handler.invoke_tool(_invocation(headers={"AUTHORIZATION": "tk"})) + assert len(FakeTool.instances) == 1 + assert FakeTool.instances[0].connect_count == 1 + + @pytest.mark.asyncio + async def test_header_value_case_does_not_collapse(self) -> None: + """Header *values* remain case-sensitive (different tokens → different sessions).""" + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(headers={"Authorization": "Bearer-A"})) + await handler.invoke_tool(_invocation(headers={"Authorization": "bearer-a"})) + assert len(FakeTool.instances) == 2 + # ---------- Aclose semantics ---------------------------------------------- @@ -280,6 +328,66 @@ class TestAclose: tool = FakeTool.instances[0] assert tool.close_count == 1 + @pytest.mark.asyncio + async def test_aclose_is_idempotent(self) -> None: + """A second ``aclose`` is a no-op (no exception, no double-close).""" + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(headers={"X": "1"})) + await handler.aclose() + await handler.aclose() + assert FakeTool.instances[0].close_count == 1 + + @pytest.mark.asyncio + async def test_invoke_after_close_returns_error_result(self) -> None: + """Post-close ``invoke_tool`` surfaces a tool error rather than crashing.""" + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.aclose() + result = await handler.invoke_tool(_invocation()) + assert result.is_error is True + assert "closed" in (result.error_message or "").lower() + + @pytest.mark.asyncio + async def test_aclose_drains_inflight_creation(self) -> None: + """An in-flight ``_create_entry`` must not leak when ``aclose`` races with it. + + Reproduces the race described in PR #5630 review-comment 3: + task A claims an inflight future and starts a slow connect; task B + runs ``aclose``; task A must self-clean (close its tool + httpx + client) and surface a closed-handler error rather than orphaning + the entry. + """ + handler = DefaultMCPToolHandler() + connect_started = asyncio.Event() + release_connect = asyncio.Event() + original_connect = FakeTool.connect + + async def gated_connect(self: FakeTool) -> None: + connect_started.set() + await release_connect.wait() + await original_connect(self) + + with _patch_tool(), patch.object(FakeTool, "connect", gated_connect): + invoke_task = asyncio.create_task(handler.invoke_tool(_invocation(headers={"X": "1"}))) + # Wait until task A is mid-connect. + await connect_started.wait() + # Race: kick off aclose. It must wait for the in-flight task. + close_task = asyncio.create_task(handler.aclose()) + # Yield once to ensure aclose has set _closed and is awaiting. + await asyncio.sleep(0) + # Allow the connect to complete; phase 3 sees _closed and self-cleans. + release_connect.set() + result = await invoke_task + await close_task + + # Entry was created and then closed by the in-flight task itself. + assert len(FakeTool.instances) == 1 + assert FakeTool.instances[0].close_count == 1 + # The originating invocation surfaces a closed-handler error. + assert result.is_error is True + assert "closed" in (result.error_message or "").lower() + # ---------- Result normalisation ------------------------------------------ @@ -400,16 +508,36 @@ class TestErrorMapping: class TestCacheKey: def test_key_order_independent(self) -> None: - k1 = DefaultMCPToolHandler._cache_key("https://x/", {"A": "1", "B": "2"}) - k2 = DefaultMCPToolHandler._cache_key("https://x/", {"B": "2", "A": "1"}) + k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"A": "1", "B": "2"}) + k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"B": "2", "A": "1"}) assert k1 == k2 def test_key_distinguishes_values(self) -> None: - k1 = DefaultMCPToolHandler._cache_key("https://x/", {"A": "1"}) - k2 = DefaultMCPToolHandler._cache_key("https://x/", {"A": "2"}) + k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"A": "1"}) + k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"A": "2"}) assert k1 != k2 def test_empty_headers_use_fixed_hash(self) -> None: - k1 = DefaultMCPToolHandler._cache_key("https://x/", None) - k2 = DefaultMCPToolHandler._cache_key("https://x/", {}) + k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, None) + k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {}) assert k1 == k2 + + def test_key_distinguishes_connection_name(self) -> None: + k1 = DefaultMCPToolHandler._cache_key("https://x/", None, "conn-A", None) + k2 = DefaultMCPToolHandler._cache_key("https://x/", None, "conn-B", None) + assert k1 != k2 + + def test_key_distinguishes_server_label(self) -> None: + k1 = DefaultMCPToolHandler._cache_key("https://x/", "Lbl-A", None, None) + k2 = DefaultMCPToolHandler._cache_key("https://x/", "Lbl-B", None, None) + assert k1 != k2 + + def test_key_collapses_header_name_case(self) -> None: + k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"Authorization": "tk"}) + k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"authorization": "tk"}) + assert k1 == k2 + + def test_key_keeps_header_value_case(self) -> None: + k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"X": "Bearer-A"}) + k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"X": "bearer-a"}) + assert k1 != k2 diff --git a/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py b/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py index 867b8543bb..fdee1f7df1 100644 --- a/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py +++ b/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py @@ -629,3 +629,36 @@ class TestProtocol: def test_stub_handler_satisfies_protocol(self) -> None: handler = StubMcpHandler(_ok()) assert isinstance(handler, MCPToolHandler) + + +# ---------- _format_outputs_for_send -------------------------------------- + + +class TestFormatOutputsForSend: + """Direct tests for the auto-send rendering helper. + + Regression for PR #5630 review-comment 4: a single scalar JSON value + must render bare (e.g. ``"42"``) rather than wrapped (``"[42]"``). + """ + + @pytest.mark.parametrize( + ("parsed", "expected"), + [ + ([], ""), + (["hello"], "hello"), + (["a", "b"], "a\nb"), + ([42], "42"), + ([3.14], "3.14"), + ([True], "true"), + ([False], "false"), + ([None], "null"), + ([{"k": "v"}], '{"k": "v"}'), + ([[1, 2]], "[1, 2]"), + (["hello", 42], '["hello", 42]'), + ([{"a": 1}, {"b": 2}], '[{"a": 1}, {"b": 2}]'), + ], + ) + def test_format_outputs_for_send(self, parsed: list[Any], expected: str) -> None: + from agent_framework_declarative._workflows._executors_mcp import _format_outputs_for_send + + assert _format_outputs_for_send(parsed) == expected