Python: [BREAKING] Renamed AgentProtocol to SupportsAgentRun (#3717)

* Renamed AgentProtocol to AgentLike

* Resolved comments

* Renamed AgentLike to SupportsAgentRun

* Resolved comments
This commit is contained in:
Dmytro Struk
2026-02-06 09:53:21 -08:00
committed by GitHub
Unverified
parent ac17adb595
commit 15256bb616
55 changed files with 354 additions and 354 deletions
@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator
from typing import Any, cast
from ag_ui.core import BaseEvent
from agent_framework import AgentProtocol
from agent_framework import SupportsAgentRun
from ._run import run_agent_stream
@@ -65,13 +65,13 @@ class AgentConfig:
class AgentFrameworkAgent:
"""Wraps Agent Framework agents for AG-UI protocol compatibility.
Translates between Agent Framework's AgentProtocol and AG-UI's event-based
Translates between Agent Framework's SupportsAgentRun and AG-UI's event-based
protocol. Follows a simple linear flow: RunStarted -> content events -> RunFinished.
"""
def __init__(
self,
agent: AgentProtocol,
agent: SupportsAgentRun,
name: str | None = None,
description: str | None = None,
state_schema: Any | None = None,
@@ -8,7 +8,7 @@ from collections.abc import AsyncGenerator, Sequence
from typing import Any
from ag_ui.encoder import EventEncoder
from agent_framework import AgentProtocol
from agent_framework import SupportsAgentRun
from fastapi import FastAPI
from fastapi.params import Depends
from fastapi.responses import StreamingResponse
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
def add_agent_framework_fastapi_endpoint(
app: FastAPI,
agent: AgentProtocol | AgentFrameworkAgent,
agent: SupportsAgentRun | AgentFrameworkAgent,
path: str = "/",
state_schema: Any | None = None,
predict_state_config: dict[str, dict[str, str]] | None = None,
@@ -34,7 +34,7 @@ def add_agent_framework_fastapi_endpoint(
Args:
app: The FastAPI application
agent: The agent to expose (can be raw AgentProtocol or wrapped)
agent: The agent to expose (can be raw SupportsAgentRun or wrapped)
path: The endpoint path
state_schema: Optional state schema for shared state management; accepts dict or Pydantic model/class
predict_state_config: Optional predictive state update configuration.
@@ -47,7 +47,7 @@ def add_agent_framework_fastapi_endpoint(
authentication checks, rate limiting, or other middleware-like behavior.
Example: `dependencies=[Depends(verify_api_key)]`
"""
if isinstance(agent, AgentProtocol):
if isinstance(agent, SupportsAgentRun):
wrapped_agent = AgentFrameworkAgent(
agent=agent,
state_schema=state_schema,
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any
from agent_framework import BaseChatClient
if TYPE_CHECKING:
from agent_framework import AgentProtocol
from agent_framework import SupportsAgentRun
logger = logging.getLogger(__name__)
@@ -29,7 +29,7 @@ def _collect_mcp_tool_functions(mcp_tools: list[Any]) -> list[Any]:
return functions
def collect_server_tools(agent: "AgentProtocol") -> list[Any]:
def collect_server_tools(agent: "SupportsAgentRun") -> list[Any]:
"""Collect server tools from an agent.
This includes both regular tools from default_options and MCP tools.
@@ -64,7 +64,7 @@ def collect_server_tools(agent: "AgentProtocol") -> list[Any]:
return server_tools
def register_additional_client_tools(agent: "AgentProtocol", client_tools: list[Any] | None) -> None:
def register_additional_client_tools(agent: "SupportsAgentRun", client_tools: list[Any] | None) -> None:
"""Register client tools as additional declaration-only tools to avoid server execution.
Args:
@@ -25,10 +25,10 @@ from ag_ui.core import (
ToolCallStartEvent,
)
from agent_framework import (
AgentProtocol,
AgentThread,
ChatMessage,
Content,
SupportsAgentRun,
prepare_function_call_results,
)
from agent_framework._middleware import FunctionMiddlewarePipeline
@@ -579,7 +579,7 @@ def _handle_step_based_approval(messages: list[Any]) -> list[BaseEvent]:
async def _resolve_approval_responses(
messages: list[Any],
tools: list[Any],
agent: AgentProtocol,
agent: SupportsAgentRun,
run_kwargs: dict[str, Any],
) -> None:
"""Execute approved function calls and replace approval content with results.
@@ -741,7 +741,7 @@ def _build_messages_snapshot(
async def run_agent_stream(
input_data: dict[str, Any],
agent: AgentProtocol,
agent: SupportsAgentRun,
config: "AgentConfig",
) -> "AsyncGenerator[BaseEvent, None]":
"""Run agent and yield AG-UI events.
@@ -9,7 +9,6 @@ from typing import Any, Generic, Literal, cast, overload
import pytest
from agent_framework import (
AgentProtocol,
AgentResponse,
AgentResponseUpdate,
AgentThread,
@@ -20,6 +19,7 @@ from agent_framework import (
ChatResponse,
ChatResponseUpdate,
Content,
SupportsAgentRun,
)
from agent_framework._clients import TOptions_co
from agent_framework._middleware import ChatMiddlewareLayer
@@ -149,8 +149,8 @@ def stream_from_updates(updates: list[ChatResponseUpdate]) -> StreamFn:
return _stream
class StubAgent(AgentProtocol):
"""Minimal AgentProtocol stub for orchestrator tests."""
class StubAgent(SupportsAgentRun):
"""Minimal SupportsAgentRun stub for orchestrator tests."""
def __init__(
self,
@@ -238,6 +238,6 @@ def stream_from_updates_fixture() -> Callable[[list[ChatResponseUpdate]], Stream
@pytest.fixture
def stub_agent() -> type[AgentProtocol]:
def stub_agent() -> type[SupportsAgentRun]:
"""Return the StubAgent class for creating test instances."""
return StubAgent # type: ignore[return-value]
@@ -26,7 +26,7 @@ def build_chat_client(streaming_chat_client_stub, stream_from_updates_fixture):
async def test_add_endpoint_with_agent_protocol(build_chat_client):
"""Test adding endpoint with raw AgentProtocol."""
"""Test adding endpoint with raw SupportsAgentRun."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, TypeVar, cast
import azure.durable_functions as df
import azure.functions as func
from agent_framework import AgentProtocol, get_logger
from agent_framework import SupportsAgentRun, get_logger
from agent_framework_durabletask import (
DEFAULT_MAX_POLL_RETRIES,
DEFAULT_POLL_INTERVAL_SECONDS,
@@ -51,12 +51,12 @@ class AgentMetadata:
"""Metadata for a registered agent.
Attributes:
agent: The agent instance implementing AgentProtocol
agent: The agent instance implementing SupportsAgentRun
http_endpoint_enabled: Whether HTTP endpoint is enabled for this agent
mcp_tool_enabled: Whether MCP tool endpoint is enabled for this agent
"""
agent: AgentProtocol
agent: SupportsAgentRun
http_endpoint_enabled: bool
mcp_tool_enabled: bool
@@ -145,7 +145,7 @@ class AgentFunctionApp(DFAppBase):
- Full access to all Azure Functions capabilities
Attributes:
agents: Dictionary of agent name to AgentProtocol instance
agents: Dictionary of agent name to SupportsAgentRun instance
enable_health_check: Whether health check endpoint is enabled
enable_http_endpoints: Whether HTTP endpoints are created for agents
enable_mcp_tool_trigger: Whether MCP tool triggers are created for agents
@@ -160,7 +160,7 @@ class AgentFunctionApp(DFAppBase):
def __init__(
self,
agents: list[AgentProtocol] | None = None,
agents: list[SupportsAgentRun] | None = None,
http_auth_level: func.AuthLevel = func.AuthLevel.FUNCTION,
enable_health_check: bool = True,
enable_http_endpoints: bool = True,
@@ -222,17 +222,17 @@ class AgentFunctionApp(DFAppBase):
logger.debug("[AgentFunctionApp] Initialization complete")
@property
def agents(self) -> dict[str, AgentProtocol]:
def agents(self) -> dict[str, SupportsAgentRun]:
"""Returns dict of agent names to agent instances.
Returns:
Dictionary mapping agent names to their AgentProtocol instances.
Dictionary mapping agent names to their SupportsAgentRun instances.
"""
return {name: metadata.agent for name, metadata in self._agent_metadata.items()}
def add_agent(
self,
agent: AgentProtocol,
agent: SupportsAgentRun,
callback: AgentResponseCallbackProtocol | None = None,
enable_http_endpoint: bool | None = None,
enable_mcp_tool_trigger: bool | None = None,
@@ -240,7 +240,7 @@ class AgentFunctionApp(DFAppBase):
"""Add an agent to the function app after initialization.
Args:
agent: The Microsoft Agent Framework agent instance (must implement AgentProtocol)
agent: The Microsoft Agent Framework agent instance (must implement SupportsAgentRun)
The agent must have a 'name' attribute.
callback: Optional callback invoked during agent execution
enable_http_endpoint: Optional flag to enable/disable HTTP endpoint for this agent.
@@ -322,7 +322,7 @@ class AgentFunctionApp(DFAppBase):
def _setup_agent_functions(
self,
agent: AgentProtocol,
agent: SupportsAgentRun,
agent_name: str,
callback: AgentResponseCallbackProtocol | None,
enable_http_endpoint: bool,
@@ -484,7 +484,7 @@ class AgentFunctionApp(DFAppBase):
def _setup_agent_entity(
self,
agent: AgentProtocol,
agent: SupportsAgentRun,
agent_name: str,
callback: AgentResponseCallbackProtocol | None,
) -> None:
@@ -12,7 +12,7 @@ from collections.abc import Callable
from typing import Any, cast
import azure.durable_functions as df
from agent_framework import AgentProtocol, get_logger
from agent_framework import SupportsAgentRun, get_logger
from agent_framework_durabletask import (
AgentEntity,
AgentEntityStateProviderMixin,
@@ -46,13 +46,13 @@ class AzureFunctionEntityStateProvider(AgentEntityStateProviderMixin):
def create_agent_entity(
agent: AgentProtocol,
agent: SupportsAgentRun,
callback: AgentResponseCallbackProtocol | None = None,
) -> Callable[[df.DurableEntityContext], None]:
"""Factory function to create an agent entity class.
Args:
agent: The Microsoft Agent Framework agent instance (must implement AgentProtocol)
agent: The Microsoft Agent Framework agent instance (must implement SupportsAgentRun)
callback: Optional callback invoked during streaming and final responses
Returns:
+1 -1
View File
@@ -5,7 +5,7 @@ and [OpenAI ChatKit (Python)](https://github.com/openai/chatkit-python/).
Specifically, it mirrors the [Agent SDK integration](https://github.com/openai/chatkit-python/blob/main/docs/server.md#agents-sdk-integration), and provides the following helpers:
- `stream_agent_response`: A helper to convert a streamed `AgentResponseUpdate`
from a Microsoft Agent Framework agent that implements `AgentProtocol` to ChatKit events.
from a Microsoft Agent Framework agent that implements `SupportsAgentRun` to ChatKit events.
- `ThreadItemConverter`: A extendable helper class to convert ChatKit thread items to
`ChatMessage` objects that can be consumed by an Agent Framework agent.
- `simple_to_agent_input`: A helper function that uses the default implementation
+1 -1
View File
@@ -25,7 +25,7 @@ agent_framework/
### Agents (`_agents.py`)
- **`AgentProtocol`** - Protocol defining the agent interface
- **`SupportsAgentRun`** - Protocol defining the agent interface
- **`BaseAgent`** - Abstract base class for agents
- **`ChatAgent`** - Main agent class wrapping a chat client with tools, instructions, and middleware
+10 -10
View File
@@ -163,14 +163,14 @@ class _RunContext(TypedDict):
finalize_kwargs: dict[str, Any]
__all__ = ["AgentProtocol", "BareAgent", "BaseAgent", "ChatAgent", "RawChatAgent"]
__all__ = ["BareAgent", "BaseAgent", "ChatAgent", "RawChatAgent", "SupportsAgentRun"]
# region Agent Protocol
@runtime_checkable
class AgentProtocol(Protocol):
class SupportsAgentRun(Protocol):
"""A protocol for an agent that can be invoked.
This protocol defines the interface that all agents must implement,
@@ -185,11 +185,11 @@ class AgentProtocol(Protocol):
Examples:
.. code-block:: python
from agent_framework import AgentProtocol
from agent_framework import SupportsAgentRun
# Any class implementing the required methods is compatible
# No need to inherit from AgentProtocol or use any framework classes
# No need to inherit from SupportsAgentRun or use any framework classes
class CustomAgent:
def __init__(self):
self.id = "custom-agent-001"
@@ -218,7 +218,7 @@ class AgentProtocol(Protocol):
# Verify the instance satisfies the protocol
instance = CustomAgent()
assert isinstance(instance, AgentProtocol)
assert isinstance(instance, SupportsAgentRun)
"""
id: str
@@ -297,7 +297,7 @@ class BaseAgent(SerializationMixin):
Note:
BaseAgent cannot be instantiated directly as it doesn't implement the
``run()`` and other methods required by AgentProtocol.
``run()`` and other methods required by SupportsAgentRun.
Use a concrete implementation like ChatAgent or create a subclass.
Examples:
@@ -451,7 +451,7 @@ class BaseAgent(SerializationMixin):
A FunctionTool that can be used as a tool by other agents.
Raises:
TypeError: If the agent does not implement AgentProtocol.
TypeError: If the agent does not implement SupportsAgentRun.
ValueError: If the agent tool name cannot be determined.
Examples:
@@ -468,9 +468,9 @@ class BaseAgent(SerializationMixin):
# Use the tool with another agent
coordinator = ChatAgent(chat_client=client, name="coordinator", tools=research_tool)
"""
# Verify that self implements AgentProtocol
if not isinstance(self, AgentProtocol):
raise TypeError(f"Agent {self.__class__.__name__} must implement AgentProtocol to be used as a tool")
# Verify that self implements SupportsAgentRun
if not isinstance(self, SupportsAgentRun):
raise TypeError(f"Agent {self.__class__.__name__} must implement SupportsAgentRun to be used as a tool")
tool_name = name or _sanitize_agent_name(self.name)
if tool_name is None:
@@ -34,7 +34,7 @@ else:
if TYPE_CHECKING:
from pydantic import BaseModel
from ._agents import AgentProtocol
from ._agents import SupportsAgentRun
from ._clients import ChatClientProtocol
from ._threads import AgentThread
from ._tools import FunctionTool
@@ -64,7 +64,7 @@ __all__ = [
"function_middleware",
]
TAgent = TypeVar("TAgent", bound="AgentProtocol")
AgentT = TypeVar("AgentT", bound="SupportsAgentRun")
TContext = TypeVar("TContext")
TUpdate = TypeVar("TUpdate")
@@ -154,7 +154,7 @@ class AgentContext:
def __init__(
self,
*,
agent: AgentProtocol,
agent: SupportsAgentRun,
messages: list[ChatMessage],
thread: AgentThread | None = None,
options: Mapping[str, Any] | None = None,
@@ -9,7 +9,7 @@ from typing_extensions import Never
from agent_framework import Content
from .._agents import AgentProtocol
from .._agents import SupportsAgentRun
from .._threads import AgentThread
from .._types import AgentResponse, AgentResponseUpdate, ChatMessage
from ._agent_utils import resolve_agent_id
@@ -80,7 +80,7 @@ class AgentExecutor(Executor):
def __init__(
self,
agent: AgentProtocol,
agent: SupportsAgentRun,
*,
agent_thread: AgentThread | None = None,
id: str | None = None,
@@ -1,9 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.
from .._agents import AgentProtocol
from .._agents import SupportsAgentRun
def resolve_agent_id(agent: AgentProtocol) -> str:
def resolve_agent_id(agent: SupportsAgentRun) -> str:
"""Resolve the unique identifier for an agent.
Prefers the `.name` attribute if set; otherwise falls back to `.id`.
@@ -6,7 +6,7 @@ from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Any
from .._agents import AgentProtocol
from .._agents import SupportsAgentRun
from .._threads import AgentThread
from ..observability import OtelAttr, capture_exception, create_workflow_span
from ._agent_executor import AgentExecutor
@@ -171,7 +171,7 @@ class WorkflowBuilder:
self._max_iterations: int = max_iterations
self._name: str | None = name
self._description: str | None = description
# Maps underlying AgentProtocol object id -> wrapped Executor so we reuse the same wrapper
# Maps underlying SupportsAgentRun object id -> wrapped Executor so we reuse the same wrapper
# across set_start_executor / add_edge calls. This avoids multiple AgentExecutor instances
# being created for the same agent.
self._agent_wrappers: dict[str, Executor] = {}
@@ -187,7 +187,7 @@ class WorkflowBuilder:
self._executor_registry: dict[str, Callable[[], Executor]] = {}
# Output executors filter; if set, only outputs from these executors are yielded
self._output_executors: list[Executor | AgentProtocol | str] = []
self._output_executors: list[Executor | SupportsAgentRun | str] = []
# Agents auto-wrapped by builder now always stream incremental updates.
@@ -208,8 +208,8 @@ class WorkflowBuilder:
return executor.id
def _maybe_wrap_agent(self, candidate: Executor | AgentProtocol) -> Executor:
"""If the provided object implements AgentProtocol, wrap it in an AgentExecutor.
def _maybe_wrap_agent(self, candidate: Executor | SupportsAgentRun) -> Executor:
"""If the provided object implements SupportsAgentRun, wrap it in an AgentExecutor.
This allows fluent builder APIs to directly accept agents instead of
requiring callers to manually instantiate AgentExecutor.
@@ -221,13 +221,13 @@ class WorkflowBuilder:
An Executor instance, wrapping the agent if necessary.
"""
try: # Local import to avoid hard dependency at import time
from agent_framework import AgentProtocol # type: ignore
from agent_framework import SupportsAgentRun # type: ignore
except Exception: # pragma: no cover - defensive
AgentProtocol = object # type: ignore
SupportsAgentRun = object # type: ignore
if isinstance(candidate, Executor): # Already an executor
return candidate
if isinstance(candidate, AgentProtocol): # type: ignore[arg-type]
if isinstance(candidate, SupportsAgentRun): # type: ignore[arg-type]
# Reuse existing wrapper for the same agent instance if present
agent_instance_id = str(id(candidate))
existing = self._agent_wrappers.get(agent_instance_id)
@@ -244,7 +244,7 @@ class WorkflowBuilder:
return wrapper
raise TypeError(
f"WorkflowBuilder expected an Executor or AgentProtocol instance; got {type(candidate).__name__}."
f"WorkflowBuilder expected an Executor or SupportsAgentRun instance; got {type(candidate).__name__}."
)
def register_executor(self, factory_func: Callable[[], Executor], name: str | list[str]) -> Self:
@@ -321,7 +321,7 @@ class WorkflowBuilder:
def register_agent(
self,
factory_func: Callable[[], AgentProtocol],
factory_func: Callable[[], SupportsAgentRun],
name: str,
agent_thread: AgentThread | None = None,
) -> Self:
@@ -332,7 +332,7 @@ class WorkflowBuilder:
enabling deferred initialization and potentially reducing startup time.
Args:
factory_func: A callable that returns an AgentProtocol instance when called.
factory_func: A callable that returns an SupportsAgentRun instance when called.
name: The name of the registered agent factory. This doesn't have to match
the agent's internal name. But it must be unique within the workflow.
agent_thread: The thread to use for running the agent. If None, a new thread will be created when
@@ -375,8 +375,8 @@ class WorkflowBuilder:
def add_edge(
self,
source: Executor | AgentProtocol | str,
target: Executor | AgentProtocol | str,
source: Executor | SupportsAgentRun | str,
target: Executor | SupportsAgentRun | str,
condition: EdgeCondition | None = None,
) -> Self:
"""Add a directed edge between two executors.
@@ -441,8 +441,8 @@ class WorkflowBuilder:
not isinstance(source, str) and isinstance(target, str)
):
raise ValueError(
"Both source and target must be either registered factory names (str) "
"or Executor/AgentProtocol instances."
"Both source and target must be either registered factory names (str) or "
"Executor/SupportsAgentRun instances."
)
if isinstance(source, str) and isinstance(target, str):
@@ -450,7 +450,7 @@ class WorkflowBuilder:
self._edge_registry.append(_EdgeRegistration(source=source, target=target, condition=condition))
return self
# Both are Executor/AgentProtocol instances; wrap and add now
# Both are Executor/SupportsAgentRun instances; wrap and add now
source_exec = self._maybe_wrap_agent(source) # type: ignore[arg-type]
target_exec = self._maybe_wrap_agent(target) # type: ignore[arg-type]
source_id = self._add_executor(source_exec)
@@ -460,8 +460,8 @@ class WorkflowBuilder:
def add_fan_out_edges(
self,
source: Executor | AgentProtocol | str,
targets: Sequence[Executor | AgentProtocol | str],
source: Executor | SupportsAgentRun | str,
targets: Sequence[Executor | SupportsAgentRun | str],
) -> Self:
"""Add multiple edges to the workflow where messages from the source will be sent to all targets.
@@ -520,8 +520,8 @@ class WorkflowBuilder:
not isinstance(source, str) and any(isinstance(t, str) for t in targets)
):
raise ValueError(
"Both source and targets must be either registered factory names (str) "
"or Executor/AgentProtocol instances."
"Both source and targets must be either registered factory names (str) or "
"Executor/SupportsAgentRun instances."
)
if isinstance(source, str) and all(isinstance(t, str) for t in targets):
@@ -529,7 +529,7 @@ class WorkflowBuilder:
self._edge_registry.append(_FanOutEdgeRegistration(source=source, targets=list(targets))) # type: ignore
return self
# Both are Executor/AgentProtocol instances; wrap and add now
# Both are Executor/SupportsAgentRun instances; wrap and add now
source_exec = self._maybe_wrap_agent(source) # type: ignore[arg-type]
target_execs = [self._maybe_wrap_agent(t) for t in targets] # type: ignore[arg-type]
source_id = self._add_executor(source_exec)
@@ -540,7 +540,7 @@ class WorkflowBuilder:
def add_switch_case_edge_group(
self,
source: Executor | AgentProtocol | str,
source: Executor | SupportsAgentRun | str,
cases: Sequence[Case | Default],
) -> Self:
"""Add an edge group that represents a switch-case statement.
@@ -620,7 +620,7 @@ class WorkflowBuilder:
):
raise ValueError(
"Both source and case targets must be either registered factory names (str) "
"or Executor/AgentProtocol instances."
"or Executor/SupportsAgentRun instances."
)
if isinstance(source, str) and all(isinstance(case.target, str) for case in cases):
@@ -628,7 +628,7 @@ class WorkflowBuilder:
self._edge_registry.append(_SwitchCaseEdgeGroupRegistration(source=source, cases=list(cases))) # type: ignore
return self
# Source is an Executor/AgentProtocol instance; wrap and add now
# Source is an Executor/SupportsAgentRun instance; wrap and add now
source_exec = self._maybe_wrap_agent(source) # type: ignore[arg-type]
source_id = self._add_executor(source_exec)
# Convert case data types to internal types that only uses target_id.
@@ -647,8 +647,8 @@ class WorkflowBuilder:
def add_multi_selection_edge_group(
self,
source: Executor | AgentProtocol | str,
targets: Sequence[Executor | AgentProtocol | str],
source: Executor | SupportsAgentRun | str,
targets: Sequence[Executor | SupportsAgentRun | str],
selection_func: Callable[[Any, list[str]], list[str]],
) -> Self:
"""Add an edge group that represents a multi-selection execution model.
@@ -731,8 +731,8 @@ class WorkflowBuilder:
not isinstance(source, str) and any(isinstance(t, str) for t in targets)
):
raise ValueError(
"Both source and targets must be either registered factory names (str) "
"or Executor/AgentProtocol instances."
"Both source and targets must be either registered factory names (str) or "
"Executor/SupportsAgentRun instances."
)
if isinstance(source, str) and all(isinstance(t, str) for t in targets):
@@ -746,7 +746,7 @@ class WorkflowBuilder:
)
return self
# Both are Executor/AgentProtocol instances; wrap and add now
# Both are Executor/SupportsAgentRun instances; wrap and add now
source_exec = self._maybe_wrap_agent(source) # type: ignore
target_execs = [self._maybe_wrap_agent(t) for t in targets] # type: ignore
source_id = self._add_executor(source_exec)
@@ -757,8 +757,8 @@ class WorkflowBuilder:
def add_fan_in_edges(
self,
sources: Sequence[Executor | AgentProtocol | str],
target: Executor | AgentProtocol | str,
sources: Sequence[Executor | SupportsAgentRun | str],
target: Executor | SupportsAgentRun | str,
) -> Self:
"""Add multiple edges from sources to a single target executor.
@@ -816,8 +816,8 @@ class WorkflowBuilder:
not all(isinstance(s, str) for s in sources) and isinstance(target, str)
):
raise ValueError(
"Both sources and target must be either registered factory names (str) "
"or Executor/AgentProtocol instances."
"Both sources and target must be either registered factory names (str) or "
"Executor/SupportsAgentRun instances."
)
if all(isinstance(s, str) for s in sources) and isinstance(target, str):
@@ -825,7 +825,7 @@ class WorkflowBuilder:
self._edge_registry.append(_FanInEdgeRegistration(sources=list(sources), target=target)) # type: ignore
return self
# Both are Executor/AgentProtocol instances; wrap and add now
# Both are Executor/SupportsAgentRun instances; wrap and add now
source_execs = [self._maybe_wrap_agent(s) for s in sources] # type: ignore
target_exec = self._maybe_wrap_agent(target) # type: ignore
source_ids = [self._add_executor(s) for s in source_execs]
@@ -834,7 +834,7 @@ class WorkflowBuilder:
return self
def add_chain(self, executors: Sequence[Executor | AgentProtocol | str]) -> Self:
def add_chain(self, executors: Sequence[Executor | SupportsAgentRun | str]) -> Self:
"""Add a chain of executors to the workflow.
The output of each executor in the chain will be sent to the next executor in the chain.
@@ -895,7 +895,7 @@ class WorkflowBuilder:
if not all(isinstance(e, str) for e in executors) and any(isinstance(e, str) for e in executors):
raise ValueError(
"All executors in the chain must be either registered factory names (str) "
"or Executor/AgentProtocol instances."
"or Executor/SupportsAgentRun instances."
)
if all(isinstance(e, str) for e in executors):
@@ -904,21 +904,21 @@ class WorkflowBuilder:
self.add_edge(executors[i], executors[i + 1])
return self
# All are Executor/AgentProtocol instances; wrap and add now
# All are Executor/SupportsAgentRun instances; wrap and add now
# Wrap each candidate first to ensure stable IDs before adding edges
wrapped: list[Executor] = [self._maybe_wrap_agent(e) for e in executors] # type: ignore[arg-type]
for i in range(len(wrapped) - 1):
self.add_edge(wrapped[i], wrapped[i + 1])
return self
def set_start_executor(self, executor: Executor | AgentProtocol | str) -> Self:
def set_start_executor(self, executor: Executor | SupportsAgentRun | str) -> Self:
"""Set the starting executor for the workflow.
The start executor is the entry point for the workflow. When the workflow is executed,
the initial message will be sent to this executor.
Args:
executor: The starting executor, which can be an Executor instance, AgentProtocol instance,
executor: The starting executor, which can be an Executor instance, SupportsAgentRun instance,
or the name of a registered executor factory.
Returns:
@@ -1067,7 +1067,7 @@ class WorkflowBuilder:
self._checkpoint_storage = checkpoint_storage
return self
def with_output_from(self, executors: list[Executor | AgentProtocol | str]) -> Self:
def with_output_from(self, executors: list[Executor | SupportsAgentRun | str]) -> Self:
"""Specify which executors' outputs should be collected as workflow outputs.
By default, outputs from all executors are collected. This method allows
@@ -1231,7 +1231,11 @@ class WorkflowBuilder:
if isinstance(factory_name, str)
]
+ [ex.id for ex in self._output_executors if isinstance(ex, Executor)]
+ [resolve_agent_id(agent) for agent in self._output_executors if isinstance(agent, AgentProtocol)]
+ [
resolve_agent_id(agent)
for agent in self._output_executors
if isinstance(agent, SupportsAgentRun)
]
)
# Perform validation before creating the workflow
@@ -38,7 +38,7 @@ if TYPE_CHECKING: # pragma: no cover
from opentelemetry.util._decorator import _AgnosticContextManager # type: ignore[reportPrivateUsage]
from pydantic import BaseModel
from ._agents import AgentProtocol
from ._agents import SupportsAgentRun
from ._clients import ChatClientProtocol
from ._threads import AgentThread
from ._tools import FunctionTool
@@ -70,7 +70,7 @@ __all__ = [
]
TAgent = TypeVar("TAgent", bound="AgentProtocol")
AgentT = TypeVar("AgentT", bound="SupportsAgentRun")
TChatClient = TypeVar("TChatClient", bound="ChatClientProtocol[Any]")
+3 -3
View File
@@ -12,7 +12,6 @@ from pydantic import BaseModel
from pytest import fixture
from agent_framework import (
AgentProtocol,
AgentResponse,
AgentResponseUpdate,
AgentThread,
@@ -24,6 +23,7 @@ from agent_framework import (
Content,
FunctionInvocationLayer,
ResponseStream,
SupportsAgentRun,
ToolProtocol,
tool,
)
@@ -273,7 +273,7 @@ class MockAgentThread(AgentThread):
# Mock Agent implementation for testing
class MockAgent(AgentProtocol):
class MockAgent(SupportsAgentRun):
@property
def id(self) -> str:
return str(uuid4())
@@ -329,5 +329,5 @@ def agent_thread() -> AgentThread:
@fixture
def agent() -> AgentProtocol:
def agent() -> SupportsAgentRun:
return MockAgent()
@@ -10,7 +10,6 @@ import pytest
from pytest import raises
from agent_framework import (
AgentProtocol,
AgentResponse,
AgentResponseUpdate,
AgentThread,
@@ -24,6 +23,7 @@ from agent_framework import (
Context,
ContextProvider,
HostedCodeInterpreterTool,
SupportsAgentRun,
ToolProtocol,
tool,
)
@@ -36,17 +36,17 @@ def test_agent_thread_type(agent_thread: AgentThread) -> None:
assert isinstance(agent_thread, AgentThread)
def test_agent_type(agent: AgentProtocol) -> None:
assert isinstance(agent, AgentProtocol)
def test_agent_type(agent: SupportsAgentRun) -> None:
assert isinstance(agent, SupportsAgentRun)
async def test_agent_run(agent: AgentProtocol) -> None:
async def test_agent_run(agent: SupportsAgentRun) -> None:
response = await agent.run("test")
assert response.messages[0].role == "assistant"
assert response.messages[0].text == "Response"
async def test_agent_run_streaming(agent: AgentProtocol) -> None:
async def test_agent_run_streaming(agent: SupportsAgentRun) -> None:
async def collect_updates(updates: AsyncIterable[AgentResponseUpdate]) -> list[AgentResponseUpdate]:
return [u async for u in updates]
@@ -57,7 +57,7 @@ async def test_agent_run_streaming(agent: AgentProtocol) -> None:
def test_chat_client_agent_type(chat_client: ChatClientProtocol) -> None:
chat_client_agent = ChatAgent(chat_client=chat_client)
assert isinstance(chat_client_agent, AgentProtocol)
assert isinstance(chat_client_agent, SupportsAgentRun)
async def test_chat_client_agent_init(chat_client: ChatClientProtocol) -> None:
@@ -804,7 +804,7 @@ def test_sanitize_agent_name_replaces_invalid_chars():
# endregion
# region Test AgentProtocol.get_new_thread and deserialize_thread
# region Test SupportsAgentRun.get_new_thread and deserialize_thread
@pytest.mark.asyncio
@@ -8,7 +8,6 @@ import pytest
from pydantic import BaseModel, Field
from agent_framework import (
AgentProtocol,
AgentResponse,
AgentResponseUpdate,
ChatMessage,
@@ -16,6 +15,7 @@ from agent_framework import (
ChatResponseUpdate,
Content,
ResponseStream,
SupportsAgentRun,
)
from agent_framework._middleware import (
AgentContext,
@@ -35,7 +35,7 @@ from agent_framework._tools import FunctionTool
class TestAgentContext:
"""Test cases for AgentContext."""
def test_init_with_defaults(self, mock_agent: AgentProtocol) -> None:
def test_init_with_defaults(self, mock_agent: SupportsAgentRun) -> None:
"""Test AgentContext initialization with default values."""
messages = [ChatMessage(role="user", text="test")]
context = AgentContext(agent=mock_agent, messages=messages)
@@ -45,7 +45,7 @@ class TestAgentContext:
assert context.stream is False
assert context.metadata == {}
def test_init_with_custom_values(self, mock_agent: AgentProtocol) -> None:
def test_init_with_custom_values(self, mock_agent: SupportsAgentRun) -> None:
"""Test AgentContext initialization with custom values."""
messages = [ChatMessage(role="user", text="test")]
metadata = {"key": "value"}
@@ -56,7 +56,7 @@ class TestAgentContext:
assert context.stream is True
assert context.metadata == metadata
def test_init_with_thread(self, mock_agent: AgentProtocol) -> None:
def test_init_with_thread(self, mock_agent: SupportsAgentRun) -> None:
"""Test AgentContext initialization with thread parameter."""
from agent_framework import AgentThread
@@ -163,7 +163,7 @@ class TestAgentMiddlewarePipeline:
pipeline = AgentMiddlewarePipeline(test_middleware)
assert pipeline.has_middlewares
async def test_execute_no_middleware(self, mock_agent: AgentProtocol) -> None:
async def test_execute_no_middleware(self, mock_agent: SupportsAgentRun) -> None:
"""Test pipeline execution with no middleware."""
pipeline = AgentMiddlewarePipeline()
messages = [ChatMessage(role="user", text="test")]
@@ -177,7 +177,7 @@ class TestAgentMiddlewarePipeline:
result = await pipeline.execute(context, final_handler)
assert result == expected_response
async def test_execute_with_middleware(self, mock_agent: AgentProtocol) -> None:
async def test_execute_with_middleware(self, mock_agent: SupportsAgentRun) -> None:
"""Test pipeline execution with middleware."""
execution_order: list[str] = []
@@ -205,7 +205,7 @@ class TestAgentMiddlewarePipeline:
assert result == expected_response
assert execution_order == ["test_before", "handler", "test_after"]
async def test_execute_stream_no_middleware(self, mock_agent: AgentProtocol) -> None:
async def test_execute_stream_no_middleware(self, mock_agent: SupportsAgentRun) -> None:
"""Test pipeline streaming execution with no middleware."""
pipeline = AgentMiddlewarePipeline()
messages = [ChatMessage(role="user", text="test")]
@@ -228,7 +228,7 @@ class TestAgentMiddlewarePipeline:
assert updates[0].text == "chunk1"
assert updates[1].text == "chunk2"
async def test_execute_stream_with_middleware(self, mock_agent: AgentProtocol) -> None:
async def test_execute_stream_with_middleware(self, mock_agent: SupportsAgentRun) -> None:
"""Test pipeline streaming execution with middleware."""
execution_order: list[str] = []
@@ -265,7 +265,7 @@ class TestAgentMiddlewarePipeline:
assert updates[1].text == "chunk2"
assert execution_order == ["test_before", "test_after", "handler_start", "handler_end"]
async def test_execute_with_pre_next_termination(self, mock_agent: AgentProtocol) -> None:
async def test_execute_with_pre_next_termination(self, mock_agent: SupportsAgentRun) -> None:
"""Test pipeline execution with termination before next()."""
middleware = self.PreNextTerminateMiddleware()
pipeline = AgentMiddlewarePipeline(middleware)
@@ -283,7 +283,7 @@ class TestAgentMiddlewarePipeline:
# Handler should not be called when terminated before next()
assert execution_order == []
async def test_execute_with_post_next_termination(self, mock_agent: AgentProtocol) -> None:
async def test_execute_with_post_next_termination(self, mock_agent: SupportsAgentRun) -> None:
"""Test pipeline execution with termination after next()."""
middleware = self.PostNextTerminateMiddleware()
pipeline = AgentMiddlewarePipeline(middleware)
@@ -301,7 +301,7 @@ class TestAgentMiddlewarePipeline:
assert response.messages[0].text == "response"
assert execution_order == ["handler"]
async def test_execute_stream_with_pre_next_termination(self, mock_agent: AgentProtocol) -> None:
async def test_execute_stream_with_pre_next_termination(self, mock_agent: SupportsAgentRun) -> None:
"""Test pipeline streaming execution with termination before next()."""
middleware = self.PreNextTerminateMiddleware()
pipeline = AgentMiddlewarePipeline(middleware)
@@ -329,7 +329,7 @@ class TestAgentMiddlewarePipeline:
assert execution_order == []
assert not updates
async def test_execute_stream_with_post_next_termination(self, mock_agent: AgentProtocol) -> None:
async def test_execute_stream_with_post_next_termination(self, mock_agent: SupportsAgentRun) -> None:
"""Test pipeline streaming execution with termination after next()."""
middleware = self.PostNextTerminateMiddleware()
pipeline = AgentMiddlewarePipeline(middleware)
@@ -356,7 +356,7 @@ class TestAgentMiddlewarePipeline:
assert updates[1].text == "chunk2"
assert execution_order == ["handler_start", "handler_end"]
async def test_execute_with_thread_in_context(self, mock_agent: AgentProtocol) -> None:
async def test_execute_with_thread_in_context(self, mock_agent: SupportsAgentRun) -> None:
"""Test pipeline execution properly passes thread to middleware."""
from agent_framework import AgentThread
@@ -383,7 +383,7 @@ class TestAgentMiddlewarePipeline:
assert result == expected_response
assert captured_thread is thread
async def test_execute_with_no_thread_in_context(self, mock_agent: AgentProtocol) -> None:
async def test_execute_with_no_thread_in_context(self, mock_agent: SupportsAgentRun) -> None:
"""Test pipeline execution when no thread is provided."""
captured_thread = "not_none" # Use string to distinguish from None
@@ -761,7 +761,7 @@ class TestChatMiddlewarePipeline:
class TestClassBasedMiddleware:
"""Test cases for class-based middleware implementations."""
async def test_agent_middleware_execution(self, mock_agent: AgentProtocol) -> None:
async def test_agent_middleware_execution(self, mock_agent: SupportsAgentRun) -> None:
"""Test class-based agent middleware execution."""
metadata_updates: list[str] = []
@@ -825,7 +825,7 @@ class TestClassBasedMiddleware:
class TestFunctionBasedMiddleware:
"""Test cases for function-based middleware implementations."""
async def test_agent_function_middleware(self, mock_agent: AgentProtocol) -> None:
async def test_agent_function_middleware(self, mock_agent: SupportsAgentRun) -> None:
"""Test function-based agent middleware."""
execution_order: list[str] = []
@@ -879,7 +879,7 @@ class TestFunctionBasedMiddleware:
class TestMixedMiddleware:
"""Test cases for mixed class and function-based middleware."""
async def test_mixed_agent_middleware(self, mock_agent: AgentProtocol) -> None:
async def test_mixed_agent_middleware(self, mock_agent: SupportsAgentRun) -> None:
"""Test mixed class and function-based agent middleware."""
execution_order: list[str] = []
@@ -976,7 +976,7 @@ class TestMixedMiddleware:
class TestMultipleMiddlewareOrdering:
"""Test cases for multiple middleware execution order."""
async def test_agent_middleware_execution_order(self, mock_agent: AgentProtocol) -> None:
async def test_agent_middleware_execution_order(self, mock_agent: SupportsAgentRun) -> None:
"""Test that multiple agent middleware execute in registration order."""
execution_order: list[str] = []
@@ -1110,7 +1110,7 @@ class TestMultipleMiddlewareOrdering:
class TestContextContentValidation:
"""Test cases for validating middleware context content."""
async def test_agent_context_validation(self, mock_agent: AgentProtocol) -> None:
async def test_agent_context_validation(self, mock_agent: SupportsAgentRun) -> None:
"""Test that agent context contains expected data."""
class ContextValidationMiddleware(AgentMiddleware):
@@ -1231,7 +1231,7 @@ class TestContextContentValidation:
class TestStreamingScenarios:
"""Test cases for streaming and non-streaming scenarios."""
async def test_streaming_flag_validation(self, mock_agent: AgentProtocol) -> None:
async def test_streaming_flag_validation(self, mock_agent: SupportsAgentRun) -> None:
"""Test that stream flag is correctly set for streaming calls."""
streaming_flags: list[bool] = []
@@ -1271,7 +1271,7 @@ class TestStreamingScenarios:
# Verify flags: [non-streaming middleware, non-streaming handler, streaming middleware, streaming handler]
assert streaming_flags == [False, False, True, True]
async def test_streaming_middleware_behavior(self, mock_agent: AgentProtocol) -> None:
async def test_streaming_middleware_behavior(self, mock_agent: SupportsAgentRun) -> None:
"""Test middleware behavior with streaming responses."""
chunks_processed: list[str] = []
@@ -1437,7 +1437,7 @@ class MockFunctionArgs(BaseModel):
class TestMiddlewareExecutionControl:
"""Test cases for middleware execution control (when next() is called vs not called)."""
async def test_agent_middleware_no_next_no_execution(self, mock_agent: AgentProtocol) -> None:
async def test_agent_middleware_no_next_no_execution(self, mock_agent: SupportsAgentRun) -> None:
"""Test that when agent middleware doesn't call next(), no execution happens."""
class NoNextMiddleware(AgentMiddleware):
@@ -1464,7 +1464,7 @@ class TestMiddlewareExecutionControl:
assert not handler_called
assert context.result is None
async def test_agent_middleware_no_next_no_streaming_execution(self, mock_agent: AgentProtocol) -> None:
async def test_agent_middleware_no_next_no_streaming_execution(self, mock_agent: SupportsAgentRun) -> None:
"""Test that when agent middleware doesn't call next(), no streaming execution happens."""
class NoNextStreamingMiddleware(AgentMiddleware):
@@ -1529,7 +1529,7 @@ class TestMiddlewareExecutionControl:
assert not handler_called
assert context.result is None
async def test_multiple_middlewares_early_stop(self, mock_agent: AgentProtocol) -> None:
async def test_multiple_middlewares_early_stop(self, mock_agent: SupportsAgentRun) -> None:
"""Test that when first middleware doesn't call next(), subsequent middleware are not called."""
execution_order: list[str] = []
@@ -1664,9 +1664,9 @@ class TestMiddlewareExecutionControl:
@pytest.fixture
def mock_agent() -> AgentProtocol:
def mock_agent() -> SupportsAgentRun:
"""Mock agent for testing."""
agent = MagicMock(spec=AgentProtocol)
agent = MagicMock(spec=SupportsAgentRun)
agent.name = "test_agent"
return agent
@@ -8,13 +8,13 @@ import pytest
from pydantic import BaseModel, Field
from agent_framework import (
AgentProtocol,
AgentResponse,
AgentResponseUpdate,
ChatAgent,
ChatMessage,
Content,
ResponseStream,
SupportsAgentRun,
)
from agent_framework._middleware import (
AgentContext,
@@ -38,7 +38,7 @@ class FunctionTestArgs(BaseModel):
class TestResultOverrideMiddleware:
"""Test cases for middleware result override functionality."""
async def test_agent_middleware_response_override_non_streaming(self, mock_agent: AgentProtocol) -> None:
async def test_agent_middleware_response_override_non_streaming(self, mock_agent: SupportsAgentRun) -> None:
"""Test that agent middleware can override response for non-streaming execution."""
override_response = AgentResponse(messages=[ChatMessage(role="assistant", text="overridden response")])
@@ -69,7 +69,7 @@ class TestResultOverrideMiddleware:
# Verify original handler was called since middleware called next()
assert handler_called
async def test_agent_middleware_response_override_streaming(self, mock_agent: AgentProtocol) -> None:
async def test_agent_middleware_response_override_streaming(self, mock_agent: SupportsAgentRun) -> None:
"""Test that agent middleware can override response for streaming execution."""
async def override_stream() -> AsyncIterable[AgentResponseUpdate]:
@@ -211,7 +211,7 @@ class TestResultOverrideMiddleware:
assert normal_updates[0].text == "test streaming response "
assert normal_updates[1].text == "another update"
async def test_agent_middleware_conditional_no_next(self, mock_agent: AgentProtocol) -> None:
async def test_agent_middleware_conditional_no_next(self, mock_agent: SupportsAgentRun) -> None:
"""Test that when agent middleware conditionally doesn't call next(), no execution happens."""
class ConditionalNoNextMiddleware(AgentMiddleware):
@@ -303,7 +303,7 @@ class TestResultOverrideMiddleware:
class TestResultObservability:
"""Test cases for middleware result observability functionality."""
async def test_agent_middleware_response_observability(self, mock_agent: AgentProtocol) -> None:
async def test_agent_middleware_response_observability(self, mock_agent: SupportsAgentRun) -> None:
"""Test that middleware can observe response after execution."""
observed_responses: list[AgentResponse] = []
@@ -370,7 +370,7 @@ class TestResultObservability:
assert observed_results[0] == "executed function result"
assert result == observed_results[0]
async def test_agent_middleware_post_execution_override(self, mock_agent: AgentProtocol) -> None:
async def test_agent_middleware_post_execution_override(self, mock_agent: SupportsAgentRun) -> None:
"""Test that middleware can override response after observing execution."""
class PostExecutionOverrideMiddleware(AgentMiddleware):
@@ -436,9 +436,9 @@ class TestResultObservability:
@pytest.fixture
def mock_agent() -> AgentProtocol:
def mock_agent() -> SupportsAgentRun:
"""Mock agent for testing."""
agent = MagicMock(spec=AgentProtocol)
agent = MagicMock(spec=SupportsAgentRun)
agent.name = "test_agent"
return agent
@@ -1853,13 +1853,13 @@ class TestChatAgentChatMiddleware:
# class TestMiddlewareWithProtocolOnlyAgent:
# """Test use_agent_middleware with agents implementing only AgentProtocol."""
# """Test use_agent_middleware with agents implementing only SupportsAgentRun."""
# async def test_middleware_with_protocol_only_agent(self) -> None:
# """Verify middleware works without BaseAgent inheritance for both run."""
# from collections.abc import AsyncIterable
# from agent_framework import AgentProtocol, AgentResponse, AgentResponseUpdate
# from agent_framework import SupportsAgentRun, AgentResponse, AgentResponseUpdate
# execution_order: list[str] = []
@@ -1873,7 +1873,7 @@ class TestChatAgentChatMiddleware:
# @use_agent_middleware
# class ProtocolOnlyAgent:
# """Minimal agent implementing only AgentProtocol, not inheriting from BaseAgent."""
# """Minimal agent implementing only SupportsAgentRun, not inheriting from BaseAgent."""
# def __init__(self):
# self.id = "protocol-only-agent"
@@ -1896,7 +1896,7 @@ class TestChatAgentChatMiddleware:
# return None
# agent = ProtocolOnlyAgent()
# assert isinstance(agent, AgentProtocol)
# assert isinstance(agent, SupportsAgentRun)
# # Test run (non-streaming)
# response = await agent.run("test message")
@@ -12,7 +12,6 @@ from opentelemetry.trace import StatusCode
from agent_framework import (
AGENT_FRAMEWORK_USER_AGENT,
AgentProtocol,
AgentResponse,
BaseChatClient,
ChatMessage,
@@ -20,6 +19,7 @@ from agent_framework import (
ChatResponseUpdate,
Content,
ResponseStream,
SupportsAgentRun,
UsageDetails,
prepend_agent_framework_to_user_agent,
tool,
@@ -473,7 +473,7 @@ def mock_chat_agent():
@pytest.mark.parametrize("enable_sensitive_data", [True, False], indirect=True)
async def test_agent_instrumentation_enabled(
mock_chat_agent: AgentProtocol, span_exporter: InMemorySpanExporter, enable_sensitive_data
mock_chat_agent: SupportsAgentRun, span_exporter: InMemorySpanExporter, enable_sensitive_data
):
"""Test that when agent diagnostics are enabled, telemetry is applied."""
@@ -499,7 +499,7 @@ async def test_agent_instrumentation_enabled(
@pytest.mark.parametrize("enable_sensitive_data", [True, False], indirect=True)
async def test_agent_streaming_response_with_diagnostics_enabled(
mock_chat_agent: AgentProtocol, span_exporter: InMemorySpanExporter, enable_sensitive_data
mock_chat_agent: SupportsAgentRun, span_exporter: InMemorySpanExporter, enable_sensitive_data
):
"""Test agent streaming telemetry through the agent telemetry mixin."""
agent = mock_chat_agent()
@@ -9,7 +9,6 @@ from typing_extensions import Never
from agent_framework import (
AgentExecutorRequest,
AgentProtocol,
AgentResponse,
AgentResponseUpdate,
AgentThread,
@@ -18,6 +17,7 @@ from agent_framework import (
Content,
Executor,
ResponseStream,
SupportsAgentRun,
UsageDetails,
WorkflowAgent,
WorkflowBuilder,
@@ -615,7 +615,7 @@ class TestWorkflowAgent:
async def test_agent_executor_output_response_false_filters_streaming_events(self):
"""Test that AgentExecutor with output_response=False does not surface streaming events."""
class MockAgent(AgentProtocol):
class MockAgent(SupportsAgentRun):
"""Mock agent for testing."""
def __init__(self, name: str, response_text: str) -> None:
@@ -705,7 +705,7 @@ class TestWorkflowAgent:
async def test_agent_executor_output_response_no_duplicate_from_workflow_output_event(self):
"""Test that AgentExecutor with output_response=True does not duplicate content."""
class MockAgent(AgentProtocol):
class MockAgent(SupportsAgentRun):
"""Mock agent for testing."""
def __init__(self, name: str, response_text: str) -> None:
@@ -422,7 +422,7 @@ def test_mixing_eager_and_lazy_initialization_error():
ValueError,
match=(
r"Both source and target must be either registered factory names \(str\) "
r"or Executor/AgentProtocol instances\."
r"or Executor/SupportsAgentRun instances\."
),
):
builder.add_edge(eager_executor, "Lazy")
@@ -17,8 +17,8 @@ from typing import Any, cast
import yaml
from agent_framework import (
AgentExecutor,
AgentProtocol,
CheckpointStorage,
SupportsAgentRun,
Workflow,
get_logger,
)
@@ -78,13 +78,13 @@ class WorkflowFactory:
workflow = factory.create_workflow_from_yaml_path("workflow.yaml")
"""
_agents: dict[str, AgentProtocol | AgentExecutor]
_agents: dict[str, SupportsAgentRun | AgentExecutor]
def __init__(
self,
*,
agent_factory: AgentFactory | None = None,
agents: Mapping[str, AgentProtocol | AgentExecutor] | None = None,
agents: Mapping[str, SupportsAgentRun | AgentExecutor] | None = None,
bindings: Mapping[str, Any] | None = None,
env_file: str | None = None,
checkpoint_storage: CheckpointStorage | None = None,
@@ -132,7 +132,7 @@ class WorkflowFactory:
)
"""
self._agent_factory = agent_factory or AgentFactory(env_file_path=env_file)
self._agents: dict[str, AgentProtocol | AgentExecutor] = dict(agents) if agents else {}
self._agents: dict[str, SupportsAgentRun | AgentExecutor] = dict(agents) if agents else {}
self._bindings: dict[str, Any] = dict(bindings) if bindings else {}
self._checkpoint_storage = checkpoint_storage
@@ -323,7 +323,7 @@ class WorkflowFactory:
description = workflow_def.get("description")
# Create agents from definitions
agents: dict[str, AgentProtocol | AgentExecutor] = dict(self._agents)
agents: dict[str, SupportsAgentRun | AgentExecutor] = dict(self._agents)
agent_defs = workflow_def.get("agents", {})
for agent_name, agent_def in agent_defs.items():
@@ -347,7 +347,7 @@ class WorkflowFactory:
workflow_def: dict[str, Any],
name: str,
description: str | None,
agents: dict[str, AgentProtocol | AgentExecutor],
agents: dict[str, SupportsAgentRun | AgentExecutor],
) -> Workflow:
"""Create workflow from definition.
@@ -506,7 +506,7 @@ class WorkflowFactory:
f"Invalid agent definition. Expected 'file', 'kind', or 'connection': {agent_def}"
)
def register_agent(self, name: str, agent: AgentProtocol | AgentExecutor) -> "WorkflowFactory":
def register_agent(self, name: str, agent: SupportsAgentRun | AgentExecutor) -> "WorkflowFactory":
"""Register an agent instance with the factory for use in workflows.
Registered agents are available to InvokeAzureAgent actions by name.
@@ -757,11 +757,11 @@ class EntityDiscovery:
True if object appears to be a valid agent
"""
try:
# Try to import AgentProtocol for proper type checking
# Try to import SupportsAgentRun for proper type checking
try:
from agent_framework import AgentProtocol
from agent_framework import SupportsAgentRun
if isinstance(obj, AgentProtocol):
if isinstance(obj, SupportsAgentRun):
return True
except ImportError:
pass
@@ -7,7 +7,7 @@ import logging
from collections.abc import AsyncGenerator
from typing import Any
from agent_framework import AgentProtocol, Content, Workflow
from agent_framework import Content, SupportsAgentRun, Workflow
from ._conversations import ConversationStore, InMemoryConversationStore
from ._discovery import EntityDiscovery
@@ -285,7 +285,7 @@ class AgentFrameworkExecutor:
yield {"type": "error", "message": str(e), "entity_id": entity_id}
async def _execute_agent(
self, agent: AgentProtocol, request: AgentFrameworkRequest, trace_collector: Any
self, agent: SupportsAgentRun, request: AgentFrameworkRequest, trace_collector: Any
) -> AsyncGenerator[Any, None]:
"""Execute Agent Framework agent with trace collection and optional thread support.
@@ -9,12 +9,12 @@ from datetime import datetime, timezone
from typing import Any, cast
from agent_framework import (
AgentProtocol,
AgentResponse,
AgentResponseUpdate,
ChatMessage,
Content,
ResponseStream,
SupportsAgentRun,
get_logger,
)
from durabletask.entities import DurableEntity
@@ -86,12 +86,12 @@ class AgentEntity:
This class encapsulates the core logic for executing an agent within a durable entity context.
"""
agent: AgentProtocol
agent: SupportsAgentRun
callback: AgentResponseCallbackProtocol | None
def __init__(
self,
agent: AgentProtocol,
agent: SupportsAgentRun,
callback: AgentResponseCallbackProtocol | None = None,
*,
state_provider: AgentEntityStateProviderMixin,
@@ -2,7 +2,7 @@
"""Durable Agent Shim for Durable Task Framework.
This module provides the DurableAIAgent shim that implements AgentProtocol
This module provides the DurableAIAgent shim that implements SupportsAgentRun
and provides a consistent interface for both Client and Orchestration contexts.
The actual execution is delegated to the context-specific providers.
"""
@@ -12,7 +12,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Generic, Literal, TypeVar
from agent_framework import AgentProtocol, AgentThread, ChatMessage
from agent_framework import AgentThread, ChatMessage, SupportsAgentRun
from ._executors import DurableAgentExecutor
from ._models import DurableAgentThread
@@ -47,11 +47,11 @@ class DurableAgentProvider(ABC, Generic[TaskT]):
raise NotImplementedError("Subclasses must implement get_agent()")
class DurableAIAgent(AgentProtocol, Generic[TaskT]):
class DurableAIAgent(SupportsAgentRun, Generic[TaskT]):
"""A durable agent proxy that delegates execution to the provider.
This class implements AgentProtocol but with one critical difference:
- AgentProtocol.run() returns a Coroutine (async, must await)
This class implements SupportsAgentRun but with one critical difference:
- SupportsAgentRun.run() returns a Coroutine (async, must await)
- DurableAIAgent.run() returns TaskT (sync Task object - must yield
or the AgentResponse directly in the case of TaskHubGrpcClient)
@@ -104,8 +104,8 @@ class DurableAIAgent(AgentProtocol, Generic[TaskT]):
Additional keys are forwarded to the agent execution.
Note:
This method overrides AgentProtocol.run() with a different return type:
- AgentProtocol.run() returns Coroutine[Any, Any, AgentResponse] (async)
This method overrides SupportsAgentRun.run() with a different return type:
- SupportsAgentRun.run() returns Coroutine[Any, Any, AgentResponse] (async)
- DurableAIAgent.run() returns TaskT (Task object for yielding)
This is intentional to support orchestration contexts that use yield patterns
@@ -11,7 +11,7 @@ from __future__ import annotations
import asyncio
from typing import Any
from agent_framework import AgentProtocol, get_logger
from agent_framework import SupportsAgentRun, get_logger
from durabletask.worker import TaskHubGrpcWorker
from ._callbacks import AgentResponseCallbackProtocol
@@ -60,12 +60,12 @@ class DurableAIAgentWorker:
"""
self._worker = worker
self._callback = callback
self._registered_agents: dict[str, AgentProtocol] = {}
self._registered_agents: dict[str, SupportsAgentRun] = {}
logger.debug("[DurableAIAgentWorker] Initialized with worker type: %s", type(worker).__name__)
def add_agent(
self,
agent: AgentProtocol,
agent: SupportsAgentRun,
callback: AgentResponseCallbackProtocol | None = None,
) -> None:
"""Register an agent with the worker.
@@ -139,7 +139,7 @@ class DurableAIAgentWorker:
def __create_agent_entity(
self,
agent: AgentProtocol,
agent: SupportsAgentRun,
callback: AgentResponseCallbackProtocol | None = None,
) -> type[DurableTaskEntityStateProvider]:
"""Factory function to create a DurableEntity class configured with an agent.
@@ -9,7 +9,7 @@ Run with: pytest tests/test_client.py -v
from unittest.mock import Mock
import pytest
from agent_framework import AgentProtocol
from agent_framework import SupportsAgentRun
from agent_framework_durabletask import DurableAgentThread, DurableAIAgentClient
from agent_framework_durabletask._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS
@@ -46,7 +46,7 @@ class TestDurableAIAgentClientGetAgent:
agent = agent_client.get_agent("assistant")
assert isinstance(agent, DurableAIAgent)
assert isinstance(agent, AgentProtocol)
assert isinstance(agent, SupportsAgentRun)
def test_get_agent_shim_has_correct_name(self, agent_client: DurableAIAgentClient) -> None:
"""Verify retrieved agent has the correct name."""
@@ -9,7 +9,7 @@ Run with: pytest tests/test_orchestration_context.py -v
from unittest.mock import Mock
import pytest
from agent_framework import AgentProtocol
from agent_framework import SupportsAgentRun
from agent_framework_durabletask import DurableAgentThread
from agent_framework_durabletask._orchestration_context import DurableAIAgentOrchestrationContext
@@ -36,7 +36,7 @@ class TestDurableAIAgentOrchestrationContextGetAgent:
agent = agent_context.get_agent("assistant")
assert isinstance(agent, DurableAIAgent)
assert isinstance(agent, AgentProtocol)
assert isinstance(agent, SupportsAgentRun)
def test_get_agent_shim_has_correct_name(self, agent_context: DurableAIAgentOrchestrationContext) -> None:
"""Verify retrieved agent has the correct name."""
@@ -10,7 +10,7 @@ from typing import Any
from unittest.mock import Mock
import pytest
from agent_framework import AgentProtocol, ChatMessage
from agent_framework import ChatMessage, SupportsAgentRun
from pydantic import BaseModel
from agent_framework_durabletask import DurableAgentThread
@@ -142,15 +142,15 @@ class TestDurableAIAgentParameterFlow:
assert kwargs["run_request"].response_format == ResponseFormatModel
class TestDurableAIAgentProtocolCompliance:
"""Test that DurableAIAgent implements AgentProtocol correctly."""
class TestDurableAISupportsAgentRunCompliance:
"""Test that DurableAIAgent implements SupportsAgentRun correctly."""
def test_agent_implements_protocol(self, test_agent: DurableAIAgent[Any]) -> None:
"""Verify DurableAIAgent implements AgentProtocol."""
assert isinstance(test_agent, AgentProtocol)
"""Verify DurableAIAgent implements SupportsAgentRun."""
assert isinstance(test_agent, SupportsAgentRun)
def test_agent_has_required_properties(self, test_agent: DurableAIAgent[Any]) -> None:
"""Verify DurableAIAgent has all required AgentProtocol properties."""
"""Verify DurableAIAgent has all required SupportsAgentRun properties."""
assert hasattr(test_agent, "id")
assert hasattr(test_agent, "name")
assert hasattr(test_agent, "display_name")
@@ -6,7 +6,7 @@ import logging
from collections.abc import Callable, Sequence
from typing import Any
from agent_framework import AgentProtocol, ChatMessage
from agent_framework import ChatMessage, SupportsAgentRun
from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse
from agent_framework._workflows._agent_utils import resolve_agent_id
from agent_framework._workflows._checkpoint import CheckpointStorage
@@ -29,8 +29,8 @@ parallel workflow with:
- a default aggregator that combines all agent conversations and completes the workflow
Notes:
- Participants can be provided as AgentProtocol or Executor instances via `.participants()`,
or as factories returning AgentProtocol or Executor via `.register_participants()`.
- Participants can be provided as SupportsAgentRun or Executor instances via `.participants()`,
or as factories returning SupportsAgentRun or Executor via `.register_participants()`.
- A custom aggregator can be provided as:
- an Executor instance (it should handle list[AgentExecutorResponse],
yield output), or
@@ -186,8 +186,8 @@ class _CallbackAggregator(Executor):
class ConcurrentBuilder:
r"""High-level builder for concurrent agent workflows.
- `participants([...])` accepts a list of AgentProtocol (recommended) or Executor.
- `register_participants([...])` accepts a list of factories for AgentProtocol (recommended)
- `participants([...])` accepts a list of SupportsAgentRun (recommended) or Executor.
- `register_participants([...])` accepts a list of factories for SupportsAgentRun (recommended)
or Executor factories
- `build()` wires: dispatcher -> fan-out -> participants -> fan-in -> aggregator.
- `with_aggregator(...)` overrides the default aggregator with an Executor or callback.
@@ -238,8 +238,8 @@ class ConcurrentBuilder:
"""
def __init__(self) -> None:
self._participants: list[AgentProtocol | Executor] = []
self._participant_factories: list[Callable[[], AgentProtocol | Executor]] = []
self._participants: list[SupportsAgentRun | Executor] = []
self._participant_factories: list[Callable[[], SupportsAgentRun | Executor]] = []
self._aggregator: Executor | None = None
self._aggregator_factory: Callable[[], Executor] | None = None
self._checkpoint_storage: CheckpointStorage | None = None
@@ -249,16 +249,16 @@ class ConcurrentBuilder:
def register_participants(
self,
participant_factories: Sequence[Callable[[], AgentProtocol | Executor]],
participant_factories: Sequence[Callable[[], SupportsAgentRun | Executor]],
) -> "ConcurrentBuilder":
r"""Define the parallel participants for this concurrent workflow.
Accepts factories (callables) that return AgentProtocol instances (e.g., created
Accepts factories (callables) that return SupportsAgentRun instances (e.g., created
by a chat client) or Executor instances. Each participant created by a factory
is wired as a parallel branch using fan-out edges from an internal dispatcher.
Args:
participant_factories: Sequence of callables returning AgentProtocol or Executor instances
participant_factories: Sequence of callables returning SupportsAgentRun or Executor instances
Raises:
ValueError: if `participant_factories` is empty or `.participants()`
@@ -300,20 +300,20 @@ class ConcurrentBuilder:
self._participant_factories = list(participant_factories)
return self
def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "ConcurrentBuilder":
def participants(self, participants: Sequence[SupportsAgentRun | Executor]) -> "ConcurrentBuilder":
r"""Define the parallel participants for this concurrent workflow.
Accepts AgentProtocol instances (e.g., created by a chat client) or Executor
Accepts SupportsAgentRun instances (e.g., created by a chat client) or Executor
instances. Each participant is wired as a parallel branch using fan-out edges
from an internal dispatcher.
Args:
participants: Sequence of AgentProtocol or Executor instances
participants: Sequence of SupportsAgentRun or Executor instances
Raises:
ValueError: if `participants` is empty, contains duplicates, or `.register_participants()`
or `.participants()` were already called
TypeError: if any entry is not AgentProtocol or Executor
TypeError: if any entry is not SupportsAgentRun or Executor
Example:
@@ -341,13 +341,13 @@ class ConcurrentBuilder:
if p.id in seen_executor_ids:
raise ValueError(f"Duplicate executor participant detected: id '{p.id}'")
seen_executor_ids.add(p.id)
elif isinstance(p, AgentProtocol):
elif isinstance(p, SupportsAgentRun):
pid = id(p)
if pid in seen_agent_ids:
raise ValueError("Duplicate agent participant detected (same agent instance provided twice)")
seen_agent_ids.add(pid)
else:
raise TypeError(f"participants must be AgentProtocol or Executor instances; got {type(p).__name__}")
raise TypeError(f"participants must be SupportsAgentRun or Executor instances; got {type(p).__name__}")
self._participants = list(participants)
return self
@@ -459,7 +459,7 @@ class ConcurrentBuilder:
def with_request_info(
self,
*,
agents: Sequence[str | AgentProtocol] | None = None,
agents: Sequence[str | SupportsAgentRun] | None = None,
) -> "ConcurrentBuilder":
"""Enable request info after agent participant responses.
@@ -508,7 +508,7 @@ class ConcurrentBuilder:
raise ValueError("No participants provided. Call .participants() or .register_participants() first.")
# We don't need to check if both are set since that is handled in the respective methods
participants: list[Executor | AgentProtocol] = []
participants: list[Executor | SupportsAgentRun] = []
if self._participant_factories:
# Resolve the participant factories now. This doesn't break the factory pattern
# since the Sequential builder still creates new instances per workflow build.
@@ -522,7 +522,7 @@ class ConcurrentBuilder:
for p in participants:
if isinstance(p, Executor):
executors.append(p)
elif isinstance(p, AgentProtocol):
elif isinstance(p, SupportsAgentRun):
if self._request_info_enabled and (
not self._request_info_filter or resolve_agent_id(p) in self._request_info_filter
):
@@ -531,7 +531,7 @@ class ConcurrentBuilder:
else:
executors.append(AgentExecutor(p))
else:
raise TypeError(f"Participants must be AgentProtocol or Executor instances. Got {type(p).__name__}.")
raise TypeError(f"Participants must be SupportsAgentRun or Executor instances. Got {type(p).__name__}.")
return executors
@@ -26,7 +26,7 @@ from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass
from typing import Any, ClassVar, cast, overload
from agent_framework import AgentProtocol, ChatAgent
from agent_framework import ChatAgent, SupportsAgentRun
from agent_framework._threads import AgentThread
from agent_framework._types import ChatMessage
from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse
@@ -523,8 +523,8 @@ class GroupChatBuilder:
def __init__(self) -> None:
"""Initialize the GroupChatBuilder."""
self._participants: dict[str, AgentProtocol | Executor] = {}
self._participant_factories: list[Callable[[], AgentProtocol | Executor]] = []
self._participants: dict[str, SupportsAgentRun | Executor] = {}
self._participant_factories: list[Callable[[], SupportsAgentRun | Executor]] = []
# Orchestrator related members
self._orchestrator: BaseGroupChatOrchestrator | None = None
@@ -683,13 +683,13 @@ class GroupChatBuilder:
def register_participants(
self,
participant_factories: Sequence[Callable[[], AgentProtocol | Executor]],
participant_factories: Sequence[Callable[[], SupportsAgentRun | Executor]],
) -> "GroupChatBuilder":
"""Register participant factories for this group chat workflow.
Args:
participant_factories: Sequence of callables that produce participant definitions
when invoked. Each callable should return either an AgentProtocol instance
when invoked. Each callable should return either an SupportsAgentRun instance
(auto-wrapped as AgentExecutor) or an Executor instance.
Returns:
@@ -711,10 +711,10 @@ class GroupChatBuilder:
self._participant_factories = list(participant_factories)
return self
def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "GroupChatBuilder":
def participants(self, participants: Sequence[SupportsAgentRun | Executor]) -> "GroupChatBuilder":
"""Define participants for this group chat workflow.
Accepts AgentProtocol instances (auto-wrapped as AgentExecutor) or Executor instances.
Accepts SupportsAgentRun instances (auto-wrapped as AgentExecutor) or Executor instances.
Args:
participants: Sequence of participant definitions
@@ -725,7 +725,7 @@ class GroupChatBuilder:
Raises:
ValueError: If participants are empty, names are duplicated, or participants
or participant factories are already set
TypeError: If any participant is not AgentProtocol or Executor instance
TypeError: If any participant is not SupportsAgentRun or Executor instance
Example:
@@ -750,17 +750,17 @@ class GroupChatBuilder:
raise ValueError("participants cannot be empty.")
# Name of the executor mapped to participant instance
named: dict[str, AgentProtocol | Executor] = {}
named: dict[str, SupportsAgentRun | Executor] = {}
for participant in participants:
if isinstance(participant, Executor):
identifier = participant.id
elif isinstance(participant, AgentProtocol):
elif isinstance(participant, SupportsAgentRun):
if not participant.name:
raise ValueError("AgentProtocol participants must have a non-empty name.")
raise ValueError("SupportsAgentRun participants must have a non-empty name.")
identifier = participant.name
else:
raise TypeError(
f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}."
f"Participants must be SupportsAgentRun or Executor instances. Got {type(participant).__name__}."
)
if identifier in named:
@@ -861,7 +861,7 @@ class GroupChatBuilder:
self._checkpoint_storage = checkpoint_storage
return self
def with_request_info(self, *, agents: Sequence[str | AgentProtocol] | None = None) -> "GroupChatBuilder":
def with_request_info(self, *, agents: Sequence[str | SupportsAgentRun] | None = None) -> "GroupChatBuilder":
"""Enable request info after agent participant responses.
This enables human-in-the-loop (HIL) scenarios for the group chat orchestration.
@@ -962,7 +962,7 @@ class GroupChatBuilder:
raise ValueError("No participants provided. Call .participants() or .register_participants() first.")
# We don't need to check if both are set since that is handled in the respective methods
participants: list[Executor | AgentProtocol] = []
participants: list[Executor | SupportsAgentRun] = []
if self._participant_factories:
for factory in self._participant_factories:
participant = factory()
@@ -974,7 +974,7 @@ class GroupChatBuilder:
for participant in participants:
if isinstance(participant, Executor):
executors.append(participant)
elif isinstance(participant, AgentProtocol):
elif isinstance(participant, SupportsAgentRun):
if self._request_info_enabled and (
not self._request_info_filter or resolve_agent_id(participant) in self._request_info_filter
):
@@ -984,7 +984,7 @@ class GroupChatBuilder:
executors.append(AgentExecutor(participant))
else:
raise TypeError(
f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}."
f"Participants must be SupportsAgentRun or Executor instances. Got {type(participant).__name__}."
)
return executors
@@ -36,7 +36,7 @@ from collections.abc import Awaitable, Callable, Mapping, Sequence
from dataclasses import dataclass
from typing import Any, cast
from agent_framework import AgentProtocol, ChatAgent
from agent_framework import ChatAgent, SupportsAgentRun
from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware
from agent_framework._threads import AgentThread
from agent_framework._tools import FunctionTool, tool
@@ -89,14 +89,14 @@ class HandoffConfiguration:
target_id: str
description: str | None = None
def __init__(self, *, target: str | AgentProtocol, description: str | None = None) -> None:
def __init__(self, *, target: str | SupportsAgentRun, description: str | None = None) -> None:
"""Initialize HandoffConfiguration.
Args:
target: Target agent identifier or AgentProtocol instance
target: Target agent identifier or SupportsAgentRun instance
description: Optional human-readable description of the handoff
"""
self.target_id = resolve_agent_id(target) if isinstance(target, AgentProtocol) else target
self.target_id = resolve_agent_id(target) if isinstance(target, SupportsAgentRun) else target
self.description = description
def __eq__(self, other: Any) -> bool:
@@ -193,7 +193,7 @@ class HandoffAgentExecutor(AgentExecutor):
def __init__(
self,
agent: AgentProtocol,
agent: SupportsAgentRun,
handoffs: Sequence[HandoffConfiguration],
*,
agent_thread: AgentThread | None = None,
@@ -236,9 +236,9 @@ class HandoffAgentExecutor(AgentExecutor):
def _prepare_agent_with_handoffs(
self,
agent: AgentProtocol,
agent: SupportsAgentRun,
handoffs: Sequence[HandoffConfiguration],
) -> AgentProtocol:
) -> SupportsAgentRun:
"""Prepare an agent by adding handoff tools for the specified target agents.
Args:
@@ -574,8 +574,8 @@ class HandoffBuilder:
self,
*,
name: str | None = None,
participants: Sequence[AgentProtocol] | None = None,
participant_factories: Mapping[str, Callable[[], AgentProtocol]] | None = None,
participants: Sequence[SupportsAgentRun] | None = None,
participant_factories: Mapping[str, Callable[[], SupportsAgentRun]] | None = None,
description: str | None = None,
) -> None:
r"""Initialize a HandoffBuilder for creating conversational handoff workflows.
@@ -604,8 +604,8 @@ class HandoffBuilder:
self._description = description
# Participant related members
self._participants: dict[str, AgentProtocol] = {}
self._participant_factories: dict[str, Callable[[], AgentProtocol]] = {}
self._participants: dict[str, SupportsAgentRun] = {}
self._participant_factories: dict[str, Callable[[], SupportsAgentRun]] = {}
self._start_id: str | None = None
if participant_factories:
self.register_participants(participant_factories)
@@ -629,16 +629,16 @@ class HandoffBuilder:
self._termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] | None = None
def register_participants(
self, participant_factories: Mapping[str, Callable[[], AgentProtocol]]
self, participant_factories: Mapping[str, Callable[[], SupportsAgentRun]]
) -> "HandoffBuilder":
"""Register factories that produce agents for the handoff workflow.
Each factory is a callable that returns an AgentProtocol instance.
Each factory is a callable that returns an SupportsAgentRun instance.
Factories are invoked when building the workflow, allowing for lazy instantiation
and state isolation per workflow instance.
Args:
participant_factories: Mapping of factory names to callables that return AgentProtocol
participant_factories: Mapping of factory names to callables that return SupportsAgentRun
instances. Each produced participant must have a unique identifier
(`.name` is preferred if set, otherwise `.id` is used).
@@ -690,11 +690,11 @@ class HandoffBuilder:
self._participant_factories = dict(participant_factories)
return self
def participants(self, participants: Sequence[AgentProtocol]) -> "HandoffBuilder":
def participants(self, participants: Sequence[SupportsAgentRun]) -> "HandoffBuilder":
"""Register the agents that will participate in the handoff workflow.
Args:
participants: Sequence of AgentProtocol instances. Each must have a unique identifier.
participants: Sequence of SupportsAgentRun instances. Each must have a unique identifier.
(`.name` is preferred if set, otherwise `.id` is used).
Returns:
@@ -703,7 +703,7 @@ class HandoffBuilder:
Raises:
ValueError: If participants is empty, contains duplicates, or `.participants()` or
`.register_participants()` has already been called.
TypeError: If participants are not AgentProtocol instances.
TypeError: If participants are not SupportsAgentRun instances.
Example:
@@ -729,13 +729,13 @@ class HandoffBuilder:
if not participants:
raise ValueError("participants cannot be empty")
named: dict[str, AgentProtocol] = {}
named: dict[str, SupportsAgentRun] = {}
for participant in participants:
if isinstance(participant, AgentProtocol):
if isinstance(participant, SupportsAgentRun):
resolved_id = self._resolve_to_id(participant)
else:
raise TypeError(
f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}."
f"Participants must be SupportsAgentRun or Executor instances. Got {type(participant).__name__}."
)
if resolved_id in named:
@@ -748,8 +748,8 @@ class HandoffBuilder:
def add_handoff(
self,
source: str | AgentProtocol,
targets: Sequence[str] | Sequence[AgentProtocol],
source: str | SupportsAgentRun,
targets: Sequence[str] | Sequence[SupportsAgentRun],
*,
description: str | None = None,
) -> "HandoffBuilder":
@@ -763,11 +763,11 @@ class HandoffBuilder:
Args:
source: The agent that can initiate the handoff. Can be:
- Factory name (str): If using participant factories
- AgentProtocol instance: The actual agent object
- SupportsAgentRun instance: The actual agent object
- Cannot mix factory names and instances across source and targets
targets: One or more target agents that the source can hand off to. Can be:
- Factory name (str): If using participant factories
- AgentProtocol instance: The actual agent object
- SupportsAgentRun instance: The actual agent object
- Single target: ["billing_agent"] or [agent_instance]
- Multiple targets: ["billing_agent", "support_agent"] or [agent1, agent2]
- Cannot mix factory names and instances across source and targets
@@ -786,7 +786,7 @@ class HandoffBuilder:
participants(...) hasn't been called yet.
2) If source or targets are factory names (str) but participant_factories(...)
hasn't been called yet, or if they are not in the participant_factories list.
TypeError: If mixing factory names (str) and AgentProtocol/Executor instances
TypeError: If mixing factory names (str) and SupportsAgentRun/Executor instances
Examples:
Single target (using factory name):
@@ -848,7 +848,7 @@ class HandoffBuilder:
self._handoff_config[source].add(HandoffConfiguration(target=t, description=description))
return self
if isinstance(source, (AgentProtocol)) and all(isinstance(t, AgentProtocol) for t in targets):
if isinstance(source, (SupportsAgentRun)) and all(isinstance(t, SupportsAgentRun) for t in targets):
# Both source and targets are instances
if not self._participants:
raise ValueError("Call participants(...) before add_handoff(...)")
@@ -881,10 +881,10 @@ class HandoffBuilder:
return self
raise TypeError(
"Cannot mix factory names (str) and AgentProtocol instances across source and targets in add_handoff()"
"Cannot mix factory names (str) and SupportsAgentRun instances across source and targets in add_handoff()"
)
def with_start_agent(self, agent: str | AgentProtocol) -> "HandoffBuilder":
def with_start_agent(self, agent: str | SupportsAgentRun) -> "HandoffBuilder":
"""Set the agent that will initiate the handoff workflow.
If not specified, the first registered participant will be used as the starting agent.
@@ -892,7 +892,7 @@ class HandoffBuilder:
Args:
agent: The agent that will start the workflow. Can be:
- Factory name (str): If using participant factories
- AgentProtocol instance: The actual agent object
- SupportsAgentRun instance: The actual agent object
Returns:
Self for method chaining.
"""
@@ -903,7 +903,7 @@ class HandoffBuilder:
else:
raise ValueError("Call register_participants(...) before with_start_agent(...)")
self._start_id = agent
elif isinstance(agent, AgentProtocol):
elif isinstance(agent, SupportsAgentRun):
resolved_id = self._resolve_to_id(agent)
if self._participants:
if resolved_id not in self._participants:
@@ -912,14 +912,14 @@ class HandoffBuilder:
raise ValueError("Call participants(...) before with_start_agent(...)")
self._start_id = resolved_id
else:
raise TypeError("Start agent must be a factory name (str) or an AgentProtocol instance")
raise TypeError("Start agent must be a factory name (str) or an SupportsAgentRun instance")
return self
def with_autonomous_mode(
self,
*,
agents: Sequence[AgentProtocol] | Sequence[str] | None = None,
agents: Sequence[SupportsAgentRun] | Sequence[str] | None = None,
prompts: dict[str, str] | None = None,
turn_limits: dict[str, int] | None = None,
) -> "HandoffBuilder":
@@ -933,7 +933,7 @@ class HandoffBuilder:
Args:
agents: Optional list of agents to enable autonomous mode for. Can be:
- Factory names (str): If using participant factories
- AgentProtocol instances: The actual agent objects
- SupportsAgentRun instances: The actual agent objects
- If not provided, all agents will operate in autonomous mode.
prompts: Optional mapping of agent identifiers/factory names to custom prompts to use when continuing
in autonomous mode. If not provided, a default prompt will be used.
@@ -1084,7 +1084,7 @@ class HandoffBuilder:
# region Internal Helper Methods
def _resolve_agents(self) -> dict[str, AgentProtocol]:
def _resolve_agents(self) -> dict[str, SupportsAgentRun]:
"""Resolve participant factories into agent instances.
If agent instances were provided directly via participants(...), those are
@@ -1092,7 +1092,7 @@ class HandoffBuilder:
those are invoked to create the agent instances.
Returns:
Map of executor IDs or factory names to `AgentProtocol` instances
Map of executor IDs or factory names to `SupportsAgentRun` instances
"""
if not self._participants and not self._participant_factories:
raise ValueError("No participants provided. Call .participants() or .register_participants() first.")
@@ -1103,13 +1103,13 @@ class HandoffBuilder:
if self._participant_factories:
# Invoke each factory to create participant instances
factory_names_to_agents: dict[str, AgentProtocol] = {}
factory_names_to_agents: dict[str, SupportsAgentRun] = {}
for factory_name, factory in self._participant_factories.items():
instance = factory()
if isinstance(instance, AgentProtocol):
if isinstance(instance, SupportsAgentRun):
resolved_id = self._resolve_to_id(instance)
else:
raise TypeError(f"Participants must be AgentProtocol instances. Got {type(instance).__name__}.")
raise TypeError(f"Participants must be SupportsAgentRun instances. Got {type(instance).__name__}.")
if resolved_id in factory_names_to_agents:
raise ValueError(f"Duplicate participant name '{resolved_id}' detected")
@@ -1122,11 +1122,11 @@ class HandoffBuilder:
raise ValueError("No executors or participant_factories have been configured")
def _resolve_handoffs(self, agents: Mapping[str, AgentProtocol]) -> dict[str, list[HandoffConfiguration]]:
def _resolve_handoffs(self, agents: Mapping[str, SupportsAgentRun]) -> dict[str, list[HandoffConfiguration]]:
"""Handoffs may be specified using factory names or instances; resolve to executor IDs.
Args:
agents: Map of agent IDs or factory names to `AgentProtocol` instances
agents: Map of agent IDs or factory names to `SupportsAgentRun` instances
Returns:
Map of executor IDs to list of HandoffConfiguration instances
@@ -1173,13 +1173,13 @@ class HandoffBuilder:
def _resolve_executors(
self,
agents: dict[str, AgentProtocol],
agents: dict[str, SupportsAgentRun],
handoffs: dict[str, list[HandoffConfiguration]],
) -> dict[str, HandoffAgentExecutor]:
"""Resolve agents into HandoffAgentExecutors.
Args:
agents: Map of agent IDs or factory names to `AgentProtocol` instances
agents: Map of agent IDs or factory names to `SupportsAgentRun` instances
handoffs: Map of executor IDs to list of HandoffConfiguration instances
Returns:
@@ -1213,9 +1213,9 @@ class HandoffBuilder:
return executors
def _resolve_to_id(self, candidate: str | AgentProtocol) -> str:
def _resolve_to_id(self, candidate: str | SupportsAgentRun) -> str:
"""Resolve a participant reference into a concrete executor identifier."""
if isinstance(candidate, AgentProtocol):
if isinstance(candidate, SupportsAgentRun):
return resolve_agent_id(candidate)
if isinstance(candidate, str):
return candidate
@@ -13,9 +13,9 @@ from enum import Enum
from typing import Any, ClassVar, TypeVar, cast, overload
from agent_framework import (
AgentProtocol,
AgentResponse,
ChatMessage,
SupportsAgentRun,
)
from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse
from agent_framework._workflows._checkpoint import CheckpointStorage
@@ -521,7 +521,7 @@ class StandardMagenticManager(MagenticManagerBase):
def __init__(
self,
agent: AgentProtocol,
agent: SupportsAgentRun,
task_ledger: _MagenticTaskLedger | None = None,
*,
task_ledger_facts_prompt: str | None = None,
@@ -562,7 +562,7 @@ class StandardMagenticManager(MagenticManagerBase):
max_round_count=max_round_count,
)
self._agent: AgentProtocol = agent
self._agent: SupportsAgentRun = agent
self.task_ledger: _MagenticTaskLedger | None = task_ledger
# Prompts may be overridden if needed
@@ -1311,10 +1311,10 @@ class MagenticOrchestrator(BaseGroupChatOrchestrator):
class MagenticAgentExecutor(AgentExecutor):
"""Specialized AgentExecutor for Magentic agent participants."""
def __init__(self, agent: AgentProtocol) -> None:
def __init__(self, agent: SupportsAgentRun) -> None:
"""Initialize a Magentic Agent Executor.
This executor wraps an AgentProtocol instance to be used as a participant
This executor wraps an SupportsAgentRun instance to be used as a participant
in a Magentic One workflow.
Args:
@@ -1377,13 +1377,13 @@ class MagenticBuilder:
def __init__(self) -> None:
"""Initialize the Magentic workflow builder."""
self._participants: dict[str, AgentProtocol | Executor] = {}
self._participant_factories: list[Callable[[], AgentProtocol | Executor]] = []
self._participants: dict[str, SupportsAgentRun | Executor] = {}
self._participant_factories: list[Callable[[], SupportsAgentRun | Executor]] = []
# Manager related members
self._manager: MagenticManagerBase | None = None
self._manager_factory: Callable[[], MagenticManagerBase] | None = None
self._manager_agent_factory: Callable[[], AgentProtocol] | None = None
self._manager_agent_factory: Callable[[], SupportsAgentRun] | None = None
self._standard_manager_options: dict[str, Any] = {}
self._enable_plan_review: bool = False
@@ -1394,12 +1394,12 @@ class MagenticBuilder:
def register_participants(
self,
participant_factories: Sequence[Callable[[], AgentProtocol | Executor]],
participant_factories: Sequence[Callable[[], SupportsAgentRun | Executor]],
) -> "MagenticBuilder":
"""Register participant factories for this Magentic workflow.
Args:
participant_factories: Sequence of callables that return AgentProtocol or Executor instances.
participant_factories: Sequence of callables that return SupportsAgentRun or Executor instances.
Returns:
Self for method chaining
@@ -1420,10 +1420,10 @@ class MagenticBuilder:
self._participant_factories = list(participant_factories)
return self
def participants(self, participants: Sequence[AgentProtocol | Executor]) -> Self:
def participants(self, participants: Sequence[SupportsAgentRun | Executor]) -> Self:
"""Define participants for this Magentic workflow.
Accepts AgentProtocol instances (auto-wrapped as AgentExecutor) or Executor instances.
Accepts SupportsAgentRun instances (auto-wrapped as AgentExecutor) or Executor instances.
Args:
participants: Sequence of participant definitions
@@ -1434,7 +1434,7 @@ class MagenticBuilder:
Raises:
ValueError: If participants are empty, names are duplicated, or participants
or participant factories are already set
TypeError: If any participant is not AgentProtocol or Executor instance
TypeError: If any participant is not SupportsAgentRun or Executor instance
Example:
@@ -1462,17 +1462,17 @@ class MagenticBuilder:
raise ValueError("participants cannot be empty.")
# Name of the executor mapped to participant instance
named: dict[str, AgentProtocol | Executor] = {}
named: dict[str, SupportsAgentRun | Executor] = {}
for participant in participants:
if isinstance(participant, Executor):
identifier = participant.id
elif isinstance(participant, AgentProtocol):
elif isinstance(participant, SupportsAgentRun):
if not participant.name:
raise ValueError("AgentProtocol participants must have a non-empty name.")
raise ValueError("SupportsAgentRun participants must have a non-empty name.")
identifier = participant.name
else:
raise TypeError(
f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}."
f"Participants must be SupportsAgentRun or Executor instances. Got {type(participant).__name__}."
)
if identifier in named:
@@ -1608,7 +1608,7 @@ class MagenticBuilder:
def with_manager(
self,
*,
agent: AgentProtocol,
agent: SupportsAgentRun,
task_ledger: _MagenticTaskLedger | None = None,
# Prompt overrides
task_ledger_facts_prompt: str | None = None,
@@ -1628,7 +1628,7 @@ class MagenticBuilder:
This will create a StandardMagenticManager using the provided agent.
Args:
agent: AgentProtocol instance for the standard magentic manager
agent: SupportsAgentRun instance for the standard magentic manager
(`StandardMagenticManager`)
task_ledger: Optional custom task ledger implementation for specialized
prompting or structured output requirements
@@ -1661,7 +1661,7 @@ class MagenticBuilder:
def with_manager(
self,
*,
agent_factory: Callable[[], AgentProtocol],
agent_factory: Callable[[], SupportsAgentRun],
task_ledger: _MagenticTaskLedger | None = None,
# Prompt overrides
task_ledger_facts_prompt: str | None = None,
@@ -1681,7 +1681,7 @@ class MagenticBuilder:
This will create a StandardMagenticManager using the provided agent factory.
Args:
agent_factory: Callable that returns a new AgentProtocol instance for the standard
agent_factory: Callable that returns a new SupportsAgentRun instance for the standard
magentic manager (`StandardMagenticManager`)
task_ledger: Optional custom task ledger implementation for specialized
prompting or structured output requirements
@@ -1715,9 +1715,9 @@ class MagenticBuilder:
*,
manager: MagenticManagerBase | None = None,
manager_factory: Callable[[], MagenticManagerBase] | None = None,
agent_factory: Callable[[], AgentProtocol] | None = None,
agent_factory: Callable[[], SupportsAgentRun] | None = None,
# Constructor args for StandardMagenticManager when manager is not provided
agent: AgentProtocol | None = None,
agent: SupportsAgentRun | None = None,
task_ledger: _MagenticTaskLedger | None = None,
# Prompt overrides
task_ledger_facts_prompt: str | None = None,
@@ -1956,7 +1956,7 @@ class MagenticBuilder:
raise ValueError("No participants provided. Call .participants() or .register_participants() first.")
# We don't need to check if both are set since that is handled in the respective methods
participants: list[Executor | AgentProtocol] = []
participants: list[Executor | SupportsAgentRun] = []
if self._participant_factories:
for factory in self._participant_factories:
participant = factory()
@@ -1968,11 +1968,11 @@ class MagenticBuilder:
for participant in participants:
if isinstance(participant, Executor):
executors.append(participant)
elif isinstance(participant, AgentProtocol):
elif isinstance(participant, SupportsAgentRun):
executors.append(MagenticAgentExecutor(participant))
else:
raise TypeError(
f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}."
f"Participants must be SupportsAgentRun or Executor instances. Got {type(participant).__name__}."
)
return executors
@@ -2,7 +2,7 @@
from dataclasses import dataclass
from agent_framework._agents import AgentProtocol
from agent_framework._agents import SupportsAgentRun
from agent_framework._types import ChatMessage
from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse
from agent_framework._workflows._agent_utils import resolve_agent_id
@@ -14,11 +14,11 @@ from agent_framework._workflows._workflow_context import WorkflowContext
from agent_framework._workflows._workflow_executor import WorkflowExecutor
def resolve_request_info_filter(agents: list[str | AgentProtocol] | None) -> set[str]:
def resolve_request_info_filter(agents: list[str | SupportsAgentRun] | None) -> set[str]:
"""Resolve a list of agent/executor references to a set of IDs for filtering.
Args:
agents: List of agent names (str), AgentProtocol instances, or Executor instances.
agents: List of agent names (str), SupportsAgentRun instances, or Executor instances.
If None, returns None (meaning no filtering - pause for all).
Returns:
@@ -31,7 +31,7 @@ def resolve_request_info_filter(agents: list[str | AgentProtocol] | None) -> set
for agent in agents:
if isinstance(agent, str):
result.add(agent)
elif isinstance(agent, AgentProtocol):
elif isinstance(agent, SupportsAgentRun):
result.add(resolve_agent_id(agent))
else:
raise TypeError(f"Unsupported type for request_info filter: {type(agent).__name__}")
@@ -117,7 +117,7 @@ class AgentApprovalExecutor(WorkflowExecutor):
agent's output or send the final response to down stream executors in the orchestration.
"""
def __init__(self, agent: AgentProtocol) -> None:
def __init__(self, agent: SupportsAgentRun) -> None:
"""Initialize the AgentApprovalExecutor.
Args:
@@ -126,7 +126,7 @@ class AgentApprovalExecutor(WorkflowExecutor):
super().__init__(workflow=self._build_workflow(agent), id=resolve_agent_id(agent), propagate_request=True)
self._description = agent.description
def _build_workflow(self, agent: AgentProtocol) -> Workflow:
def _build_workflow(self, agent: SupportsAgentRun) -> Workflow:
"""Build the internal workflow for the AgentApprovalExecutor."""
agent_executor = AgentExecutor(agent)
request_info_executor = AgentRequestInfoExecutor(id="agent_request_info_executor")
@@ -4,8 +4,8 @@
This module provides a high-level, agent-focused API to assemble a sequential
workflow where:
- Participants can be provided as AgentProtocol or Executor instances via `.participants()`,
or as factories returning AgentProtocol or Executor via `.register_participants()`
- Participants can be provided as SupportsAgentRun or Executor instances via `.participants()`,
or as factories returning SupportsAgentRun or Executor via `.register_participants()`
- A shared conversation context (list[ChatMessage]) is passed along the chain
- Agents append their assistant messages to the context
- Custom executors can transform or summarize and return a refined context
@@ -15,7 +15,7 @@ Typical wiring:
input -> _InputToConversation -> participant1 -> (agent? -> _ResponseToConversation) -> ... -> participantN -> _EndWithConversation
Notes:
- Participants can mix AgentProtocol and Executor objects
- Participants can mix SupportsAgentRun and Executor objects
- Agents are auto-wrapped by WorkflowBuilder as AgentExecutor (unless already wrapped)
- AgentExecutor produces AgentExecutorResponse; _ResponseToConversation converts this to list[ChatMessage]
- Non-agent executors must define a handler that consumes `list[ChatMessage]` and sends back
@@ -41,7 +41,7 @@ import logging
from collections.abc import Callable, Sequence
from typing import Any
from agent_framework import AgentProtocol, ChatMessage
from agent_framework import ChatMessage, SupportsAgentRun
from agent_framework._workflows._agent_executor import (
AgentExecutor,
AgentExecutorResponse,
@@ -109,8 +109,8 @@ class _EndWithConversation(Executor):
class SequentialBuilder:
r"""High-level builder for sequential agent/executor workflows with shared context.
- `participants([...])` accepts a list of AgentProtocol (recommended) or Executor instances
- `register_participants([...])` accepts a list of factories for AgentProtocol (recommended)
- `participants([...])` accepts a list of SupportsAgentRun (recommended) or Executor instances
- `register_participants([...])` accepts a list of factories for SupportsAgentRun (recommended)
or Executor factories
- Executors must define a handler that consumes list[ChatMessage] and sends out a list[ChatMessage]
- The workflow wires participants in order, passing a list[ChatMessage] down the chain
@@ -148,8 +148,8 @@ class SequentialBuilder:
"""
def __init__(self) -> None:
self._participants: list[AgentProtocol | Executor] = []
self._participant_factories: list[Callable[[], AgentProtocol | Executor]] = []
self._participants: list[SupportsAgentRun | Executor] = []
self._participant_factories: list[Callable[[], SupportsAgentRun | Executor]] = []
self._checkpoint_storage: CheckpointStorage | None = None
self._request_info_enabled: bool = False
self._request_info_filter: set[str] | None = None
@@ -157,7 +157,7 @@ class SequentialBuilder:
def register_participants(
self,
participant_factories: Sequence[Callable[[], AgentProtocol | Executor]],
participant_factories: Sequence[Callable[[], SupportsAgentRun | Executor]],
) -> "SequentialBuilder":
"""Register participant factories for this sequential workflow."""
if self._participants:
@@ -172,10 +172,10 @@ class SequentialBuilder:
self._participant_factories = list(participant_factories)
return self
def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "SequentialBuilder":
def participants(self, participants: Sequence[SupportsAgentRun | Executor]) -> "SequentialBuilder":
"""Define the ordered participants for this sequential workflow.
Accepts AgentProtocol instances (auto-wrapped as AgentExecutor) or Executor instances.
Accepts SupportsAgentRun instances (auto-wrapped as AgentExecutor) or Executor instances.
Raises if empty or duplicates are provided for clarity.
"""
if self._participant_factories:
@@ -196,7 +196,7 @@ class SequentialBuilder:
raise ValueError(f"Duplicate executor participant detected: id '{p.id}'")
seen_executor_ids.add(p.id)
else:
# Treat non-Executor as agent-like (AgentProtocol). Structural checks can be brittle at runtime.
# Treat non-Executor as agent-like (SupportsAgentRun). Structural checks can be brittle at runtime.
pid = id(p)
if pid in seen_agent_ids:
raise ValueError("Duplicate agent participant detected (same agent instance provided twice)")
@@ -213,7 +213,7 @@ class SequentialBuilder:
def with_request_info(
self,
*,
agents: Sequence[str | AgentProtocol] | None = None,
agents: Sequence[str | SupportsAgentRun] | None = None,
) -> "SequentialBuilder":
"""Enable request info after agent participant responses.
@@ -262,7 +262,7 @@ class SequentialBuilder:
raise ValueError("No participants provided. Call .participants() or .register_participants() first.")
# We don't need to check if both are set since that is handled in the respective methods
participants: list[Executor | AgentProtocol] = []
participants: list[Executor | SupportsAgentRun] = []
if self._participant_factories:
# Resolve the participant factories now. This doesn't break the factory pattern
# since the Sequential builder still creates new instances per workflow build.
@@ -276,7 +276,7 @@ class SequentialBuilder:
for p in participants:
if isinstance(p, Executor):
executors.append(p)
elif isinstance(p, AgentProtocol):
elif isinstance(p, SupportsAgentRun):
if self._request_info_enabled and (
not self._request_info_filter or resolve_agent_id(p) in self._request_info_filter
):
@@ -285,7 +285,7 @@ class SequentialBuilder:
else:
executors.append(AgentExecutor(p))
else:
raise TypeError(f"Participants must be AgentProtocol or Executor instances. Got {type(p).__name__}.")
raise TypeError(f"Participants must be SupportsAgentRun or Executor instances. Got {type(p).__name__}.")
return executors
@@ -312,7 +312,7 @@ class SequentialBuilder:
builder.set_start_executor(input_conv)
# Start of the chain is the input normalizer
prior: Executor | AgentProtocol = input_conv
prior: Executor | SupportsAgentRun = input_conv
for p in participants:
builder.add_edge(prior, p)
prior = p
@@ -320,7 +320,7 @@ class TestGroupChatBuilder:
builder = GroupChatBuilder().with_orchestrator(selection_func=selector)
with pytest.raises(ValueError, match="AgentProtocol participants must have a non-empty name"):
with pytest.raises(ValueError, match="SupportsAgentRun participants must have a non-empty name"):
builder.participants([agent])
def test_empty_participant_name_raises_error(self) -> None:
@@ -332,7 +332,7 @@ class TestGroupChatBuilder:
builder = GroupChatBuilder().with_orchestrator(selection_func=selector)
with pytest.raises(ValueError, match="AgentProtocol participants must have a non-empty name"):
with pytest.raises(ValueError, match="SupportsAgentRun participants must have a non-empty name"):
builder.participants([agent])
@@ -420,7 +420,7 @@ def test_handoff_builder_rejects_mixed_types_in_add_handoff_source():
triage = MockHandoffAgent(name="triage")
specialist = MockHandoffAgent(name="specialist")
with pytest.raises(TypeError, match="Cannot mix factory names \\(str\\) and AgentProtocol.*instances"):
with pytest.raises(TypeError, match="Cannot mix factory names \\(str\\) and SupportsAgentRun.*instances"):
(
HandoffBuilder(participants=[triage, specialist])
.with_start_agent(triage)
@@ -7,7 +7,6 @@ from typing import Any, ClassVar, cast
import pytest
from agent_framework import (
AgentProtocol,
AgentResponse,
AgentResponseUpdate,
AgentThread,
@@ -15,6 +14,7 @@ from agent_framework import (
ChatMessage,
Content,
Executor,
SupportsAgentRun,
Workflow,
WorkflowCheckpoint,
WorkflowCheckpointException,
@@ -576,7 +576,7 @@ class StubAssistantsAgent(BaseAgent):
)
async def _collect_agent_responses_setup(participant: AgentProtocol) -> list[ChatMessage]:
async def _collect_agent_responses_setup(participant: SupportsAgentRun) -> list[ChatMessage]:
captured: list[ChatMessage] = []
wf = (
@@ -1121,10 +1121,10 @@ async def test_magentic_with_agent_factory():
"""Test workflow creation using agent_factory for StandardMagenticManager."""
factory_call_count = 0
def agent_factory() -> AgentProtocol:
def agent_factory() -> SupportsAgentRun:
nonlocal factory_call_count
factory_call_count += 1
return cast(AgentProtocol, StubManagerAgent())
return cast(SupportsAgentRun, StubManagerAgent())
participant = StubAgent("agentA", "reply from agentA")
workflow = (
@@ -1239,10 +1239,10 @@ def test_magentic_agent_factory_with_standard_manager_options():
"""Test that agent_factory properly passes through standard manager options."""
factory_call_count = 0
def agent_factory() -> AgentProtocol:
def agent_factory() -> SupportsAgentRun:
nonlocal factory_call_count
factory_call_count += 1
return cast(AgentProtocol, StubManagerAgent())
return cast(SupportsAgentRun, StubManagerAgent())
# Custom options to verify they are passed through
custom_max_stall_count = 5
@@ -8,11 +8,11 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from agent_framework import (
AgentProtocol,
AgentResponse,
AgentResponseUpdate,
AgentThread,
ChatMessage,
SupportsAgentRun,
)
from agent_framework._workflows._agent_executor import AgentExecutorRequest, AgentExecutorResponse
from agent_framework._workflows._workflow_context import WorkflowContext
@@ -44,10 +44,10 @@ class TestResolveRequestInfoFilter:
assert result == {"agent1", "agent2"}
def test_resolves_agent_display_names(self):
"""Test resolving AgentProtocol instances by name attribute."""
agent1 = MagicMock(spec=AgentProtocol)
"""Test resolving SupportsAgentRun instances by name attribute."""
agent1 = MagicMock(spec=SupportsAgentRun)
agent1.name = "writer"
agent2 = MagicMock(spec=AgentProtocol)
agent2 = MagicMock(spec=SupportsAgentRun)
agent2.name = "reviewer"
result = resolve_request_info_filter([agent1, agent2])
@@ -55,7 +55,7 @@ class TestResolveRequestInfoFilter:
def test_mixed_types(self):
"""Test resolving a mix of strings and agents."""
agent = MagicMock(spec=AgentProtocol)
agent = MagicMock(spec=SupportsAgentRun)
agent.name = "writer"
result = resolve_request_info_filter(["manual_name", agent])
+1 -1
View File
@@ -131,7 +131,7 @@ sequenceDiagram
| Field | Type | Description |
|-------|------|-------------|
| `agent` | `AgentProtocol` | The agent being invoked |
| `agent` | `SupportsAgentRun` | The agent being invoked |
| `messages` | `list[ChatMessage]` | Input messages (mutable) |
| `thread` | `AgentThread \| None` | Conversation thread |
| `options` | `Mapping[str, Any]` | Chat options dict |
@@ -3,7 +3,7 @@
import asyncio
from typing import Any
from agent_framework import AgentProtocol, AgentResponse, AgentThread, ChatMessage, HostedMCPTool
from agent_framework import SupportsAgentRun, AgentResponse, AgentThread, ChatMessage, HostedMCPTool
from agent_framework.azure import AzureAIProjectAgentProvider
from azure.identity.aio import AzureCliCredential
@@ -14,7 +14,7 @@ This sample demonstrates integrating hosted Model Context Protocol (MCP) tools w
"""
async def handle_approvals_without_thread(query: str, agent: "AgentProtocol") -> AgentResponse:
async def handle_approvals_without_thread(query: str, agent: "SupportsAgentRun") -> AgentResponse:
"""When we don't have a thread, we need to ensure we return with the input, approval request and approval."""
result = await agent.run(query, store=False)
@@ -35,7 +35,7 @@ async def handle_approvals_without_thread(query: str, agent: "AgentProtocol") ->
return result
async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", thread: "AgentThread") -> AgentResponse:
async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", thread: "AgentThread") -> AgentResponse:
"""Here we let the thread deal with the previous responses, and we just rerun with the approval."""
result = await agent.run(query, thread=thread)
@@ -3,7 +3,7 @@
import asyncio
from typing import Any
from agent_framework import AgentProtocol, AgentResponse, AgentThread, HostedMCPTool
from agent_framework import SupportsAgentRun, AgentResponse, AgentThread, HostedMCPTool
from agent_framework.azure import AzureAIAgentsProvider
from azure.identity.aio import AzureCliCredential
@@ -15,7 +15,7 @@ servers, including user approval workflows for function call security.
"""
async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", thread: "AgentThread") -> AgentResponse:
async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", thread: "AgentThread") -> AgentResponse:
"""Here we let the thread deal with the previous responses, and we just rerun with the approval."""
from agent_framework import ChatMessage
@@ -5,7 +5,7 @@ from datetime import datetime, timezone
from typing import Any
from agent_framework import (
AgentProtocol,
SupportsAgentRun,
AgentThread,
HostedMCPTool,
HostedWebSearchTool,
@@ -43,7 +43,7 @@ def get_time() -> str:
return f"The current UTC time is {current_time.strftime('%Y-%m-%d %H:%M:%S')}."
async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", thread: "AgentThread"):
async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", thread: "AgentThread"):
"""Here we let the thread deal with the previous responses, and we just rerun with the approval."""
from agent_framework import ChatMessage
@@ -15,10 +15,10 @@ Azure OpenAI Responses Client, including user approval workflows for function ca
"""
if TYPE_CHECKING:
from agent_framework import AgentProtocol, AgentThread
from agent_framework import SupportsAgentRun, AgentThread
async def handle_approvals_without_thread(query: str, agent: "AgentProtocol"):
async def handle_approvals_without_thread(query: str, agent: "SupportsAgentRun"):
"""When we don't have a thread, we need to ensure we return with the input, approval request and approval."""
from agent_framework import ChatMessage
@@ -40,7 +40,7 @@ async def handle_approvals_without_thread(query: str, agent: "AgentProtocol"):
return result
async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", thread: "AgentThread"):
async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", thread: "AgentThread"):
"""Here we let the thread deal with the previous responses, and we just rerun with the approval."""
from agent_framework import ChatMessage
@@ -63,7 +63,7 @@ async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", threa
return result
async def handle_approvals_with_thread_streaming(query: str, agent: "AgentProtocol", thread: "AgentThread"):
async def handle_approvals_with_thread_streaming(query: str, agent: "SupportsAgentRun", thread: "AgentThread"):
"""Here we let the thread deal with the previous responses, and we just rerun with the approval."""
from agent_framework import ChatMessage
@@ -14,10 +14,10 @@ OpenAI Responses Client, including user approval workflows for function call sec
"""
if TYPE_CHECKING:
from agent_framework import AgentProtocol, AgentThread
from agent_framework import SupportsAgentRun, AgentThread
async def handle_approvals_without_thread(query: str, agent: "AgentProtocol"):
async def handle_approvals_without_thread(query: str, agent: "SupportsAgentRun"):
"""When we don't have a thread, we need to ensure we return with the input, approval request and approval."""
from agent_framework import ChatMessage
@@ -39,7 +39,7 @@ async def handle_approvals_without_thread(query: str, agent: "AgentProtocol"):
return result
async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", thread: "AgentThread"):
async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", thread: "AgentThread"):
"""Here we let the thread deal with the previous responses, and we just rerun with the approval."""
from agent_framework import ChatMessage
@@ -62,7 +62,7 @@ async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", threa
return result
async def handle_approvals_with_thread_streaming(query: str, agent: "AgentProtocol", thread: "AgentThread"):
async def handle_approvals_with_thread_streaming(query: str, agent: "SupportsAgentRun", thread: "AgentThread"):
"""Here we let the thread deal with the previous responses, and we just rerun with the approval."""
from agent_framework import ChatMessage
@@ -56,7 +56,7 @@ async def main() -> None:
)
# 2) Build a concurrent workflow
# Participants are either Agents (type of AgentProtocol) or Executors
# Participants are either Agents (type of SupportsAgentRun) or Executors
workflow = ConcurrentBuilder().participants([researcher, marketer, legal]).build()
# 3) Run with a single prompt and pretty-print the final combined messages
@@ -80,7 +80,7 @@ async def main() -> None:
return response.messages[-1].text if response.messages else ""
# Build with a custom aggregator callback function
# - participants([...]) accepts AgentProtocol (agents) or Executor instances.
# - participants([...]) accepts SupportsAgentRun (agents) or Executor instances.
# Each participant becomes a parallel branch (fan-out) from an internal dispatcher.
# - with_aggregator(...) overrides the default aggregator:
# • Default aggregator -> returns list[ChatMessage] (one user + one assistant per agent)
@@ -122,7 +122,7 @@ async def run_workflow(workflow: Workflow, query: str) -> None:
async def main() -> None:
# Create a concurrent builder with participant factories and a custom aggregator
# - register_participants([...]) accepts factory functions that return
# AgentProtocol (agents) or Executor instances.
# SupportsAgentRun (agents) or Executor instances.
# - register_aggregator(...) takes a factory function that returns an Executor instance.
concurrent_builder = (
ConcurrentBuilder()
@@ -51,8 +51,6 @@ async def main() -> None:
"You are a Researcher. You find information without additional computation or quantitative analysis."
),
# This agent requires the gpt-4o-search-preview model to perform web searches.
# Feel free to explore with other agents that support web search, for example,
# the `OpenAIResponseAgent` or `AzureAgentProtocol` with bing grounding.
chat_client=OpenAIChatClient(model_id="gpt-4o-search-preview"),
)
@@ -8,7 +8,7 @@ from agent_framework import AgentResponse, ChatAgent, ChatMessage, tool
from agent_framework.openai import OpenAIResponsesClient
if TYPE_CHECKING:
from agent_framework import AgentProtocol
from agent_framework import SupportsAgentRun
"""
Demonstration of a tool with approvals.
@@ -40,7 +40,7 @@ def get_weather_detail(location: Annotated[str, "The city and state, e.g. San Fr
)
async def handle_approvals(query: str, agent: "AgentProtocol") -> AgentResponse:
async def handle_approvals(query: str, agent: "SupportsAgentRun") -> AgentResponse:
"""Handle function call approvals.
When we don't have a thread, we need to ensure we include the original query,
@@ -75,7 +75,7 @@ async def handle_approvals(query: str, agent: "AgentProtocol") -> AgentResponse:
return result
async def handle_approvals_streaming(query: str, agent: "AgentProtocol") -> None:
async def handle_approvals_streaming(query: str, agent: "SupportsAgentRun") -> None:
"""Handle function call approvals with streaming responses.
When we don't have a thread, we need to ensure we include the original query,
@@ -29,8 +29,6 @@ async def main() -> None:
"You are a Researcher. You find information without additional computation or quantitative analysis."
),
# This agent requires the gpt-4o-search-preview model to perform web searches.
# Feel free to explore with other agents that support web search, for example,
# the `OpenAIResponseAgent` or `AzureAgentProtocol` with bing grounding.
chat_client=OpenAIChatClient(model_id="gpt-4o-search-preview"),
)