mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: run sync tools off the event loop (#5773)
* fix: run sync tools off event loop * chore: silence harness tool marker type check
This commit is contained in:
committed by
GitHub
Unverified
parent
c3901a4ddd
commit
f29bae8fbc
@@ -349,6 +349,8 @@ class BackgroundAgentsProvider(ContextProvider):
|
||||
_save_provider_state(session, provider_state, source_id=source_id)
|
||||
return f"Background task {task_id} started on agent '{agent_name}'."
|
||||
|
||||
background_agents_start_task._invoke_sync_on_event_loop = True # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
@tool(name="background_agents_wait_for_first_completion", approval_mode="never_require")
|
||||
async def background_agents_wait_for_first_completion(task_ids: list[int]) -> str:
|
||||
"""Block until the first of the specified background tasks completes. Returns the completed task's ID."""
|
||||
@@ -471,6 +473,8 @@ class BackgroundAgentsProvider(ContextProvider):
|
||||
_save_provider_state(session, provider_state, source_id=source_id)
|
||||
return f"Task {task_id} continued with new input."
|
||||
|
||||
background_agents_continue_task._invoke_sync_on_event_loop = True # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
@tool(name="background_agents_clear_completed_task", approval_mode="never_require")
|
||||
def background_agents_clear_completed_task(task_id: int) -> str:
|
||||
"""Remove a completed or failed task and release its session to free memory."""
|
||||
|
||||
@@ -292,6 +292,7 @@ class FunctionTool(SerializationMixin):
|
||||
"_cached_parameters",
|
||||
"_input_schema",
|
||||
"_schema_supplied",
|
||||
"_invoke_sync_on_event_loop",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@@ -366,6 +367,7 @@ class FunctionTool(SerializationMixin):
|
||||
self.description = description
|
||||
self.kind = kind
|
||||
self.additional_properties = additional_properties
|
||||
self._invoke_sync_on_event_loop = False
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
@@ -537,6 +539,16 @@ class FunctionTool(SerializationMixin):
|
||||
self.invocation_exception_count += 1
|
||||
raise
|
||||
|
||||
async def _invoke_function(self, call_kwargs: Mapping[str, Any]) -> Any:
|
||||
"""Run sync tools off the event loop during async invocation."""
|
||||
func = self.func.func if isinstance(self.func, FunctionTool) else self.func
|
||||
if inspect.iscoroutinefunction(func) or getattr(self, "_invoke_sync_on_event_loop", False):
|
||||
res = self.__call__(**call_kwargs)
|
||||
return await res if inspect.isawaitable(res) else res
|
||||
|
||||
res = await asyncio.to_thread(self.__call__, **call_kwargs)
|
||||
return await res if inspect.isawaitable(res) else res
|
||||
|
||||
@overload
|
||||
async def invoke(
|
||||
self,
|
||||
@@ -679,8 +691,7 @@ class FunctionTool(SerializationMixin):
|
||||
if not OBSERVABILITY_SETTINGS.ENABLED: # type: ignore[name-defined]
|
||||
logger.info(f"Function name: {self.name}")
|
||||
logger.debug(f"Function arguments: {observable_kwargs}")
|
||||
res = self.__call__(**call_kwargs)
|
||||
result = await res if inspect.isawaitable(res) else res
|
||||
result = await self._invoke_function(call_kwargs)
|
||||
if skip_parsing:
|
||||
logger.info(f"Function {self.name} succeeded.")
|
||||
logger.debug(f"Function result: {type(result).__name__}")
|
||||
@@ -730,8 +741,7 @@ class FunctionTool(SerializationMixin):
|
||||
start_time_stamp = perf_counter()
|
||||
end_time_stamp: float | None = None
|
||||
try:
|
||||
res = self.__call__(**call_kwargs)
|
||||
result = await res if inspect.isawaitable(res) else res
|
||||
result = await self._invoke_function(call_kwargs)
|
||||
end_time_stamp = perf_counter()
|
||||
except Exception as exception:
|
||||
end_time_stamp = perf_counter()
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Annotated, Any, Literal, get_args, get_origin
|
||||
from unittest.mock import Mock
|
||||
|
||||
@@ -1346,6 +1348,45 @@ async def test_invoke_skip_parsing_awaits_async_functions() -> None:
|
||||
assert raw == 42
|
||||
|
||||
|
||||
async def test_invoke_sync_tool_does_not_block_event_loop() -> None:
|
||||
release_tool = threading.Event()
|
||||
tool_thread_ids: list[int] = []
|
||||
event_loop_thread_id = threading.get_ident()
|
||||
|
||||
@tool
|
||||
def wait_for_release() -> str:
|
||||
tool_thread_ids.append(threading.get_ident())
|
||||
return "released" if release_tool.wait(timeout=0.2) else "timed out"
|
||||
|
||||
async def release_soon() -> None:
|
||||
await asyncio.sleep(0.01)
|
||||
release_tool.set()
|
||||
|
||||
tool_task = asyncio.create_task(wait_for_release.invoke(skip_parsing=True))
|
||||
release_task = asyncio.create_task(release_soon())
|
||||
|
||||
assert await asyncio.wait_for(tool_task, timeout=1) == "released"
|
||||
await release_task
|
||||
assert tool_thread_ids
|
||||
assert tool_thread_ids[0] != event_loop_thread_id
|
||||
|
||||
|
||||
async def test_invoke_sync_tool_can_stay_on_event_loop() -> None:
|
||||
event_loop_thread_id = threading.get_ident()
|
||||
tool_thread_ids: list[int] = []
|
||||
|
||||
@tool
|
||||
def needs_event_loop() -> str:
|
||||
tool_thread_ids.append(threading.get_ident())
|
||||
asyncio.get_running_loop()
|
||||
return "ok"
|
||||
|
||||
needs_event_loop._invoke_sync_on_event_loop = True
|
||||
|
||||
assert await needs_event_loop.invoke(skip_parsing=True) == "ok"
|
||||
assert tool_thread_ids == [event_loop_thread_id]
|
||||
|
||||
|
||||
async def test_invoke_skip_parsing_bypasses_configured_result_parser() -> None:
|
||||
"""The tool's own result_parser is bypassed when skip_parsing=True is requested."""
|
||||
parser_calls: list[Any] = []
|
||||
|
||||
Reference in New Issue
Block a user