Python: run sync tools off the event loop (#5773)

* fix: run sync tools off event loop

* chore: silence harness tool marker type check
This commit is contained in:
Yufeng He
2026-06-04 12:42:08 +08:00
committed by GitHub
Unverified
parent c3901a4ddd
commit f29bae8fbc
3 changed files with 59 additions and 4 deletions
@@ -349,6 +349,8 @@ class BackgroundAgentsProvider(ContextProvider):
_save_provider_state(session, provider_state, source_id=source_id)
return f"Background task {task_id} started on agent '{agent_name}'."
background_agents_start_task._invoke_sync_on_event_loop = True # pyright: ignore[reportPrivateUsage]
@tool(name="background_agents_wait_for_first_completion", approval_mode="never_require")
async def background_agents_wait_for_first_completion(task_ids: list[int]) -> str:
"""Block until the first of the specified background tasks completes. Returns the completed task's ID."""
@@ -471,6 +473,8 @@ class BackgroundAgentsProvider(ContextProvider):
_save_provider_state(session, provider_state, source_id=source_id)
return f"Task {task_id} continued with new input."
background_agents_continue_task._invoke_sync_on_event_loop = True # pyright: ignore[reportPrivateUsage]
@tool(name="background_agents_clear_completed_task", approval_mode="never_require")
def background_agents_clear_completed_task(task_id: int) -> str:
"""Remove a completed or failed task and release its session to free memory."""
+14 -4
View File
@@ -292,6 +292,7 @@ class FunctionTool(SerializationMixin):
"_cached_parameters",
"_input_schema",
"_schema_supplied",
"_invoke_sync_on_event_loop",
}
def __init__(
@@ -366,6 +367,7 @@ class FunctionTool(SerializationMixin):
self.description = description
self.kind = kind
self.additional_properties = additional_properties
self._invoke_sync_on_event_loop = False
for key, value in kwargs.items():
setattr(self, key, value)
@@ -537,6 +539,16 @@ class FunctionTool(SerializationMixin):
self.invocation_exception_count += 1
raise
async def _invoke_function(self, call_kwargs: Mapping[str, Any]) -> Any:
"""Run sync tools off the event loop during async invocation."""
func = self.func.func if isinstance(self.func, FunctionTool) else self.func
if inspect.iscoroutinefunction(func) or getattr(self, "_invoke_sync_on_event_loop", False):
res = self.__call__(**call_kwargs)
return await res if inspect.isawaitable(res) else res
res = await asyncio.to_thread(self.__call__, **call_kwargs)
return await res if inspect.isawaitable(res) else res
@overload
async def invoke(
self,
@@ -679,8 +691,7 @@ class FunctionTool(SerializationMixin):
if not OBSERVABILITY_SETTINGS.ENABLED: # type: ignore[name-defined]
logger.info(f"Function name: {self.name}")
logger.debug(f"Function arguments: {observable_kwargs}")
res = self.__call__(**call_kwargs)
result = await res if inspect.isawaitable(res) else res
result = await self._invoke_function(call_kwargs)
if skip_parsing:
logger.info(f"Function {self.name} succeeded.")
logger.debug(f"Function result: {type(result).__name__}")
@@ -730,8 +741,7 @@ class FunctionTool(SerializationMixin):
start_time_stamp = perf_counter()
end_time_stamp: float | None = None
try:
res = self.__call__(**call_kwargs)
result = await res if inspect.isawaitable(res) else res
result = await self._invoke_function(call_kwargs)
end_time_stamp = perf_counter()
except Exception as exception:
end_time_stamp = perf_counter()
@@ -1,4 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import threading
from typing import Annotated, Any, Literal, get_args, get_origin
from unittest.mock import Mock
@@ -1346,6 +1348,45 @@ async def test_invoke_skip_parsing_awaits_async_functions() -> None:
assert raw == 42
async def test_invoke_sync_tool_does_not_block_event_loop() -> None:
release_tool = threading.Event()
tool_thread_ids: list[int] = []
event_loop_thread_id = threading.get_ident()
@tool
def wait_for_release() -> str:
tool_thread_ids.append(threading.get_ident())
return "released" if release_tool.wait(timeout=0.2) else "timed out"
async def release_soon() -> None:
await asyncio.sleep(0.01)
release_tool.set()
tool_task = asyncio.create_task(wait_for_release.invoke(skip_parsing=True))
release_task = asyncio.create_task(release_soon())
assert await asyncio.wait_for(tool_task, timeout=1) == "released"
await release_task
assert tool_thread_ids
assert tool_thread_ids[0] != event_loop_thread_id
async def test_invoke_sync_tool_can_stay_on_event_loop() -> None:
event_loop_thread_id = threading.get_ident()
tool_thread_ids: list[int] = []
@tool
def needs_event_loop() -> str:
tool_thread_ids.append(threading.get_ident())
asyncio.get_running_loop()
return "ok"
needs_event_loop._invoke_sync_on_event_loop = True
assert await needs_event_loop.invoke(skip_parsing=True) == "ok"
assert tool_thread_ids == [event_loop_thread_id]
async def test_invoke_skip_parsing_bypasses_configured_result_parser() -> None:
"""The tool's own result_parser is bypassed when skip_parsing=True is requested."""
parser_calls: list[Any] = []