From f29bae8fbc9e3057f9d14cd223ddcc5a8fae3fd4 Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Thu, 4 Jun 2026 12:42:08 +0800 Subject: [PATCH] Python: run sync tools off the event loop (#5773) * fix: run sync tools off event loop * chore: silence harness tool marker type check --- .../_harness/_background_agents.py | 4 ++ .../packages/core/agent_framework/_tools.py | 18 ++++++-- python/packages/core/tests/core/test_tools.py | 41 +++++++++++++++++++ 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/python/packages/core/agent_framework/_harness/_background_agents.py b/python/packages/core/agent_framework/_harness/_background_agents.py index c329af4aa9..c5efa1a6fb 100644 --- a/python/packages/core/agent_framework/_harness/_background_agents.py +++ b/python/packages/core/agent_framework/_harness/_background_agents.py @@ -349,6 +349,8 @@ class BackgroundAgentsProvider(ContextProvider): _save_provider_state(session, provider_state, source_id=source_id) return f"Background task {task_id} started on agent '{agent_name}'." + background_agents_start_task._invoke_sync_on_event_loop = True # pyright: ignore[reportPrivateUsage] + @tool(name="background_agents_wait_for_first_completion", approval_mode="never_require") async def background_agents_wait_for_first_completion(task_ids: list[int]) -> str: """Block until the first of the specified background tasks completes. Returns the completed task's ID.""" @@ -471,6 +473,8 @@ class BackgroundAgentsProvider(ContextProvider): _save_provider_state(session, provider_state, source_id=source_id) return f"Task {task_id} continued with new input." + background_agents_continue_task._invoke_sync_on_event_loop = True # pyright: ignore[reportPrivateUsage] + @tool(name="background_agents_clear_completed_task", approval_mode="never_require") def background_agents_clear_completed_task(task_id: int) -> str: """Remove a completed or failed task and release its session to free memory.""" diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 5237cf62ba..7bb54ee2c9 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -292,6 +292,7 @@ class FunctionTool(SerializationMixin): "_cached_parameters", "_input_schema", "_schema_supplied", + "_invoke_sync_on_event_loop", } def __init__( @@ -366,6 +367,7 @@ class FunctionTool(SerializationMixin): self.description = description self.kind = kind self.additional_properties = additional_properties + self._invoke_sync_on_event_loop = False for key, value in kwargs.items(): setattr(self, key, value) @@ -537,6 +539,16 @@ class FunctionTool(SerializationMixin): self.invocation_exception_count += 1 raise + async def _invoke_function(self, call_kwargs: Mapping[str, Any]) -> Any: + """Run sync tools off the event loop during async invocation.""" + func = self.func.func if isinstance(self.func, FunctionTool) else self.func + if inspect.iscoroutinefunction(func) or getattr(self, "_invoke_sync_on_event_loop", False): + res = self.__call__(**call_kwargs) + return await res if inspect.isawaitable(res) else res + + res = await asyncio.to_thread(self.__call__, **call_kwargs) + return await res if inspect.isawaitable(res) else res + @overload async def invoke( self, @@ -679,8 +691,7 @@ class FunctionTool(SerializationMixin): if not OBSERVABILITY_SETTINGS.ENABLED: # type: ignore[name-defined] logger.info(f"Function name: {self.name}") logger.debug(f"Function arguments: {observable_kwargs}") - res = self.__call__(**call_kwargs) - result = await res if inspect.isawaitable(res) else res + result = await self._invoke_function(call_kwargs) if skip_parsing: logger.info(f"Function {self.name} succeeded.") logger.debug(f"Function result: {type(result).__name__}") @@ -730,8 +741,7 @@ class FunctionTool(SerializationMixin): start_time_stamp = perf_counter() end_time_stamp: float | None = None try: - res = self.__call__(**call_kwargs) - result = await res if inspect.isawaitable(res) else res + result = await self._invoke_function(call_kwargs) end_time_stamp = perf_counter() except Exception as exception: end_time_stamp = perf_counter() diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index b3762bf4ef..f44cbc267a 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -1,4 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +import asyncio +import threading from typing import Annotated, Any, Literal, get_args, get_origin from unittest.mock import Mock @@ -1346,6 +1348,45 @@ async def test_invoke_skip_parsing_awaits_async_functions() -> None: assert raw == 42 +async def test_invoke_sync_tool_does_not_block_event_loop() -> None: + release_tool = threading.Event() + tool_thread_ids: list[int] = [] + event_loop_thread_id = threading.get_ident() + + @tool + def wait_for_release() -> str: + tool_thread_ids.append(threading.get_ident()) + return "released" if release_tool.wait(timeout=0.2) else "timed out" + + async def release_soon() -> None: + await asyncio.sleep(0.01) + release_tool.set() + + tool_task = asyncio.create_task(wait_for_release.invoke(skip_parsing=True)) + release_task = asyncio.create_task(release_soon()) + + assert await asyncio.wait_for(tool_task, timeout=1) == "released" + await release_task + assert tool_thread_ids + assert tool_thread_ids[0] != event_loop_thread_id + + +async def test_invoke_sync_tool_can_stay_on_event_loop() -> None: + event_loop_thread_id = threading.get_ident() + tool_thread_ids: list[int] = [] + + @tool + def needs_event_loop() -> str: + tool_thread_ids.append(threading.get_ident()) + asyncio.get_running_loop() + return "ok" + + needs_event_loop._invoke_sync_on_event_loop = True + + assert await needs_event_loop.invoke(skip_parsing=True) == "ok" + assert tool_thread_ids == [event_loop_thread_id] + + async def test_invoke_skip_parsing_bypasses_configured_result_parser() -> None: """The tool's own result_parser is bypassed when skip_parsing=True is requested.""" parser_calls: list[Any] = []