Python: Add a BackgroundAgentsProvider for python (#6069)

* Add a BackgroundAgentsProvider for python

* Address PR comments and fix linting warnings

* Address PR comment
This commit is contained in:
westey
2026-05-27 10:12:01 +01:00
committed by GitHub
Unverified
parent 3242d8a4c4
commit ae989b92e7
3 changed files with 1069 additions and 0 deletions
@@ -79,6 +79,12 @@ from ._evaluation import (
tool_calls_present,
)
from ._feature_stage import ExperimentalFeature, ReleaseCandidateFeature
from ._harness._background_agents import (
DEFAULT_BACKGROUND_AGENTS_SOURCE_ID,
BackgroundAgentsProvider,
BackgroundTaskInfo,
BackgroundTaskStatus,
)
from ._harness._memory import (
DEFAULT_MEMORY_SOURCE_ID,
MemoryContextProvider,
@@ -297,6 +303,7 @@ __all__ = [
"AGENT_FRAMEWORK_USER_AGENT",
"APP_INFO",
"COMPACTION_STATE_KEY",
"DEFAULT_BACKGROUND_AGENTS_SOURCE_ID",
"DEFAULT_MAX_ITERATIONS",
"DEFAULT_MEMORY_SOURCE_ID",
"DEFAULT_MODE_SOURCE_ID",
@@ -332,6 +339,9 @@ __all__ = [
"AgentSession",
"AggregatingSkillsSource",
"Annotation",
"BackgroundAgentsProvider",
"BackgroundTaskInfo",
"BackgroundTaskStatus",
"BaseAgent",
"BaseChatClient",
"BaseEmbeddingClient",
@@ -0,0 +1,521 @@
# Copyright (c) Microsoft. All rights reserved.
"""BackgroundAgentsProvider: enables an agent to delegate work to background sub-agents asynchronously.
This module provides :class:`BackgroundAgentsProvider`, a context provider that allows
a parent agent to start background tasks on child agents, wait for their completion,
and retrieve results. Each background task runs in its own session concurrently.
"""
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, MutableMapping, Sequence
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, ClassVar, cast
from .._agents import SupportsAgentRun
from .._feature_stage import ExperimentalFeature, experimental
from .._serialization import SerializationMixin
from .._sessions import AgentSession, ContextProvider, SessionContext
from .._tools import tool
from .._types import AgentResponse, Message
DEFAULT_BACKGROUND_AGENTS_SOURCE_ID = "background_agents"
DEFAULT_BACKGROUND_AGENTS_INSTRUCTIONS = """\
## Background Agents
You have access to background agents that can perform work on your behalf.
- Use the `background_agents_*` tools to start tasks on background agents and check their results.
- Creating a background task does not block, and background tasks run concurrently.
- Important: Always wait for outstanding tasks to finish before you finish processing.
- Important: After retrieving results from a completed task, clear it with \
background_agents_clear_completed_task to free memory, unless you plan to continue it with \
background_agents_continue_task.
{background_agents}"""
class BackgroundTaskStatus(str, Enum):
"""Status of a background task."""
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
LOST = "lost"
@experimental(feature_id=ExperimentalFeature.HARNESS)
class BackgroundTaskInfo(SerializationMixin):
"""Metadata for a single background task."""
DEFAULT_EXCLUDE: ClassVar[set[str]] = set()
id: int
agent_name: str
description: str
status: BackgroundTaskStatus
result_text: str | None
error_text: str | None
__slots__ = ("agent_name", "description", "error_text", "id", "result_text", "status")
def __init__(
self,
id: int,
agent_name: str,
description: str,
status: BackgroundTaskStatus = BackgroundTaskStatus.RUNNING,
result_text: str | None = None,
error_text: str | None = None,
) -> None:
"""Initialize a background task info entry."""
self.id = id
self.agent_name = agent_name
self.description = description
self.status = status
self.result_text = result_text
self.error_text = error_text
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]:
"""Serialize for session state persistence."""
del exclude
data: dict[str, Any] = {
"id": self.id,
"agent_name": self.agent_name,
"description": self.description,
"status": self.status.value,
}
if not exclude_none or self.result_text is not None:
data["result_text"] = self.result_text
if not exclude_none or self.error_text is not None:
data["error_text"] = self.error_text
return data
@classmethod
def from_dict(cls, data: MutableMapping[str, Any], **kwargs: Any) -> BackgroundTaskInfo:
"""Deserialize from session state."""
return cls(
id=data["id"],
agent_name=data["agent_name"],
description=data["description"],
status=BackgroundTaskStatus(data["status"]),
result_text=data.get("result_text"),
error_text=data.get("error_text"),
)
@dataclass
class _RuntimeState:
"""Non-serializable per-session runtime state for background tasks."""
in_flight_tasks: dict[int, asyncio.Task[AgentResponse[Any]]] = field(
default_factory=lambda: {} # pyright: ignore[reportUnknownLambdaType]
)
background_sessions: dict[int, AgentSession] = field(
default_factory=lambda: {} # pyright: ignore[reportUnknownLambdaType]
)
# ---------------------------------------------------------------------------
# Module-level helper functions (following ModeProvider pattern)
# ---------------------------------------------------------------------------
async def _run_agent(awaitable: Awaitable[AgentResponse[Any]]) -> AgentResponse[Any]:
"""Wrap an Awaitable in a proper coroutine for use with asyncio.create_task."""
return await awaitable
def _validate_and_build_agent_dict(agents: Sequence[SupportsAgentRun]) -> dict[str, SupportsAgentRun]:
"""Validate agents and build a case-insensitive lookup dict.
Raises:
ValueError: If agents is empty, an agent has no name, or names are not unique.
"""
if not agents:
raise ValueError("At least one background agent must be provided.")
agent_dict: dict[str, SupportsAgentRun] = {}
for agent in agents:
name = agent.name
if not name or not name.strip():
raise ValueError("All background agents must have a non-empty name.")
key = name.lower()
if key in agent_dict:
raise ValueError(
f"Duplicate background agent name: '{name}'. Agent names must be unique (case-insensitive)."
)
agent_dict[key] = agent
return agent_dict
def _build_agent_list_text(agents: dict[str, SupportsAgentRun]) -> str:
"""Build text listing available background agents."""
lines = ["Available background agents:"]
for agent in agents.values():
line = f"- {agent.name}"
if agent.description:
line += f": {agent.description}"
lines.append(line)
return "\n".join(lines)
def _get_provider_state(session: AgentSession, *, source_id: str) -> dict[str, Any]:
"""Load or initialize serializable provider state from session."""
state = session.state.get(source_id)
if state is None:
initial: dict[str, Any] = {"next_task_id": 1, "tasks": []}
session.state[source_id] = initial
return initial
return cast(dict[str, Any], state)
def _save_provider_state(session: AgentSession, state: dict[str, Any], *, source_id: str) -> None:
"""Persist serializable state to session."""
session.state[source_id] = state
def _get_tasks(state: dict[str, Any]) -> list[BackgroundTaskInfo]:
"""Parse task list from state dict."""
return [BackgroundTaskInfo.from_dict(t) for t in state.get("tasks", [])]
def _save_tasks(state: dict[str, Any], tasks: list[BackgroundTaskInfo]) -> None:
"""Serialize task list back to state dict."""
state["tasks"] = [t.to_dict() for t in tasks]
def _finalize_task(
task_info: BackgroundTaskInfo,
completed_task: asyncio.Task[AgentResponse[Any]],
runtime: _RuntimeState,
) -> None:
"""Extract results from a completed asyncio task and update task info."""
if completed_task.cancelled():
task_info.status = BackgroundTaskStatus.FAILED
task_info.error_text = "Task was canceled."
else:
exception = completed_task.exception()
if exception is not None:
task_info.status = BackgroundTaskStatus.FAILED
task_info.error_text = str(exception)
else:
task_info.status = BackgroundTaskStatus.COMPLETED
task_info.result_text = completed_task.result().text
runtime.in_flight_tasks.pop(task_info.id, None)
def _refresh_task_state(
session: AgentSession, state: dict[str, Any], runtime: _RuntimeState, *, source_id: str
) -> list[BackgroundTaskInfo]:
"""Refresh status of in-flight tasks and return updated task list."""
tasks = _get_tasks(state)
changed = False
for task_info in tasks:
if task_info.status != BackgroundTaskStatus.RUNNING:
continue
in_flight = runtime.in_flight_tasks.get(task_info.id)
if in_flight is None:
task_info.status = BackgroundTaskStatus.LOST
changed = True
continue
if in_flight.done():
_finalize_task(task_info, in_flight, runtime)
changed = True
if changed:
_save_tasks(state, tasks)
_save_provider_state(session, state, source_id=source_id)
return tasks
# ---------------------------------------------------------------------------
# Provider class
# ---------------------------------------------------------------------------
@experimental(feature_id=ExperimentalFeature.HARNESS)
class BackgroundAgentsProvider(ContextProvider):
"""Context provider that enables an agent to delegate work to background sub-agents.
The ``BackgroundAgentsProvider`` allows a parent agent to start background tasks on child agents,
wait for their completion, and retrieve results. Each background task runs in its own session and
executes concurrently.
This provider exposes the following tools to the agent:
- ``background_agents_start_task`` — Start a background task on a named agent with text input.
- ``background_agents_wait_for_first_completion`` — Block until the first of the specified tasks completes.
- ``background_agents_get_task_results`` — Retrieve the text output of a completed background task.
- ``background_agents_get_all_tasks`` — List all background tasks with their IDs, statuses, and descriptions.
- ``background_agents_continue_task`` — Send follow-up input to a completed task's session to resume work.
- ``background_agents_clear_completed_task`` — Remove a completed task and release its session.
"""
def __init__(
self,
agents: Sequence[SupportsAgentRun],
*,
source_id: str = DEFAULT_BACKGROUND_AGENTS_SOURCE_ID,
instructions: str | None = None,
) -> None:
"""Initialize the background agents provider.
Args:
agents: Collection of background agents available for delegation.
Each agent must have a non-empty, unique name (case-insensitive).
Keyword Args:
source_id: Unique source ID for serializable task state in session.
instructions: Optional instruction override. May include ``{background_agents}``
placeholder which will be replaced with the agent listing.
Raises:
ValueError: If agents is empty, an agent has no name, or names are not unique.
"""
super().__init__(source_id)
self._agents = _validate_and_build_agent_dict(agents)
# Build instructions with agent listing.
base_instructions = instructions if instructions is not None else DEFAULT_BACKGROUND_AGENTS_INSTRUCTIONS
agent_list_text = _build_agent_list_text(self._agents)
self._instructions = base_instructions.replace("{background_agents}", agent_list_text)
# Per-session runtime state (non-serializable), keyed by session_id.
# Note: Runtime state (in-flight asyncio.Task objects, child AgentSession handles)
# is inherently non-serializable and cannot survive process restarts. If the provider
# instance is lost, _refresh_task_state() marks orphaned tasks as LOST.
self._runtime: dict[str, _RuntimeState] = {}
def _get_runtime(self, session: AgentSession) -> _RuntimeState:
"""Get or create runtime state for a session."""
session_id = session.session_id
if session_id not in self._runtime:
self._runtime[session_id] = _RuntimeState()
return self._runtime[session_id]
async def before_run(
self,
*,
agent: Any,
session: AgentSession,
context: SessionContext,
state: dict[str, Any],
) -> None:
"""Inject background agent tools and instructions before the model runs."""
del agent, state
provider_state = _get_provider_state(session, source_id=self.source_id)
runtime = self._get_runtime(session)
source_id = self.source_id
@tool(name="background_agents_start_task", approval_mode="never_require")
def background_agents_start_task(agent_name: str, input: str, description: str) -> str:
"""Start a background task on a named agent. Returns a confirmation with the task ID."""
key = agent_name.lower()
if key not in self._agents:
available = ", ".join(a.name or "" for a in self._agents.values())
return f"Error: No background agent found with name '{agent_name}'. Available agents: {available}"
bg_agent = self._agents[key]
task_id = provider_state.get("next_task_id", 1)
provider_state["next_task_id"] = task_id + 1
task_info = BackgroundTaskInfo(
id=task_id,
agent_name=agent_name,
description=description,
)
tasks = _get_tasks(provider_state)
tasks.append(task_info)
_save_tasks(provider_state, tasks)
# Create a dedicated session for this background task.
sub_session = bg_agent.create_session()
# Start the task concurrently.
async_task = asyncio.create_task(_run_agent(bg_agent.run(input, session=sub_session)))
runtime.in_flight_tasks[task_id] = async_task
runtime.background_sessions[task_id] = sub_session
_save_provider_state(session, provider_state, source_id=source_id)
return f"Background task {task_id} started on agent '{agent_name}'."
@tool(name="background_agents_wait_for_first_completion", approval_mode="never_require")
async def background_agents_wait_for_first_completion(task_ids: list[int]) -> str:
"""Block until the first of the specified background tasks completes. Returns the completed task's ID."""
if not task_ids:
return "Error: No task IDs provided."
# Collect in-flight tasks matching the requested IDs.
waitable: list[tuple[int, asyncio.Task[AgentResponse[Any]]]] = []
for tid in task_ids:
in_flight = runtime.in_flight_tasks.get(tid)
if in_flight is not None:
waitable.append((tid, in_flight))
if not waitable:
# Refresh state to catch any that completed.
tasks = _refresh_task_state(session, provider_state, runtime, source_id=source_id)
already_complete = next(
(t for t in tasks if t.id in task_ids and t.status != BackgroundTaskStatus.RUNNING), None
)
if already_complete is not None:
return (
f"Task {already_complete.id} is not running; current status: {already_complete.status.value}."
)
return "Error: None of the specified task IDs correspond to running tasks."
# Wait for the first one to complete.
done, _ = await asyncio.wait(
[t for _, t in waitable],
return_when=asyncio.FIRST_COMPLETED,
)
# Find which ID completed.
completed_id: int | None = None
for tid, task in waitable:
if task in done:
completed_id = tid
break
# Finalize the completed task.
tasks = _get_tasks(provider_state)
task_info = next((t for t in tasks if t.id == completed_id), None)
if task_info is not None and completed_id is not None:
completed_task = runtime.in_flight_tasks.get(completed_id)
if completed_task is not None:
_finalize_task(task_info, completed_task, runtime)
_save_tasks(provider_state, tasks)
_save_provider_state(session, provider_state, source_id=source_id)
status_str = task_info.status.value if task_info else "Unknown"
return f"Task {completed_id} finished with status: {status_str}."
@tool(name="background_agents_get_task_results", approval_mode="never_require")
def background_agents_get_task_results(task_id: int) -> str:
"""Get the text output of a background task by its ID."""
tasks = _refresh_task_state(session, provider_state, runtime, source_id=source_id)
task_info = next((t for t in tasks if t.id == task_id), None)
if task_info is None:
return f"Error: No task found with ID {task_id}."
if task_info.status == BackgroundTaskStatus.COMPLETED:
return task_info.result_text or "(no output)"
if task_info.status == BackgroundTaskStatus.FAILED:
return f"Task failed: {task_info.error_text or 'Unknown error'}"
if task_info.status == BackgroundTaskStatus.LOST:
return "Task state was lost (reference unavailable)."
if task_info.status == BackgroundTaskStatus.RUNNING:
return f"Task {task_id} is still running."
return f"Task {task_id} has status: {task_info.status.value}."
@tool(name="background_agents_get_all_tasks", approval_mode="never_require")
def background_agents_get_all_tasks() -> str:
"""List all background tasks with their IDs, statuses, agent names, and descriptions."""
tasks = _refresh_task_state(session, provider_state, runtime, source_id=source_id)
if not tasks:
return "No tasks."
lines = ["Tasks:"]
for t in tasks:
lines.append(f"- Task {t.id} [{t.status.value}] ({t.agent_name}): {t.description}")
return "\n".join(lines)
@tool(name="background_agents_continue_task", approval_mode="never_require")
def background_agents_continue_task(task_id: int, text: str) -> str:
"""Send follow-up input to a completed or failed task to resume its work."""
tasks = _refresh_task_state(session, provider_state, runtime, source_id=source_id)
task_info = next((t for t in tasks if t.id == task_id), None)
if task_info is None:
return f"Error: No task found with ID {task_id}."
if task_info.status == BackgroundTaskStatus.LOST:
return (
f"Error: Task {task_id} cannot be continued because its session was lost. Start a new task instead."
)
if task_info.status == BackgroundTaskStatus.RUNNING:
return f"Error: Task {task_id} is still running. Wait for it to complete before continuing."
key = task_info.agent_name.lower()
if key not in self._agents:
return f"Error: Agent '{task_info.agent_name}' is no longer available."
sub_session = runtime.background_sessions.get(task_id)
if sub_session is None:
return f"Error: Session for task {task_id} is no longer available."
bg_agent = self._agents[key]
# Reset task state and start a new run on the existing session.
task_info.status = BackgroundTaskStatus.RUNNING
task_info.result_text = None
task_info.error_text = None
_save_tasks(provider_state, tasks)
async_task = asyncio.create_task(_run_agent(bg_agent.run(text, session=sub_session)))
runtime.in_flight_tasks[task_id] = async_task
_save_provider_state(session, provider_state, source_id=source_id)
return f"Task {task_id} continued with new input."
@tool(name="background_agents_clear_completed_task", approval_mode="never_require")
def background_agents_clear_completed_task(task_id: int) -> str:
"""Remove a completed or failed task and release its session to free memory."""
tasks = _refresh_task_state(session, provider_state, runtime, source_id=source_id)
task_info = next((t for t in tasks if t.id == task_id), None)
if task_info is None:
return f"Error: No task found with ID {task_id}."
if task_info.status == BackgroundTaskStatus.RUNNING:
return f"Error: Task {task_id} is still running. Wait for it to complete before clearing."
# Remove the task from state.
tasks = [t for t in tasks if t.id != task_id]
_save_tasks(provider_state, tasks)
# Clean up runtime references.
runtime.in_flight_tasks.pop(task_id, None)
runtime.background_sessions.pop(task_id, None)
_save_provider_state(session, provider_state, source_id=source_id)
return f"Task {task_id} cleared."
# Inject instructions and current task status.
context.extend_instructions(self.source_id, [self._instructions])
context.extend_tools(
self.source_id,
[
background_agents_start_task,
background_agents_wait_for_first_completion,
background_agents_get_task_results,
background_agents_get_all_tasks,
background_agents_continue_task,
background_agents_clear_completed_task,
],
)
# Include current task status as context message if there are tasks.
# Refresh first to get accurate statuses for any tasks that completed between turns.
tasks = _refresh_task_state(session, provider_state, runtime, source_id=source_id)
if tasks:
status_lines = ["### Current background tasks"]
for t in tasks:
status_lines.append(f"- Task {t.id} [{t.status.value}] ({t.agent_name}): {t.description}")
context.extend_messages(
self.source_id,
[Message(role="user", contents=["\n".join(status_lines)])],
)
@@ -0,0 +1,538 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import asyncio
from typing import Any
import pytest
from agent_framework import (
AgentResponse,
AgentSession,
BackgroundAgentsProvider,
BackgroundTaskInfo,
BackgroundTaskStatus,
Message,
)
from agent_framework._sessions import SessionContext
# Suppress "coroutine was never awaited" warnings from task cancellation in tests.
# This occurs when cancelling tasks that wrap coroutines through _run_agent().
pytestmark = pytest.mark.filterwarnings("ignore::RuntimeWarning:asyncio")
# --- Test Helpers ---
class _FakeAgent:
"""Minimal agent stub for testing background agent delegation."""
def __init__(
self,
name: str,
description: str | None = None,
*,
response_text: str = "done",
delay: float = 0.0,
should_fail: bool = False,
):
self.id = f"agent-{name}"
self.name = name
self.description = description
self._response_text = response_text
self._delay = delay
self._should_fail = should_fail
def create_session(self, *, session_id: str | None = None) -> AgentSession:
return AgentSession(session_id=session_id)
def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession:
return AgentSession(service_session_id=service_session_id, session_id=session_id)
async def run(
self, messages: Any = None, *, stream: bool = False, session: Any = None, **kwargs: Any
) -> AgentResponse[Any]:
if self._delay > 0:
await asyncio.sleep(self._delay)
if self._should_fail:
raise RuntimeError("Agent execution failed")
return AgentResponse(messages=[Message(role="assistant", contents=[self._response_text])])
def _make_provider(*agents: _FakeAgent) -> BackgroundAgentsProvider:
"""Create a provider with given agents."""
return BackgroundAgentsProvider(agents)
def _make_session() -> AgentSession:
"""Create a session for testing."""
return AgentSession()
async def _get_tools(provider: BackgroundAgentsProvider, session: AgentSession) -> dict[str, Any]:
"""Run before_run and return tools by name."""
context = SessionContext(input_messages=[])
await provider.before_run(agent=None, session=session, context=context, state={})
tools_by_name: dict[str, Any] = {}
for t in context.tools:
tools_by_name[t.name if hasattr(t, "name") else str(t)] = t
return tools_by_name
async def _invoke_tool(tool_obj: Any, **kwargs: Any) -> str:
"""Invoke a FunctionTool and return the raw result string."""
return await tool_obj.invoke(arguments=kwargs, skip_parsing=True)
# --- Constructor Tests ---
def test_constructor_requires_at_least_one_agent() -> None:
"""Should reject empty agent list."""
with pytest.raises(ValueError, match="At least one background agent"):
BackgroundAgentsProvider([])
def test_constructor_requires_agent_names() -> None:
"""Should reject agents with no name."""
agent = _FakeAgent("")
with pytest.raises(ValueError, match="non-empty name"):
BackgroundAgentsProvider([agent])
def test_constructor_rejects_duplicate_names() -> None:
"""Should reject duplicate agent names (case-insensitive)."""
agent1 = _FakeAgent("Research")
agent2 = _FakeAgent("research")
with pytest.raises(ValueError, match="Duplicate background agent name"):
BackgroundAgentsProvider([agent1, agent2])
def test_constructor_valid_agents() -> None:
"""Should succeed with valid unique agents."""
provider = BackgroundAgentsProvider([_FakeAgent("Alpha"), _FakeAgent("Beta")])
assert provider.source_id == "background_agents"
def test_constructor_custom_source_id() -> None:
"""Should accept custom source_id."""
provider = BackgroundAgentsProvider([_FakeAgent("Agent1")], source_id="custom_bg")
assert provider.source_id == "custom_bg"
# --- Tool Injection Tests ---
async def test_before_run_injects_six_tools() -> None:
"""before_run should inject exactly 6 tools."""
provider = _make_provider(_FakeAgent("Worker"))
tools = await _get_tools(provider, _make_session())
assert len(tools) == 6
expected_names = {
"background_agents_start_task",
"background_agents_wait_for_first_completion",
"background_agents_get_task_results",
"background_agents_get_all_tasks",
"background_agents_continue_task",
"background_agents_clear_completed_task",
}
assert set(tools.keys()) == expected_names
async def test_before_run_injects_instructions() -> None:
"""before_run should inject instructions mentioning agent names."""
provider = _make_provider(_FakeAgent("ResearchBot", "Does research"))
context = SessionContext(input_messages=[])
session = _make_session()
await provider.before_run(agent=None, session=session, context=context, state={})
all_instructions = " ".join(context.instructions)
assert "ResearchBot" in all_instructions
assert "Does research" in all_instructions
# --- Start Task Tests ---
async def test_start_task_success() -> None:
"""Should start a task and return confirmation."""
provider = _make_provider(_FakeAgent("Worker", response_text="result"))
session = _make_session()
tools = await _get_tools(provider, session)
result = await _invoke_tool(
tools["background_agents_start_task"],
agent_name="Worker",
input="do something",
description="test task",
)
assert "task 1 started" in result.lower()
assert "Worker" in result
async def test_start_task_unknown_agent() -> None:
"""Should return error for unknown agent name."""
provider = _make_provider(_FakeAgent("Worker"))
session = _make_session()
tools = await _get_tools(provider, session)
result = await _invoke_tool(
tools["background_agents_start_task"],
agent_name="NonExistent",
input="do something",
description="test",
)
assert "Error" in result
assert "NonExistent" in result
async def test_start_task_increments_ids() -> None:
"""Task IDs should increment sequentially."""
provider = _make_provider(_FakeAgent("Worker"))
session = _make_session()
tools = await _get_tools(provider, session)
r1 = await _invoke_tool(
tools["background_agents_start_task"],
agent_name="Worker",
input="task 1",
description="first",
)
r2 = await _invoke_tool(
tools["background_agents_start_task"],
agent_name="Worker",
input="task 2",
description="second",
)
assert "task 1 started" in r1.lower()
assert "task 2 started" in r2.lower()
# --- Get All Tasks Tests ---
async def test_get_all_tasks_empty() -> None:
"""Should return 'No tasks.' when no tasks exist."""
provider = _make_provider(_FakeAgent("Worker"))
session = _make_session()
tools = await _get_tools(provider, session)
result = await _invoke_tool(tools["background_agents_get_all_tasks"])
assert "No tasks" in result
async def test_get_all_tasks_shows_tasks() -> None:
"""Should list all tasks with status and description."""
provider = _make_provider(_FakeAgent("Worker"))
session = _make_session()
tools = await _get_tools(provider, session)
await _invoke_tool(
tools["background_agents_start_task"],
agent_name="Worker",
input="hello",
description="my task",
)
result = await _invoke_tool(tools["background_agents_get_all_tasks"])
assert "my task" in result
assert "Worker" in result
# --- Wait for Completion Tests ---
async def test_wait_for_first_completion() -> None:
"""Should wait and return when a task completes."""
provider = _make_provider(_FakeAgent("Fast", response_text="fast result", delay=0.01))
session = _make_session()
tools = await _get_tools(provider, session)
await _invoke_tool(
tools["background_agents_start_task"],
agent_name="Fast",
input="go",
description="fast task",
)
result = await _invoke_tool(
tools["background_agents_wait_for_first_completion"],
task_ids=[1],
)
assert "finished" in result.lower()
assert "completed" in result.lower()
async def test_wait_empty_task_ids() -> None:
"""Should return error for empty task_ids."""
provider = _make_provider(_FakeAgent("Worker"))
session = _make_session()
tools = await _get_tools(provider, session)
result = await _invoke_tool(
tools["background_agents_wait_for_first_completion"],
task_ids=[],
)
assert "Error" in result
async def test_wait_no_running_tasks() -> None:
"""Should return error when no specified tasks are running."""
provider = _make_provider(_FakeAgent("Worker"))
session = _make_session()
tools = await _get_tools(provider, session)
result = await _invoke_tool(
tools["background_agents_wait_for_first_completion"],
task_ids=[999],
)
assert "Error" in result or "not running" in result.lower()
# --- Get Task Results Tests ---
async def test_get_task_results_completed() -> None:
"""Should return result text for completed task."""
provider = _make_provider(_FakeAgent("Worker", response_text="the answer", delay=0.01))
session = _make_session()
tools = await _get_tools(provider, session)
await _invoke_tool(
tools["background_agents_start_task"],
agent_name="Worker",
input="query",
description="test",
)
# Wait for completion.
await _invoke_tool(
tools["background_agents_wait_for_first_completion"],
task_ids=[1],
)
result = await _invoke_tool(
tools["background_agents_get_task_results"],
task_id=1,
)
assert result == "the answer"
async def test_get_task_results_running() -> None:
"""Should indicate task is still running."""
provider = _make_provider(_FakeAgent("Slow", delay=10.0))
session = _make_session()
tools = await _get_tools(provider, session)
await _invoke_tool(
tools["background_agents_start_task"],
agent_name="Slow",
input="query",
description="slow task",
)
try:
result = await _invoke_tool(
tools["background_agents_get_task_results"],
task_id=1,
)
assert "still running" in result.lower()
finally:
runtime = provider._get_runtime(session)
for task in list(runtime.in_flight_tasks.values()):
task.cancel()
await asyncio.gather(*runtime.in_flight_tasks.values(), return_exceptions=True)
async def test_get_task_results_failed() -> None:
"""Should return error text for failed task."""
provider = _make_provider(_FakeAgent("Broken", should_fail=True, delay=0.01))
session = _make_session()
tools = await _get_tools(provider, session)
await _invoke_tool(
tools["background_agents_start_task"],
agent_name="Broken",
input="query",
description="will fail",
)
await _invoke_tool(
tools["background_agents_wait_for_first_completion"],
task_ids=[1],
)
result = await _invoke_tool(
tools["background_agents_get_task_results"],
task_id=1,
)
assert "failed" in result.lower()
async def test_get_task_results_not_found() -> None:
"""Should return error for non-existent task."""
provider = _make_provider(_FakeAgent("Worker"))
session = _make_session()
tools = await _get_tools(provider, session)
result = await _invoke_tool(
tools["background_agents_get_task_results"],
task_id=999,
)
assert "Error" in result
# --- Continue Task Tests ---
async def test_continue_task_after_completion() -> None:
"""Should be able to continue a completed task."""
provider = _make_provider(_FakeAgent("Worker", response_text="first result", delay=0.01))
session = _make_session()
tools = await _get_tools(provider, session)
await _invoke_tool(
tools["background_agents_start_task"],
agent_name="Worker",
input="first input",
description="continuable",
)
await _invoke_tool(
tools["background_agents_wait_for_first_completion"],
task_ids=[1],
)
result = await _invoke_tool(
tools["background_agents_continue_task"],
task_id=1,
text="follow up",
)
assert "continued" in result.lower()
async def test_continue_task_still_running() -> None:
"""Should return error if task is still running."""
provider = _make_provider(_FakeAgent("Slow", delay=10.0))
session = _make_session()
tools = await _get_tools(provider, session)
await _invoke_tool(
tools["background_agents_start_task"],
agent_name="Slow",
input="input",
description="running",
)
try:
result = await _invoke_tool(
tools["background_agents_continue_task"],
task_id=1,
text="follow up",
)
assert "still running" in result.lower()
finally:
runtime = provider._get_runtime(session)
for task in list(runtime.in_flight_tasks.values()):
task.cancel()
await asyncio.gather(*runtime.in_flight_tasks.values(), return_exceptions=True)
async def test_continue_task_not_found() -> None:
"""Should return error for non-existent task."""
provider = _make_provider(_FakeAgent("Worker"))
session = _make_session()
tools = await _get_tools(provider, session)
result = await _invoke_tool(
tools["background_agents_continue_task"],
task_id=999,
text="hello",
)
assert "Error" in result
# --- Clear Task Tests ---
async def test_clear_completed_task() -> None:
"""Should clear a completed task."""
provider = _make_provider(_FakeAgent("Worker", response_text="done", delay=0.01))
session = _make_session()
tools = await _get_tools(provider, session)
await _invoke_tool(
tools["background_agents_start_task"],
agent_name="Worker",
input="task",
description="clearable",
)
await _invoke_tool(
tools["background_agents_wait_for_first_completion"],
task_ids=[1],
)
result = await _invoke_tool(
tools["background_agents_clear_completed_task"],
task_id=1,
)
assert "cleared" in result.lower()
# Verify task is gone.
all_tasks = await _invoke_tool(tools["background_agents_get_all_tasks"])
assert "No tasks" in all_tasks
async def test_clear_running_task_error() -> None:
"""Should return error when clearing a running task."""
provider = _make_provider(_FakeAgent("Slow", delay=10.0))
session = _make_session()
tools = await _get_tools(provider, session)
await _invoke_tool(
tools["background_agents_start_task"],
agent_name="Slow",
input="task",
description="still going",
)
try:
result = await _invoke_tool(
tools["background_agents_clear_completed_task"],
task_id=1,
)
assert "still running" in result.lower()
finally:
runtime = provider._get_runtime(session)
for task in list(runtime.in_flight_tasks.values()):
task.cancel()
await asyncio.gather(*runtime.in_flight_tasks.values(), return_exceptions=True)
async def test_clear_not_found() -> None:
"""Should return error for non-existent task."""
provider = _make_provider(_FakeAgent("Worker"))
session = _make_session()
tools = await _get_tools(provider, session)
result = await _invoke_tool(
tools["background_agents_clear_completed_task"],
task_id=999,
)
assert "Error" in result
# --- BackgroundTaskInfo Tests ---
def test_task_info_serialization() -> None:
"""BackgroundTaskInfo should round-trip through to_dict/from_dict."""
info = BackgroundTaskInfo(
id=1,
agent_name="Worker",
description="test task",
status=BackgroundTaskStatus.COMPLETED,
result_text="hello",
)
data = info.to_dict()
restored = BackgroundTaskInfo.from_dict(data)
assert restored.id == 1
assert restored.agent_name == "Worker"
assert restored.status == BackgroundTaskStatus.COMPLETED
assert restored.result_text == "hello"
assert restored.error_text is None
def test_task_status_enum_values() -> None:
"""BackgroundTaskStatus should have expected values."""
assert BackgroundTaskStatus.RUNNING == "running"
assert BackgroundTaskStatus.COMPLETED == "completed"
assert BackgroundTaskStatus.FAILED == "failed"
assert BackgroundTaskStatus.LOST == "lost"