mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add a BackgroundAgentsProvider for python (#6069)
* Add a BackgroundAgentsProvider for python * Address PR comments and fix linting warnings * Address PR comment
This commit is contained in:
committed by
GitHub
Unverified
parent
3242d8a4c4
commit
ae989b92e7
@@ -79,6 +79,12 @@ from ._evaluation import (
|
||||
tool_calls_present,
|
||||
)
|
||||
from ._feature_stage import ExperimentalFeature, ReleaseCandidateFeature
|
||||
from ._harness._background_agents import (
|
||||
DEFAULT_BACKGROUND_AGENTS_SOURCE_ID,
|
||||
BackgroundAgentsProvider,
|
||||
BackgroundTaskInfo,
|
||||
BackgroundTaskStatus,
|
||||
)
|
||||
from ._harness._memory import (
|
||||
DEFAULT_MEMORY_SOURCE_ID,
|
||||
MemoryContextProvider,
|
||||
@@ -297,6 +303,7 @@ __all__ = [
|
||||
"AGENT_FRAMEWORK_USER_AGENT",
|
||||
"APP_INFO",
|
||||
"COMPACTION_STATE_KEY",
|
||||
"DEFAULT_BACKGROUND_AGENTS_SOURCE_ID",
|
||||
"DEFAULT_MAX_ITERATIONS",
|
||||
"DEFAULT_MEMORY_SOURCE_ID",
|
||||
"DEFAULT_MODE_SOURCE_ID",
|
||||
@@ -332,6 +339,9 @@ __all__ = [
|
||||
"AgentSession",
|
||||
"AggregatingSkillsSource",
|
||||
"Annotation",
|
||||
"BackgroundAgentsProvider",
|
||||
"BackgroundTaskInfo",
|
||||
"BackgroundTaskStatus",
|
||||
"BaseAgent",
|
||||
"BaseChatClient",
|
||||
"BaseEmbeddingClient",
|
||||
|
||||
@@ -0,0 +1,521 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""BackgroundAgentsProvider: enables an agent to delegate work to background sub-agents asynchronously.
|
||||
|
||||
This module provides :class:`BackgroundAgentsProvider`, a context provider that allows
|
||||
a parent agent to start background tasks on child agents, wait for their completion,
|
||||
and retrieve results. Each background task runs in its own session concurrently.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, MutableMapping, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar, cast
|
||||
|
||||
from .._agents import SupportsAgentRun
|
||||
from .._feature_stage import ExperimentalFeature, experimental
|
||||
from .._serialization import SerializationMixin
|
||||
from .._sessions import AgentSession, ContextProvider, SessionContext
|
||||
from .._tools import tool
|
||||
from .._types import AgentResponse, Message
|
||||
|
||||
DEFAULT_BACKGROUND_AGENTS_SOURCE_ID = "background_agents"
|
||||
|
||||
DEFAULT_BACKGROUND_AGENTS_INSTRUCTIONS = """\
|
||||
## Background Agents
|
||||
|
||||
You have access to background agents that can perform work on your behalf.
|
||||
|
||||
- Use the `background_agents_*` tools to start tasks on background agents and check their results.
|
||||
- Creating a background task does not block, and background tasks run concurrently.
|
||||
- Important: Always wait for outstanding tasks to finish before you finish processing.
|
||||
- Important: After retrieving results from a completed task, clear it with \
|
||||
background_agents_clear_completed_task to free memory, unless you plan to continue it with \
|
||||
background_agents_continue_task.
|
||||
|
||||
{background_agents}"""
|
||||
|
||||
|
||||
class BackgroundTaskStatus(str, Enum):
|
||||
"""Status of a background task."""
|
||||
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
LOST = "lost"
|
||||
|
||||
|
||||
@experimental(feature_id=ExperimentalFeature.HARNESS)
|
||||
class BackgroundTaskInfo(SerializationMixin):
|
||||
"""Metadata for a single background task."""
|
||||
|
||||
DEFAULT_EXCLUDE: ClassVar[set[str]] = set()
|
||||
|
||||
id: int
|
||||
agent_name: str
|
||||
description: str
|
||||
status: BackgroundTaskStatus
|
||||
result_text: str | None
|
||||
error_text: str | None
|
||||
__slots__ = ("agent_name", "description", "error_text", "id", "result_text", "status")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
agent_name: str,
|
||||
description: str,
|
||||
status: BackgroundTaskStatus = BackgroundTaskStatus.RUNNING,
|
||||
result_text: str | None = None,
|
||||
error_text: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize a background task info entry."""
|
||||
self.id = id
|
||||
self.agent_name = agent_name
|
||||
self.description = description
|
||||
self.status = status
|
||||
self.result_text = result_text
|
||||
self.error_text = error_text
|
||||
|
||||
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]:
|
||||
"""Serialize for session state persistence."""
|
||||
del exclude
|
||||
data: dict[str, Any] = {
|
||||
"id": self.id,
|
||||
"agent_name": self.agent_name,
|
||||
"description": self.description,
|
||||
"status": self.status.value,
|
||||
}
|
||||
if not exclude_none or self.result_text is not None:
|
||||
data["result_text"] = self.result_text
|
||||
if not exclude_none or self.error_text is not None:
|
||||
data["error_text"] = self.error_text
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: MutableMapping[str, Any], **kwargs: Any) -> BackgroundTaskInfo:
|
||||
"""Deserialize from session state."""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
agent_name=data["agent_name"],
|
||||
description=data["description"],
|
||||
status=BackgroundTaskStatus(data["status"]),
|
||||
result_text=data.get("result_text"),
|
||||
error_text=data.get("error_text"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _RuntimeState:
|
||||
"""Non-serializable per-session runtime state for background tasks."""
|
||||
|
||||
in_flight_tasks: dict[int, asyncio.Task[AgentResponse[Any]]] = field(
|
||||
default_factory=lambda: {} # pyright: ignore[reportUnknownLambdaType]
|
||||
)
|
||||
background_sessions: dict[int, AgentSession] = field(
|
||||
default_factory=lambda: {} # pyright: ignore[reportUnknownLambdaType]
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level helper functions (following ModeProvider pattern)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _run_agent(awaitable: Awaitable[AgentResponse[Any]]) -> AgentResponse[Any]:
|
||||
"""Wrap an Awaitable in a proper coroutine for use with asyncio.create_task."""
|
||||
return await awaitable
|
||||
|
||||
|
||||
def _validate_and_build_agent_dict(agents: Sequence[SupportsAgentRun]) -> dict[str, SupportsAgentRun]:
|
||||
"""Validate agents and build a case-insensitive lookup dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If agents is empty, an agent has no name, or names are not unique.
|
||||
"""
|
||||
if not agents:
|
||||
raise ValueError("At least one background agent must be provided.")
|
||||
|
||||
agent_dict: dict[str, SupportsAgentRun] = {}
|
||||
for agent in agents:
|
||||
name = agent.name
|
||||
if not name or not name.strip():
|
||||
raise ValueError("All background agents must have a non-empty name.")
|
||||
key = name.lower()
|
||||
if key in agent_dict:
|
||||
raise ValueError(
|
||||
f"Duplicate background agent name: '{name}'. Agent names must be unique (case-insensitive)."
|
||||
)
|
||||
agent_dict[key] = agent
|
||||
return agent_dict
|
||||
|
||||
|
||||
def _build_agent_list_text(agents: dict[str, SupportsAgentRun]) -> str:
|
||||
"""Build text listing available background agents."""
|
||||
lines = ["Available background agents:"]
|
||||
for agent in agents.values():
|
||||
line = f"- {agent.name}"
|
||||
if agent.description:
|
||||
line += f": {agent.description}"
|
||||
lines.append(line)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _get_provider_state(session: AgentSession, *, source_id: str) -> dict[str, Any]:
|
||||
"""Load or initialize serializable provider state from session."""
|
||||
state = session.state.get(source_id)
|
||||
if state is None:
|
||||
initial: dict[str, Any] = {"next_task_id": 1, "tasks": []}
|
||||
session.state[source_id] = initial
|
||||
return initial
|
||||
return cast(dict[str, Any], state)
|
||||
|
||||
|
||||
def _save_provider_state(session: AgentSession, state: dict[str, Any], *, source_id: str) -> None:
|
||||
"""Persist serializable state to session."""
|
||||
session.state[source_id] = state
|
||||
|
||||
|
||||
def _get_tasks(state: dict[str, Any]) -> list[BackgroundTaskInfo]:
|
||||
"""Parse task list from state dict."""
|
||||
return [BackgroundTaskInfo.from_dict(t) for t in state.get("tasks", [])]
|
||||
|
||||
|
||||
def _save_tasks(state: dict[str, Any], tasks: list[BackgroundTaskInfo]) -> None:
|
||||
"""Serialize task list back to state dict."""
|
||||
state["tasks"] = [t.to_dict() for t in tasks]
|
||||
|
||||
|
||||
def _finalize_task(
|
||||
task_info: BackgroundTaskInfo,
|
||||
completed_task: asyncio.Task[AgentResponse[Any]],
|
||||
runtime: _RuntimeState,
|
||||
) -> None:
|
||||
"""Extract results from a completed asyncio task and update task info."""
|
||||
if completed_task.cancelled():
|
||||
task_info.status = BackgroundTaskStatus.FAILED
|
||||
task_info.error_text = "Task was canceled."
|
||||
else:
|
||||
exception = completed_task.exception()
|
||||
if exception is not None:
|
||||
task_info.status = BackgroundTaskStatus.FAILED
|
||||
task_info.error_text = str(exception)
|
||||
else:
|
||||
task_info.status = BackgroundTaskStatus.COMPLETED
|
||||
task_info.result_text = completed_task.result().text
|
||||
runtime.in_flight_tasks.pop(task_info.id, None)
|
||||
|
||||
|
||||
def _refresh_task_state(
|
||||
session: AgentSession, state: dict[str, Any], runtime: _RuntimeState, *, source_id: str
|
||||
) -> list[BackgroundTaskInfo]:
|
||||
"""Refresh status of in-flight tasks and return updated task list."""
|
||||
tasks = _get_tasks(state)
|
||||
changed = False
|
||||
|
||||
for task_info in tasks:
|
||||
if task_info.status != BackgroundTaskStatus.RUNNING:
|
||||
continue
|
||||
|
||||
in_flight = runtime.in_flight_tasks.get(task_info.id)
|
||||
if in_flight is None:
|
||||
task_info.status = BackgroundTaskStatus.LOST
|
||||
changed = True
|
||||
continue
|
||||
|
||||
if in_flight.done():
|
||||
_finalize_task(task_info, in_flight, runtime)
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
_save_tasks(state, tasks)
|
||||
_save_provider_state(session, state, source_id=source_id)
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider class
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@experimental(feature_id=ExperimentalFeature.HARNESS)
|
||||
class BackgroundAgentsProvider(ContextProvider):
|
||||
"""Context provider that enables an agent to delegate work to background sub-agents.
|
||||
|
||||
The ``BackgroundAgentsProvider`` allows a parent agent to start background tasks on child agents,
|
||||
wait for their completion, and retrieve results. Each background task runs in its own session and
|
||||
executes concurrently.
|
||||
|
||||
This provider exposes the following tools to the agent:
|
||||
|
||||
- ``background_agents_start_task`` — Start a background task on a named agent with text input.
|
||||
- ``background_agents_wait_for_first_completion`` — Block until the first of the specified tasks completes.
|
||||
- ``background_agents_get_task_results`` — Retrieve the text output of a completed background task.
|
||||
- ``background_agents_get_all_tasks`` — List all background tasks with their IDs, statuses, and descriptions.
|
||||
- ``background_agents_continue_task`` — Send follow-up input to a completed task's session to resume work.
|
||||
- ``background_agents_clear_completed_task`` — Remove a completed task and release its session.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agents: Sequence[SupportsAgentRun],
|
||||
*,
|
||||
source_id: str = DEFAULT_BACKGROUND_AGENTS_SOURCE_ID,
|
||||
instructions: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the background agents provider.
|
||||
|
||||
Args:
|
||||
agents: Collection of background agents available for delegation.
|
||||
Each agent must have a non-empty, unique name (case-insensitive).
|
||||
|
||||
Keyword Args:
|
||||
source_id: Unique source ID for serializable task state in session.
|
||||
instructions: Optional instruction override. May include ``{background_agents}``
|
||||
placeholder which will be replaced with the agent listing.
|
||||
|
||||
Raises:
|
||||
ValueError: If agents is empty, an agent has no name, or names are not unique.
|
||||
"""
|
||||
super().__init__(source_id)
|
||||
|
||||
self._agents = _validate_and_build_agent_dict(agents)
|
||||
|
||||
# Build instructions with agent listing.
|
||||
base_instructions = instructions if instructions is not None else DEFAULT_BACKGROUND_AGENTS_INSTRUCTIONS
|
||||
agent_list_text = _build_agent_list_text(self._agents)
|
||||
self._instructions = base_instructions.replace("{background_agents}", agent_list_text)
|
||||
|
||||
# Per-session runtime state (non-serializable), keyed by session_id.
|
||||
# Note: Runtime state (in-flight asyncio.Task objects, child AgentSession handles)
|
||||
# is inherently non-serializable and cannot survive process restarts. If the provider
|
||||
# instance is lost, _refresh_task_state() marks orphaned tasks as LOST.
|
||||
self._runtime: dict[str, _RuntimeState] = {}
|
||||
|
||||
def _get_runtime(self, session: AgentSession) -> _RuntimeState:
|
||||
"""Get or create runtime state for a session."""
|
||||
session_id = session.session_id
|
||||
if session_id not in self._runtime:
|
||||
self._runtime[session_id] = _RuntimeState()
|
||||
return self._runtime[session_id]
|
||||
|
||||
async def before_run(
|
||||
self,
|
||||
*,
|
||||
agent: Any,
|
||||
session: AgentSession,
|
||||
context: SessionContext,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
"""Inject background agent tools and instructions before the model runs."""
|
||||
del agent, state
|
||||
|
||||
provider_state = _get_provider_state(session, source_id=self.source_id)
|
||||
runtime = self._get_runtime(session)
|
||||
source_id = self.source_id
|
||||
|
||||
@tool(name="background_agents_start_task", approval_mode="never_require")
|
||||
def background_agents_start_task(agent_name: str, input: str, description: str) -> str:
|
||||
"""Start a background task on a named agent. Returns a confirmation with the task ID."""
|
||||
key = agent_name.lower()
|
||||
if key not in self._agents:
|
||||
available = ", ".join(a.name or "" for a in self._agents.values())
|
||||
return f"Error: No background agent found with name '{agent_name}'. Available agents: {available}"
|
||||
|
||||
bg_agent = self._agents[key]
|
||||
task_id = provider_state.get("next_task_id", 1)
|
||||
provider_state["next_task_id"] = task_id + 1
|
||||
|
||||
task_info = BackgroundTaskInfo(
|
||||
id=task_id,
|
||||
agent_name=agent_name,
|
||||
description=description,
|
||||
)
|
||||
tasks = _get_tasks(provider_state)
|
||||
tasks.append(task_info)
|
||||
_save_tasks(provider_state, tasks)
|
||||
|
||||
# Create a dedicated session for this background task.
|
||||
sub_session = bg_agent.create_session()
|
||||
|
||||
# Start the task concurrently.
|
||||
async_task = asyncio.create_task(_run_agent(bg_agent.run(input, session=sub_session)))
|
||||
runtime.in_flight_tasks[task_id] = async_task
|
||||
runtime.background_sessions[task_id] = sub_session
|
||||
|
||||
_save_provider_state(session, provider_state, source_id=source_id)
|
||||
return f"Background task {task_id} started on agent '{agent_name}'."
|
||||
|
||||
@tool(name="background_agents_wait_for_first_completion", approval_mode="never_require")
|
||||
async def background_agents_wait_for_first_completion(task_ids: list[int]) -> str:
|
||||
"""Block until the first of the specified background tasks completes. Returns the completed task's ID."""
|
||||
if not task_ids:
|
||||
return "Error: No task IDs provided."
|
||||
|
||||
# Collect in-flight tasks matching the requested IDs.
|
||||
waitable: list[tuple[int, asyncio.Task[AgentResponse[Any]]]] = []
|
||||
for tid in task_ids:
|
||||
in_flight = runtime.in_flight_tasks.get(tid)
|
||||
if in_flight is not None:
|
||||
waitable.append((tid, in_flight))
|
||||
|
||||
if not waitable:
|
||||
# Refresh state to catch any that completed.
|
||||
tasks = _refresh_task_state(session, provider_state, runtime, source_id=source_id)
|
||||
already_complete = next(
|
||||
(t for t in tasks if t.id in task_ids and t.status != BackgroundTaskStatus.RUNNING), None
|
||||
)
|
||||
if already_complete is not None:
|
||||
return (
|
||||
f"Task {already_complete.id} is not running; current status: {already_complete.status.value}."
|
||||
)
|
||||
return "Error: None of the specified task IDs correspond to running tasks."
|
||||
|
||||
# Wait for the first one to complete.
|
||||
done, _ = await asyncio.wait(
|
||||
[t for _, t in waitable],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
# Find which ID completed.
|
||||
completed_id: int | None = None
|
||||
for tid, task in waitable:
|
||||
if task in done:
|
||||
completed_id = tid
|
||||
break
|
||||
|
||||
# Finalize the completed task.
|
||||
tasks = _get_tasks(provider_state)
|
||||
task_info = next((t for t in tasks if t.id == completed_id), None)
|
||||
if task_info is not None and completed_id is not None:
|
||||
completed_task = runtime.in_flight_tasks.get(completed_id)
|
||||
if completed_task is not None:
|
||||
_finalize_task(task_info, completed_task, runtime)
|
||||
_save_tasks(provider_state, tasks)
|
||||
_save_provider_state(session, provider_state, source_id=source_id)
|
||||
|
||||
status_str = task_info.status.value if task_info else "Unknown"
|
||||
return f"Task {completed_id} finished with status: {status_str}."
|
||||
|
||||
@tool(name="background_agents_get_task_results", approval_mode="never_require")
|
||||
def background_agents_get_task_results(task_id: int) -> str:
|
||||
"""Get the text output of a background task by its ID."""
|
||||
tasks = _refresh_task_state(session, provider_state, runtime, source_id=source_id)
|
||||
task_info = next((t for t in tasks if t.id == task_id), None)
|
||||
|
||||
if task_info is None:
|
||||
return f"Error: No task found with ID {task_id}."
|
||||
|
||||
if task_info.status == BackgroundTaskStatus.COMPLETED:
|
||||
return task_info.result_text or "(no output)"
|
||||
if task_info.status == BackgroundTaskStatus.FAILED:
|
||||
return f"Task failed: {task_info.error_text or 'Unknown error'}"
|
||||
if task_info.status == BackgroundTaskStatus.LOST:
|
||||
return "Task state was lost (reference unavailable)."
|
||||
if task_info.status == BackgroundTaskStatus.RUNNING:
|
||||
return f"Task {task_id} is still running."
|
||||
return f"Task {task_id} has status: {task_info.status.value}."
|
||||
|
||||
@tool(name="background_agents_get_all_tasks", approval_mode="never_require")
|
||||
def background_agents_get_all_tasks() -> str:
|
||||
"""List all background tasks with their IDs, statuses, agent names, and descriptions."""
|
||||
tasks = _refresh_task_state(session, provider_state, runtime, source_id=source_id)
|
||||
|
||||
if not tasks:
|
||||
return "No tasks."
|
||||
|
||||
lines = ["Tasks:"]
|
||||
for t in tasks:
|
||||
lines.append(f"- Task {t.id} [{t.status.value}] ({t.agent_name}): {t.description}")
|
||||
return "\n".join(lines)
|
||||
|
||||
@tool(name="background_agents_continue_task", approval_mode="never_require")
|
||||
def background_agents_continue_task(task_id: int, text: str) -> str:
|
||||
"""Send follow-up input to a completed or failed task to resume its work."""
|
||||
tasks = _refresh_task_state(session, provider_state, runtime, source_id=source_id)
|
||||
task_info = next((t for t in tasks if t.id == task_id), None)
|
||||
|
||||
if task_info is None:
|
||||
return f"Error: No task found with ID {task_id}."
|
||||
|
||||
if task_info.status == BackgroundTaskStatus.LOST:
|
||||
return (
|
||||
f"Error: Task {task_id} cannot be continued because its session was lost. Start a new task instead."
|
||||
)
|
||||
|
||||
if task_info.status == BackgroundTaskStatus.RUNNING:
|
||||
return f"Error: Task {task_id} is still running. Wait for it to complete before continuing."
|
||||
|
||||
key = task_info.agent_name.lower()
|
||||
if key not in self._agents:
|
||||
return f"Error: Agent '{task_info.agent_name}' is no longer available."
|
||||
|
||||
sub_session = runtime.background_sessions.get(task_id)
|
||||
if sub_session is None:
|
||||
return f"Error: Session for task {task_id} is no longer available."
|
||||
|
||||
bg_agent = self._agents[key]
|
||||
|
||||
# Reset task state and start a new run on the existing session.
|
||||
task_info.status = BackgroundTaskStatus.RUNNING
|
||||
task_info.result_text = None
|
||||
task_info.error_text = None
|
||||
_save_tasks(provider_state, tasks)
|
||||
|
||||
async_task = asyncio.create_task(_run_agent(bg_agent.run(text, session=sub_session)))
|
||||
runtime.in_flight_tasks[task_id] = async_task
|
||||
|
||||
_save_provider_state(session, provider_state, source_id=source_id)
|
||||
return f"Task {task_id} continued with new input."
|
||||
|
||||
@tool(name="background_agents_clear_completed_task", approval_mode="never_require")
|
||||
def background_agents_clear_completed_task(task_id: int) -> str:
|
||||
"""Remove a completed or failed task and release its session to free memory."""
|
||||
tasks = _refresh_task_state(session, provider_state, runtime, source_id=source_id)
|
||||
task_info = next((t for t in tasks if t.id == task_id), None)
|
||||
|
||||
if task_info is None:
|
||||
return f"Error: No task found with ID {task_id}."
|
||||
|
||||
if task_info.status == BackgroundTaskStatus.RUNNING:
|
||||
return f"Error: Task {task_id} is still running. Wait for it to complete before clearing."
|
||||
|
||||
# Remove the task from state.
|
||||
tasks = [t for t in tasks if t.id != task_id]
|
||||
_save_tasks(provider_state, tasks)
|
||||
|
||||
# Clean up runtime references.
|
||||
runtime.in_flight_tasks.pop(task_id, None)
|
||||
runtime.background_sessions.pop(task_id, None)
|
||||
|
||||
_save_provider_state(session, provider_state, source_id=source_id)
|
||||
return f"Task {task_id} cleared."
|
||||
|
||||
# Inject instructions and current task status.
|
||||
context.extend_instructions(self.source_id, [self._instructions])
|
||||
context.extend_tools(
|
||||
self.source_id,
|
||||
[
|
||||
background_agents_start_task,
|
||||
background_agents_wait_for_first_completion,
|
||||
background_agents_get_task_results,
|
||||
background_agents_get_all_tasks,
|
||||
background_agents_continue_task,
|
||||
background_agents_clear_completed_task,
|
||||
],
|
||||
)
|
||||
|
||||
# Include current task status as context message if there are tasks.
|
||||
# Refresh first to get accurate statuses for any tasks that completed between turns.
|
||||
tasks = _refresh_task_state(session, provider_state, runtime, source_id=source_id)
|
||||
if tasks:
|
||||
status_lines = ["### Current background tasks"]
|
||||
for t in tasks:
|
||||
status_lines.append(f"- Task {t.id} [{t.status.value}] ({t.agent_name}): {t.description}")
|
||||
context.extend_messages(
|
||||
self.source_id,
|
||||
[Message(role="user", contents=["\n".join(status_lines)])],
|
||||
)
|
||||
@@ -0,0 +1,538 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import (
|
||||
AgentResponse,
|
||||
AgentSession,
|
||||
BackgroundAgentsProvider,
|
||||
BackgroundTaskInfo,
|
||||
BackgroundTaskStatus,
|
||||
Message,
|
||||
)
|
||||
from agent_framework._sessions import SessionContext
|
||||
|
||||
# Suppress "coroutine was never awaited" warnings from task cancellation in tests.
|
||||
# This occurs when cancelling tasks that wrap coroutines through _run_agent().
|
||||
pytestmark = pytest.mark.filterwarnings("ignore::RuntimeWarning:asyncio")
|
||||
|
||||
# --- Test Helpers ---
|
||||
|
||||
|
||||
class _FakeAgent:
|
||||
"""Minimal agent stub for testing background agent delegation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
*,
|
||||
response_text: str = "done",
|
||||
delay: float = 0.0,
|
||||
should_fail: bool = False,
|
||||
):
|
||||
self.id = f"agent-{name}"
|
||||
self.name = name
|
||||
self.description = description
|
||||
self._response_text = response_text
|
||||
self._delay = delay
|
||||
self._should_fail = should_fail
|
||||
|
||||
def create_session(self, *, session_id: str | None = None) -> AgentSession:
|
||||
return AgentSession(session_id=session_id)
|
||||
|
||||
def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession:
|
||||
return AgentSession(service_session_id=service_session_id, session_id=session_id)
|
||||
|
||||
async def run(
|
||||
self, messages: Any = None, *, stream: bool = False, session: Any = None, **kwargs: Any
|
||||
) -> AgentResponse[Any]:
|
||||
if self._delay > 0:
|
||||
await asyncio.sleep(self._delay)
|
||||
if self._should_fail:
|
||||
raise RuntimeError("Agent execution failed")
|
||||
return AgentResponse(messages=[Message(role="assistant", contents=[self._response_text])])
|
||||
|
||||
|
||||
def _make_provider(*agents: _FakeAgent) -> BackgroundAgentsProvider:
|
||||
"""Create a provider with given agents."""
|
||||
return BackgroundAgentsProvider(agents)
|
||||
|
||||
|
||||
def _make_session() -> AgentSession:
|
||||
"""Create a session for testing."""
|
||||
return AgentSession()
|
||||
|
||||
|
||||
async def _get_tools(provider: BackgroundAgentsProvider, session: AgentSession) -> dict[str, Any]:
|
||||
"""Run before_run and return tools by name."""
|
||||
context = SessionContext(input_messages=[])
|
||||
await provider.before_run(agent=None, session=session, context=context, state={})
|
||||
tools_by_name: dict[str, Any] = {}
|
||||
for t in context.tools:
|
||||
tools_by_name[t.name if hasattr(t, "name") else str(t)] = t
|
||||
return tools_by_name
|
||||
|
||||
|
||||
async def _invoke_tool(tool_obj: Any, **kwargs: Any) -> str:
|
||||
"""Invoke a FunctionTool and return the raw result string."""
|
||||
return await tool_obj.invoke(arguments=kwargs, skip_parsing=True)
|
||||
|
||||
|
||||
# --- Constructor Tests ---
|
||||
|
||||
|
||||
def test_constructor_requires_at_least_one_agent() -> None:
|
||||
"""Should reject empty agent list."""
|
||||
with pytest.raises(ValueError, match="At least one background agent"):
|
||||
BackgroundAgentsProvider([])
|
||||
|
||||
|
||||
def test_constructor_requires_agent_names() -> None:
|
||||
"""Should reject agents with no name."""
|
||||
agent = _FakeAgent("")
|
||||
with pytest.raises(ValueError, match="non-empty name"):
|
||||
BackgroundAgentsProvider([agent])
|
||||
|
||||
|
||||
def test_constructor_rejects_duplicate_names() -> None:
|
||||
"""Should reject duplicate agent names (case-insensitive)."""
|
||||
agent1 = _FakeAgent("Research")
|
||||
agent2 = _FakeAgent("research")
|
||||
with pytest.raises(ValueError, match="Duplicate background agent name"):
|
||||
BackgroundAgentsProvider([agent1, agent2])
|
||||
|
||||
|
||||
def test_constructor_valid_agents() -> None:
|
||||
"""Should succeed with valid unique agents."""
|
||||
provider = BackgroundAgentsProvider([_FakeAgent("Alpha"), _FakeAgent("Beta")])
|
||||
assert provider.source_id == "background_agents"
|
||||
|
||||
|
||||
def test_constructor_custom_source_id() -> None:
|
||||
"""Should accept custom source_id."""
|
||||
provider = BackgroundAgentsProvider([_FakeAgent("Agent1")], source_id="custom_bg")
|
||||
assert provider.source_id == "custom_bg"
|
||||
|
||||
|
||||
# --- Tool Injection Tests ---
|
||||
|
||||
|
||||
async def test_before_run_injects_six_tools() -> None:
|
||||
"""before_run should inject exactly 6 tools."""
|
||||
provider = _make_provider(_FakeAgent("Worker"))
|
||||
tools = await _get_tools(provider, _make_session())
|
||||
assert len(tools) == 6
|
||||
expected_names = {
|
||||
"background_agents_start_task",
|
||||
"background_agents_wait_for_first_completion",
|
||||
"background_agents_get_task_results",
|
||||
"background_agents_get_all_tasks",
|
||||
"background_agents_continue_task",
|
||||
"background_agents_clear_completed_task",
|
||||
}
|
||||
assert set(tools.keys()) == expected_names
|
||||
|
||||
|
||||
async def test_before_run_injects_instructions() -> None:
|
||||
"""before_run should inject instructions mentioning agent names."""
|
||||
provider = _make_provider(_FakeAgent("ResearchBot", "Does research"))
|
||||
context = SessionContext(input_messages=[])
|
||||
session = _make_session()
|
||||
await provider.before_run(agent=None, session=session, context=context, state={})
|
||||
all_instructions = " ".join(context.instructions)
|
||||
assert "ResearchBot" in all_instructions
|
||||
assert "Does research" in all_instructions
|
||||
|
||||
|
||||
# --- Start Task Tests ---
|
||||
|
||||
|
||||
async def test_start_task_success() -> None:
|
||||
"""Should start a task and return confirmation."""
|
||||
provider = _make_provider(_FakeAgent("Worker", response_text="result"))
|
||||
session = _make_session()
|
||||
tools = await _get_tools(provider, session)
|
||||
|
||||
result = await _invoke_tool(
|
||||
tools["background_agents_start_task"],
|
||||
agent_name="Worker",
|
||||
input="do something",
|
||||
description="test task",
|
||||
)
|
||||
assert "task 1 started" in result.lower()
|
||||
assert "Worker" in result
|
||||
|
||||
|
||||
async def test_start_task_unknown_agent() -> None:
|
||||
"""Should return error for unknown agent name."""
|
||||
provider = _make_provider(_FakeAgent("Worker"))
|
||||
session = _make_session()
|
||||
tools = await _get_tools(provider, session)
|
||||
|
||||
result = await _invoke_tool(
|
||||
tools["background_agents_start_task"],
|
||||
agent_name="NonExistent",
|
||||
input="do something",
|
||||
description="test",
|
||||
)
|
||||
assert "Error" in result
|
||||
assert "NonExistent" in result
|
||||
|
||||
|
||||
async def test_start_task_increments_ids() -> None:
|
||||
"""Task IDs should increment sequentially."""
|
||||
provider = _make_provider(_FakeAgent("Worker"))
|
||||
session = _make_session()
|
||||
tools = await _get_tools(provider, session)
|
||||
|
||||
r1 = await _invoke_tool(
|
||||
tools["background_agents_start_task"],
|
||||
agent_name="Worker",
|
||||
input="task 1",
|
||||
description="first",
|
||||
)
|
||||
r2 = await _invoke_tool(
|
||||
tools["background_agents_start_task"],
|
||||
agent_name="Worker",
|
||||
input="task 2",
|
||||
description="second",
|
||||
)
|
||||
assert "task 1 started" in r1.lower()
|
||||
assert "task 2 started" in r2.lower()
|
||||
|
||||
|
||||
# --- Get All Tasks Tests ---
|
||||
|
||||
|
||||
async def test_get_all_tasks_empty() -> None:
|
||||
"""Should return 'No tasks.' when no tasks exist."""
|
||||
provider = _make_provider(_FakeAgent("Worker"))
|
||||
session = _make_session()
|
||||
tools = await _get_tools(provider, session)
|
||||
|
||||
result = await _invoke_tool(tools["background_agents_get_all_tasks"])
|
||||
assert "No tasks" in result
|
||||
|
||||
|
||||
async def test_get_all_tasks_shows_tasks() -> None:
|
||||
"""Should list all tasks with status and description."""
|
||||
provider = _make_provider(_FakeAgent("Worker"))
|
||||
session = _make_session()
|
||||
tools = await _get_tools(provider, session)
|
||||
|
||||
await _invoke_tool(
|
||||
tools["background_agents_start_task"],
|
||||
agent_name="Worker",
|
||||
input="hello",
|
||||
description="my task",
|
||||
)
|
||||
result = await _invoke_tool(tools["background_agents_get_all_tasks"])
|
||||
assert "my task" in result
|
||||
assert "Worker" in result
|
||||
|
||||
|
||||
# --- Wait for Completion Tests ---
|
||||
|
||||
|
||||
async def test_wait_for_first_completion() -> None:
|
||||
"""Should wait and return when a task completes."""
|
||||
provider = _make_provider(_FakeAgent("Fast", response_text="fast result", delay=0.01))
|
||||
session = _make_session()
|
||||
tools = await _get_tools(provider, session)
|
||||
|
||||
await _invoke_tool(
|
||||
tools["background_agents_start_task"],
|
||||
agent_name="Fast",
|
||||
input="go",
|
||||
description="fast task",
|
||||
)
|
||||
result = await _invoke_tool(
|
||||
tools["background_agents_wait_for_first_completion"],
|
||||
task_ids=[1],
|
||||
)
|
||||
assert "finished" in result.lower()
|
||||
assert "completed" in result.lower()
|
||||
|
||||
|
||||
async def test_wait_empty_task_ids() -> None:
|
||||
"""Should return error for empty task_ids."""
|
||||
provider = _make_provider(_FakeAgent("Worker"))
|
||||
session = _make_session()
|
||||
tools = await _get_tools(provider, session)
|
||||
|
||||
result = await _invoke_tool(
|
||||
tools["background_agents_wait_for_first_completion"],
|
||||
task_ids=[],
|
||||
)
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
async def test_wait_no_running_tasks() -> None:
|
||||
"""Should return error when no specified tasks are running."""
|
||||
provider = _make_provider(_FakeAgent("Worker"))
|
||||
session = _make_session()
|
||||
tools = await _get_tools(provider, session)
|
||||
|
||||
result = await _invoke_tool(
|
||||
tools["background_agents_wait_for_first_completion"],
|
||||
task_ids=[999],
|
||||
)
|
||||
assert "Error" in result or "not running" in result.lower()
|
||||
|
||||
|
||||
# --- Get Task Results Tests ---
|
||||
|
||||
|
||||
async def test_get_task_results_completed() -> None:
|
||||
"""Should return result text for completed task."""
|
||||
provider = _make_provider(_FakeAgent("Worker", response_text="the answer", delay=0.01))
|
||||
session = _make_session()
|
||||
tools = await _get_tools(provider, session)
|
||||
|
||||
await _invoke_tool(
|
||||
tools["background_agents_start_task"],
|
||||
agent_name="Worker",
|
||||
input="query",
|
||||
description="test",
|
||||
)
|
||||
# Wait for completion.
|
||||
await _invoke_tool(
|
||||
tools["background_agents_wait_for_first_completion"],
|
||||
task_ids=[1],
|
||||
)
|
||||
result = await _invoke_tool(
|
||||
tools["background_agents_get_task_results"],
|
||||
task_id=1,
|
||||
)
|
||||
assert result == "the answer"
|
||||
|
||||
|
||||
async def test_get_task_results_running() -> None:
|
||||
"""Should indicate task is still running."""
|
||||
provider = _make_provider(_FakeAgent("Slow", delay=10.0))
|
||||
session = _make_session()
|
||||
tools = await _get_tools(provider, session)
|
||||
|
||||
await _invoke_tool(
|
||||
tools["background_agents_start_task"],
|
||||
agent_name="Slow",
|
||||
input="query",
|
||||
description="slow task",
|
||||
)
|
||||
try:
|
||||
result = await _invoke_tool(
|
||||
tools["background_agents_get_task_results"],
|
||||
task_id=1,
|
||||
)
|
||||
assert "still running" in result.lower()
|
||||
finally:
|
||||
runtime = provider._get_runtime(session)
|
||||
for task in list(runtime.in_flight_tasks.values()):
|
||||
task.cancel()
|
||||
await asyncio.gather(*runtime.in_flight_tasks.values(), return_exceptions=True)
|
||||
|
||||
|
||||
async def test_get_task_results_failed() -> None:
|
||||
"""Should return error text for failed task."""
|
||||
provider = _make_provider(_FakeAgent("Broken", should_fail=True, delay=0.01))
|
||||
session = _make_session()
|
||||
tools = await _get_tools(provider, session)
|
||||
|
||||
await _invoke_tool(
|
||||
tools["background_agents_start_task"],
|
||||
agent_name="Broken",
|
||||
input="query",
|
||||
description="will fail",
|
||||
)
|
||||
await _invoke_tool(
|
||||
tools["background_agents_wait_for_first_completion"],
|
||||
task_ids=[1],
|
||||
)
|
||||
result = await _invoke_tool(
|
||||
tools["background_agents_get_task_results"],
|
||||
task_id=1,
|
||||
)
|
||||
assert "failed" in result.lower()
|
||||
|
||||
|
||||
async def test_get_task_results_not_found() -> None:
|
||||
"""Should return error for non-existent task."""
|
||||
provider = _make_provider(_FakeAgent("Worker"))
|
||||
session = _make_session()
|
||||
tools = await _get_tools(provider, session)
|
||||
|
||||
result = await _invoke_tool(
|
||||
tools["background_agents_get_task_results"],
|
||||
task_id=999,
|
||||
)
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
# --- Continue Task Tests ---
|
||||
|
||||
|
||||
async def test_continue_task_after_completion() -> None:
|
||||
"""Should be able to continue a completed task."""
|
||||
provider = _make_provider(_FakeAgent("Worker", response_text="first result", delay=0.01))
|
||||
session = _make_session()
|
||||
tools = await _get_tools(provider, session)
|
||||
|
||||
await _invoke_tool(
|
||||
tools["background_agents_start_task"],
|
||||
agent_name="Worker",
|
||||
input="first input",
|
||||
description="continuable",
|
||||
)
|
||||
await _invoke_tool(
|
||||
tools["background_agents_wait_for_first_completion"],
|
||||
task_ids=[1],
|
||||
)
|
||||
result = await _invoke_tool(
|
||||
tools["background_agents_continue_task"],
|
||||
task_id=1,
|
||||
text="follow up",
|
||||
)
|
||||
assert "continued" in result.lower()
|
||||
|
||||
|
||||
async def test_continue_task_still_running() -> None:
|
||||
"""Should return error if task is still running."""
|
||||
provider = _make_provider(_FakeAgent("Slow", delay=10.0))
|
||||
session = _make_session()
|
||||
tools = await _get_tools(provider, session)
|
||||
|
||||
await _invoke_tool(
|
||||
tools["background_agents_start_task"],
|
||||
agent_name="Slow",
|
||||
input="input",
|
||||
description="running",
|
||||
)
|
||||
try:
|
||||
result = await _invoke_tool(
|
||||
tools["background_agents_continue_task"],
|
||||
task_id=1,
|
||||
text="follow up",
|
||||
)
|
||||
assert "still running" in result.lower()
|
||||
finally:
|
||||
runtime = provider._get_runtime(session)
|
||||
for task in list(runtime.in_flight_tasks.values()):
|
||||
task.cancel()
|
||||
await asyncio.gather(*runtime.in_flight_tasks.values(), return_exceptions=True)
|
||||
|
||||
|
||||
async def test_continue_task_not_found() -> None:
|
||||
"""Should return error for non-existent task."""
|
||||
provider = _make_provider(_FakeAgent("Worker"))
|
||||
session = _make_session()
|
||||
tools = await _get_tools(provider, session)
|
||||
|
||||
result = await _invoke_tool(
|
||||
tools["background_agents_continue_task"],
|
||||
task_id=999,
|
||||
text="hello",
|
||||
)
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
# --- Clear Task Tests ---
|
||||
|
||||
|
||||
async def test_clear_completed_task() -> None:
|
||||
"""Should clear a completed task."""
|
||||
provider = _make_provider(_FakeAgent("Worker", response_text="done", delay=0.01))
|
||||
session = _make_session()
|
||||
tools = await _get_tools(provider, session)
|
||||
|
||||
await _invoke_tool(
|
||||
tools["background_agents_start_task"],
|
||||
agent_name="Worker",
|
||||
input="task",
|
||||
description="clearable",
|
||||
)
|
||||
await _invoke_tool(
|
||||
tools["background_agents_wait_for_first_completion"],
|
||||
task_ids=[1],
|
||||
)
|
||||
result = await _invoke_tool(
|
||||
tools["background_agents_clear_completed_task"],
|
||||
task_id=1,
|
||||
)
|
||||
assert "cleared" in result.lower()
|
||||
|
||||
# Verify task is gone.
|
||||
all_tasks = await _invoke_tool(tools["background_agents_get_all_tasks"])
|
||||
assert "No tasks" in all_tasks
|
||||
|
||||
|
||||
async def test_clear_running_task_error() -> None:
|
||||
"""Should return error when clearing a running task."""
|
||||
provider = _make_provider(_FakeAgent("Slow", delay=10.0))
|
||||
session = _make_session()
|
||||
tools = await _get_tools(provider, session)
|
||||
|
||||
await _invoke_tool(
|
||||
tools["background_agents_start_task"],
|
||||
agent_name="Slow",
|
||||
input="task",
|
||||
description="still going",
|
||||
)
|
||||
try:
|
||||
result = await _invoke_tool(
|
||||
tools["background_agents_clear_completed_task"],
|
||||
task_id=1,
|
||||
)
|
||||
assert "still running" in result.lower()
|
||||
finally:
|
||||
runtime = provider._get_runtime(session)
|
||||
for task in list(runtime.in_flight_tasks.values()):
|
||||
task.cancel()
|
||||
await asyncio.gather(*runtime.in_flight_tasks.values(), return_exceptions=True)
|
||||
|
||||
|
||||
async def test_clear_not_found() -> None:
|
||||
"""Should return error for non-existent task."""
|
||||
provider = _make_provider(_FakeAgent("Worker"))
|
||||
session = _make_session()
|
||||
tools = await _get_tools(provider, session)
|
||||
|
||||
result = await _invoke_tool(
|
||||
tools["background_agents_clear_completed_task"],
|
||||
task_id=999,
|
||||
)
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
# --- BackgroundTaskInfo Tests ---
|
||||
|
||||
|
||||
def test_task_info_serialization() -> None:
|
||||
"""BackgroundTaskInfo should round-trip through to_dict/from_dict."""
|
||||
info = BackgroundTaskInfo(
|
||||
id=1,
|
||||
agent_name="Worker",
|
||||
description="test task",
|
||||
status=BackgroundTaskStatus.COMPLETED,
|
||||
result_text="hello",
|
||||
)
|
||||
data = info.to_dict()
|
||||
restored = BackgroundTaskInfo.from_dict(data)
|
||||
assert restored.id == 1
|
||||
assert restored.agent_name == "Worker"
|
||||
assert restored.status == BackgroundTaskStatus.COMPLETED
|
||||
assert restored.result_text == "hello"
|
||||
assert restored.error_text is None
|
||||
|
||||
|
||||
def test_task_status_enum_values() -> None:
|
||||
"""BackgroundTaskStatus should have expected values."""
|
||||
assert BackgroundTaskStatus.RUNNING == "running"
|
||||
assert BackgroundTaskStatus.COMPLETED == "completed"
|
||||
assert BackgroundTaskStatus.FAILED == "failed"
|
||||
assert BackgroundTaskStatus.LOST == "lost"
|
||||
Reference in New Issue
Block a user