mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: [BREAKING] Renamed AgentProtocol to SupportsAgentRun (#3717)
* Renamed AgentProtocol to AgentLike * Resolved comments * Renamed AgentLike to SupportsAgentRun * Resolved comments
This commit is contained in:
committed by
GitHub
Unverified
parent
ac17adb595
commit
15256bb616
@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator
|
||||
from typing import Any, cast
|
||||
|
||||
from ag_ui.core import BaseEvent
|
||||
from agent_framework import AgentProtocol
|
||||
from agent_framework import SupportsAgentRun
|
||||
|
||||
from ._run import run_agent_stream
|
||||
|
||||
@@ -65,13 +65,13 @@ class AgentConfig:
|
||||
class AgentFrameworkAgent:
|
||||
"""Wraps Agent Framework agents for AG-UI protocol compatibility.
|
||||
|
||||
Translates between Agent Framework's AgentProtocol and AG-UI's event-based
|
||||
Translates between Agent Framework's SupportsAgentRun and AG-UI's event-based
|
||||
protocol. Follows a simple linear flow: RunStarted -> content events -> RunFinished.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
agent: SupportsAgentRun,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
state_schema: Any | None = None,
|
||||
|
||||
@@ -8,7 +8,7 @@ from collections.abc import AsyncGenerator, Sequence
|
||||
from typing import Any
|
||||
|
||||
from ag_ui.encoder import EventEncoder
|
||||
from agent_framework import AgentProtocol
|
||||
from agent_framework import SupportsAgentRun
|
||||
from fastapi import FastAPI
|
||||
from fastapi.params import Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def add_agent_framework_fastapi_endpoint(
|
||||
app: FastAPI,
|
||||
agent: AgentProtocol | AgentFrameworkAgent,
|
||||
agent: SupportsAgentRun | AgentFrameworkAgent,
|
||||
path: str = "/",
|
||||
state_schema: Any | None = None,
|
||||
predict_state_config: dict[str, dict[str, str]] | None = None,
|
||||
@@ -34,7 +34,7 @@ def add_agent_framework_fastapi_endpoint(
|
||||
|
||||
Args:
|
||||
app: The FastAPI application
|
||||
agent: The agent to expose (can be raw AgentProtocol or wrapped)
|
||||
agent: The agent to expose (can be raw SupportsAgentRun or wrapped)
|
||||
path: The endpoint path
|
||||
state_schema: Optional state schema for shared state management; accepts dict or Pydantic model/class
|
||||
predict_state_config: Optional predictive state update configuration.
|
||||
@@ -47,7 +47,7 @@ def add_agent_framework_fastapi_endpoint(
|
||||
authentication checks, rate limiting, or other middleware-like behavior.
|
||||
Example: `dependencies=[Depends(verify_api_key)]`
|
||||
"""
|
||||
if isinstance(agent, AgentProtocol):
|
||||
if isinstance(agent, SupportsAgentRun):
|
||||
wrapped_agent = AgentFrameworkAgent(
|
||||
agent=agent,
|
||||
state_schema=state_schema,
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from agent_framework import BaseChatClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent_framework import AgentProtocol
|
||||
from agent_framework import SupportsAgentRun
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,7 +29,7 @@ def _collect_mcp_tool_functions(mcp_tools: list[Any]) -> list[Any]:
|
||||
return functions
|
||||
|
||||
|
||||
def collect_server_tools(agent: "AgentProtocol") -> list[Any]:
|
||||
def collect_server_tools(agent: "SupportsAgentRun") -> list[Any]:
|
||||
"""Collect server tools from an agent.
|
||||
|
||||
This includes both regular tools from default_options and MCP tools.
|
||||
@@ -64,7 +64,7 @@ def collect_server_tools(agent: "AgentProtocol") -> list[Any]:
|
||||
return server_tools
|
||||
|
||||
|
||||
def register_additional_client_tools(agent: "AgentProtocol", client_tools: list[Any] | None) -> None:
|
||||
def register_additional_client_tools(agent: "SupportsAgentRun", client_tools: list[Any] | None) -> None:
|
||||
"""Register client tools as additional declaration-only tools to avoid server execution.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -25,10 +25,10 @@ from ag_ui.core import (
|
||||
ToolCallStartEvent,
|
||||
)
|
||||
from agent_framework import (
|
||||
AgentProtocol,
|
||||
AgentThread,
|
||||
ChatMessage,
|
||||
Content,
|
||||
SupportsAgentRun,
|
||||
prepare_function_call_results,
|
||||
)
|
||||
from agent_framework._middleware import FunctionMiddlewarePipeline
|
||||
@@ -579,7 +579,7 @@ def _handle_step_based_approval(messages: list[Any]) -> list[BaseEvent]:
|
||||
async def _resolve_approval_responses(
|
||||
messages: list[Any],
|
||||
tools: list[Any],
|
||||
agent: AgentProtocol,
|
||||
agent: SupportsAgentRun,
|
||||
run_kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Execute approved function calls and replace approval content with results.
|
||||
@@ -741,7 +741,7 @@ def _build_messages_snapshot(
|
||||
|
||||
async def run_agent_stream(
|
||||
input_data: dict[str, Any],
|
||||
agent: AgentProtocol,
|
||||
agent: SupportsAgentRun,
|
||||
config: "AgentConfig",
|
||||
) -> "AsyncGenerator[BaseEvent, None]":
|
||||
"""Run agent and yield AG-UI events.
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Any, Generic, Literal, cast, overload
|
||||
|
||||
import pytest
|
||||
from agent_framework import (
|
||||
AgentProtocol,
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentThread,
|
||||
@@ -20,6 +19,7 @@ from agent_framework import (
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
Content,
|
||||
SupportsAgentRun,
|
||||
)
|
||||
from agent_framework._clients import TOptions_co
|
||||
from agent_framework._middleware import ChatMiddlewareLayer
|
||||
@@ -149,8 +149,8 @@ def stream_from_updates(updates: list[ChatResponseUpdate]) -> StreamFn:
|
||||
return _stream
|
||||
|
||||
|
||||
class StubAgent(AgentProtocol):
|
||||
"""Minimal AgentProtocol stub for orchestrator tests."""
|
||||
class StubAgent(SupportsAgentRun):
|
||||
"""Minimal SupportsAgentRun stub for orchestrator tests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -238,6 +238,6 @@ def stream_from_updates_fixture() -> Callable[[list[ChatResponseUpdate]], Stream
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stub_agent() -> type[AgentProtocol]:
|
||||
def stub_agent() -> type[SupportsAgentRun]:
|
||||
"""Return the StubAgent class for creating test instances."""
|
||||
return StubAgent # type: ignore[return-value]
|
||||
|
||||
@@ -26,7 +26,7 @@ def build_chat_client(streaming_chat_client_stub, stream_from_updates_fixture):
|
||||
|
||||
|
||||
async def test_add_endpoint_with_agent_protocol(build_chat_client):
|
||||
"""Test adding endpoint with raw AgentProtocol."""
|
||||
"""Test adding endpoint with raw SupportsAgentRun."""
|
||||
app = FastAPI()
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||
|
||||
import azure.durable_functions as df
|
||||
import azure.functions as func
|
||||
from agent_framework import AgentProtocol, get_logger
|
||||
from agent_framework import SupportsAgentRun, get_logger
|
||||
from agent_framework_durabletask import (
|
||||
DEFAULT_MAX_POLL_RETRIES,
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
@@ -51,12 +51,12 @@ class AgentMetadata:
|
||||
"""Metadata for a registered agent.
|
||||
|
||||
Attributes:
|
||||
agent: The agent instance implementing AgentProtocol
|
||||
agent: The agent instance implementing SupportsAgentRun
|
||||
http_endpoint_enabled: Whether HTTP endpoint is enabled for this agent
|
||||
mcp_tool_enabled: Whether MCP tool endpoint is enabled for this agent
|
||||
"""
|
||||
|
||||
agent: AgentProtocol
|
||||
agent: SupportsAgentRun
|
||||
http_endpoint_enabled: bool
|
||||
mcp_tool_enabled: bool
|
||||
|
||||
@@ -145,7 +145,7 @@ class AgentFunctionApp(DFAppBase):
|
||||
- Full access to all Azure Functions capabilities
|
||||
|
||||
Attributes:
|
||||
agents: Dictionary of agent name to AgentProtocol instance
|
||||
agents: Dictionary of agent name to SupportsAgentRun instance
|
||||
enable_health_check: Whether health check endpoint is enabled
|
||||
enable_http_endpoints: Whether HTTP endpoints are created for agents
|
||||
enable_mcp_tool_trigger: Whether MCP tool triggers are created for agents
|
||||
@@ -160,7 +160,7 @@ class AgentFunctionApp(DFAppBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agents: list[AgentProtocol] | None = None,
|
||||
agents: list[SupportsAgentRun] | None = None,
|
||||
http_auth_level: func.AuthLevel = func.AuthLevel.FUNCTION,
|
||||
enable_health_check: bool = True,
|
||||
enable_http_endpoints: bool = True,
|
||||
@@ -222,17 +222,17 @@ class AgentFunctionApp(DFAppBase):
|
||||
logger.debug("[AgentFunctionApp] Initialization complete")
|
||||
|
||||
@property
|
||||
def agents(self) -> dict[str, AgentProtocol]:
|
||||
def agents(self) -> dict[str, SupportsAgentRun]:
|
||||
"""Returns dict of agent names to agent instances.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping agent names to their AgentProtocol instances.
|
||||
Dictionary mapping agent names to their SupportsAgentRun instances.
|
||||
"""
|
||||
return {name: metadata.agent for name, metadata in self._agent_metadata.items()}
|
||||
|
||||
def add_agent(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
agent: SupportsAgentRun,
|
||||
callback: AgentResponseCallbackProtocol | None = None,
|
||||
enable_http_endpoint: bool | None = None,
|
||||
enable_mcp_tool_trigger: bool | None = None,
|
||||
@@ -240,7 +240,7 @@ class AgentFunctionApp(DFAppBase):
|
||||
"""Add an agent to the function app after initialization.
|
||||
|
||||
Args:
|
||||
agent: The Microsoft Agent Framework agent instance (must implement AgentProtocol)
|
||||
agent: The Microsoft Agent Framework agent instance (must implement SupportsAgentRun)
|
||||
The agent must have a 'name' attribute.
|
||||
callback: Optional callback invoked during agent execution
|
||||
enable_http_endpoint: Optional flag to enable/disable HTTP endpoint for this agent.
|
||||
@@ -322,7 +322,7 @@ class AgentFunctionApp(DFAppBase):
|
||||
|
||||
def _setup_agent_functions(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
agent: SupportsAgentRun,
|
||||
agent_name: str,
|
||||
callback: AgentResponseCallbackProtocol | None,
|
||||
enable_http_endpoint: bool,
|
||||
@@ -484,7 +484,7 @@ class AgentFunctionApp(DFAppBase):
|
||||
|
||||
def _setup_agent_entity(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
agent: SupportsAgentRun,
|
||||
agent_name: str,
|
||||
callback: AgentResponseCallbackProtocol | None,
|
||||
) -> None:
|
||||
|
||||
@@ -12,7 +12,7 @@ from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import azure.durable_functions as df
|
||||
from agent_framework import AgentProtocol, get_logger
|
||||
from agent_framework import SupportsAgentRun, get_logger
|
||||
from agent_framework_durabletask import (
|
||||
AgentEntity,
|
||||
AgentEntityStateProviderMixin,
|
||||
@@ -46,13 +46,13 @@ class AzureFunctionEntityStateProvider(AgentEntityStateProviderMixin):
|
||||
|
||||
|
||||
def create_agent_entity(
|
||||
agent: AgentProtocol,
|
||||
agent: SupportsAgentRun,
|
||||
callback: AgentResponseCallbackProtocol | None = None,
|
||||
) -> Callable[[df.DurableEntityContext], None]:
|
||||
"""Factory function to create an agent entity class.
|
||||
|
||||
Args:
|
||||
agent: The Microsoft Agent Framework agent instance (must implement AgentProtocol)
|
||||
agent: The Microsoft Agent Framework agent instance (must implement SupportsAgentRun)
|
||||
callback: Optional callback invoked during streaming and final responses
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -5,7 +5,7 @@ and [OpenAI ChatKit (Python)](https://github.com/openai/chatkit-python/).
|
||||
Specifically, it mirrors the [Agent SDK integration](https://github.com/openai/chatkit-python/blob/main/docs/server.md#agents-sdk-integration), and provides the following helpers:
|
||||
|
||||
- `stream_agent_response`: A helper to convert a streamed `AgentResponseUpdate`
|
||||
from a Microsoft Agent Framework agent that implements `AgentProtocol` to ChatKit events.
|
||||
from a Microsoft Agent Framework agent that implements `SupportsAgentRun` to ChatKit events.
|
||||
- `ThreadItemConverter`: A extendable helper class to convert ChatKit thread items to
|
||||
`ChatMessage` objects that can be consumed by an Agent Framework agent.
|
||||
- `simple_to_agent_input`: A helper function that uses the default implementation
|
||||
|
||||
@@ -25,7 +25,7 @@ agent_framework/
|
||||
|
||||
### Agents (`_agents.py`)
|
||||
|
||||
- **`AgentProtocol`** - Protocol defining the agent interface
|
||||
- **`SupportsAgentRun`** - Protocol defining the agent interface
|
||||
- **`BaseAgent`** - Abstract base class for agents
|
||||
- **`ChatAgent`** - Main agent class wrapping a chat client with tools, instructions, and middleware
|
||||
|
||||
|
||||
@@ -163,14 +163,14 @@ class _RunContext(TypedDict):
|
||||
finalize_kwargs: dict[str, Any]
|
||||
|
||||
|
||||
__all__ = ["AgentProtocol", "BareAgent", "BaseAgent", "ChatAgent", "RawChatAgent"]
|
||||
__all__ = ["BareAgent", "BaseAgent", "ChatAgent", "RawChatAgent", "SupportsAgentRun"]
|
||||
|
||||
|
||||
# region Agent Protocol
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AgentProtocol(Protocol):
|
||||
class SupportsAgentRun(Protocol):
|
||||
"""A protocol for an agent that can be invoked.
|
||||
|
||||
This protocol defines the interface that all agents must implement,
|
||||
@@ -185,11 +185,11 @@ class AgentProtocol(Protocol):
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework import AgentProtocol
|
||||
from agent_framework import SupportsAgentRun
|
||||
|
||||
|
||||
# Any class implementing the required methods is compatible
|
||||
# No need to inherit from AgentProtocol or use any framework classes
|
||||
# No need to inherit from SupportsAgentRun or use any framework classes
|
||||
class CustomAgent:
|
||||
def __init__(self):
|
||||
self.id = "custom-agent-001"
|
||||
@@ -218,7 +218,7 @@ class AgentProtocol(Protocol):
|
||||
|
||||
# Verify the instance satisfies the protocol
|
||||
instance = CustomAgent()
|
||||
assert isinstance(instance, AgentProtocol)
|
||||
assert isinstance(instance, SupportsAgentRun)
|
||||
"""
|
||||
|
||||
id: str
|
||||
@@ -297,7 +297,7 @@ class BaseAgent(SerializationMixin):
|
||||
|
||||
Note:
|
||||
BaseAgent cannot be instantiated directly as it doesn't implement the
|
||||
``run()`` and other methods required by AgentProtocol.
|
||||
``run()`` and other methods required by SupportsAgentRun.
|
||||
Use a concrete implementation like ChatAgent or create a subclass.
|
||||
|
||||
Examples:
|
||||
@@ -451,7 +451,7 @@ class BaseAgent(SerializationMixin):
|
||||
A FunctionTool that can be used as a tool by other agents.
|
||||
|
||||
Raises:
|
||||
TypeError: If the agent does not implement AgentProtocol.
|
||||
TypeError: If the agent does not implement SupportsAgentRun.
|
||||
ValueError: If the agent tool name cannot be determined.
|
||||
|
||||
Examples:
|
||||
@@ -468,9 +468,9 @@ class BaseAgent(SerializationMixin):
|
||||
# Use the tool with another agent
|
||||
coordinator = ChatAgent(chat_client=client, name="coordinator", tools=research_tool)
|
||||
"""
|
||||
# Verify that self implements AgentProtocol
|
||||
if not isinstance(self, AgentProtocol):
|
||||
raise TypeError(f"Agent {self.__class__.__name__} must implement AgentProtocol to be used as a tool")
|
||||
# Verify that self implements SupportsAgentRun
|
||||
if not isinstance(self, SupportsAgentRun):
|
||||
raise TypeError(f"Agent {self.__class__.__name__} must implement SupportsAgentRun to be used as a tool")
|
||||
|
||||
tool_name = name or _sanitize_agent_name(self.name)
|
||||
if tool_name is None:
|
||||
|
||||
@@ -34,7 +34,7 @@ else:
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._agents import AgentProtocol
|
||||
from ._agents import SupportsAgentRun
|
||||
from ._clients import ChatClientProtocol
|
||||
from ._threads import AgentThread
|
||||
from ._tools import FunctionTool
|
||||
@@ -64,7 +64,7 @@ __all__ = [
|
||||
"function_middleware",
|
||||
]
|
||||
|
||||
TAgent = TypeVar("TAgent", bound="AgentProtocol")
|
||||
AgentT = TypeVar("AgentT", bound="SupportsAgentRun")
|
||||
TContext = TypeVar("TContext")
|
||||
TUpdate = TypeVar("TUpdate")
|
||||
|
||||
@@ -154,7 +154,7 @@ class AgentContext:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
agent: AgentProtocol,
|
||||
agent: SupportsAgentRun,
|
||||
messages: list[ChatMessage],
|
||||
thread: AgentThread | None = None,
|
||||
options: Mapping[str, Any] | None = None,
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing_extensions import Never
|
||||
|
||||
from agent_framework import Content
|
||||
|
||||
from .._agents import AgentProtocol
|
||||
from .._agents import SupportsAgentRun
|
||||
from .._threads import AgentThread
|
||||
from .._types import AgentResponse, AgentResponseUpdate, ChatMessage
|
||||
from ._agent_utils import resolve_agent_id
|
||||
@@ -80,7 +80,7 @@ class AgentExecutor(Executor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
agent: SupportsAgentRun,
|
||||
*,
|
||||
agent_thread: AgentThread | None = None,
|
||||
id: str | None = None,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from .._agents import AgentProtocol
|
||||
from .._agents import SupportsAgentRun
|
||||
|
||||
|
||||
def resolve_agent_id(agent: AgentProtocol) -> str:
|
||||
def resolve_agent_id(agent: SupportsAgentRun) -> str:
|
||||
"""Resolve the unique identifier for an agent.
|
||||
|
||||
Prefers the `.name` attribute if set; otherwise falls back to `.id`.
|
||||
|
||||
@@ -6,7 +6,7 @@ from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .._agents import AgentProtocol
|
||||
from .._agents import SupportsAgentRun
|
||||
from .._threads import AgentThread
|
||||
from ..observability import OtelAttr, capture_exception, create_workflow_span
|
||||
from ._agent_executor import AgentExecutor
|
||||
@@ -171,7 +171,7 @@ class WorkflowBuilder:
|
||||
self._max_iterations: int = max_iterations
|
||||
self._name: str | None = name
|
||||
self._description: str | None = description
|
||||
# Maps underlying AgentProtocol object id -> wrapped Executor so we reuse the same wrapper
|
||||
# Maps underlying SupportsAgentRun object id -> wrapped Executor so we reuse the same wrapper
|
||||
# across set_start_executor / add_edge calls. This avoids multiple AgentExecutor instances
|
||||
# being created for the same agent.
|
||||
self._agent_wrappers: dict[str, Executor] = {}
|
||||
@@ -187,7 +187,7 @@ class WorkflowBuilder:
|
||||
self._executor_registry: dict[str, Callable[[], Executor]] = {}
|
||||
|
||||
# Output executors filter; if set, only outputs from these executors are yielded
|
||||
self._output_executors: list[Executor | AgentProtocol | str] = []
|
||||
self._output_executors: list[Executor | SupportsAgentRun | str] = []
|
||||
|
||||
# Agents auto-wrapped by builder now always stream incremental updates.
|
||||
|
||||
@@ -208,8 +208,8 @@ class WorkflowBuilder:
|
||||
|
||||
return executor.id
|
||||
|
||||
def _maybe_wrap_agent(self, candidate: Executor | AgentProtocol) -> Executor:
|
||||
"""If the provided object implements AgentProtocol, wrap it in an AgentExecutor.
|
||||
def _maybe_wrap_agent(self, candidate: Executor | SupportsAgentRun) -> Executor:
|
||||
"""If the provided object implements SupportsAgentRun, wrap it in an AgentExecutor.
|
||||
|
||||
This allows fluent builder APIs to directly accept agents instead of
|
||||
requiring callers to manually instantiate AgentExecutor.
|
||||
@@ -221,13 +221,13 @@ class WorkflowBuilder:
|
||||
An Executor instance, wrapping the agent if necessary.
|
||||
"""
|
||||
try: # Local import to avoid hard dependency at import time
|
||||
from agent_framework import AgentProtocol # type: ignore
|
||||
from agent_framework import SupportsAgentRun # type: ignore
|
||||
except Exception: # pragma: no cover - defensive
|
||||
AgentProtocol = object # type: ignore
|
||||
SupportsAgentRun = object # type: ignore
|
||||
|
||||
if isinstance(candidate, Executor): # Already an executor
|
||||
return candidate
|
||||
if isinstance(candidate, AgentProtocol): # type: ignore[arg-type]
|
||||
if isinstance(candidate, SupportsAgentRun): # type: ignore[arg-type]
|
||||
# Reuse existing wrapper for the same agent instance if present
|
||||
agent_instance_id = str(id(candidate))
|
||||
existing = self._agent_wrappers.get(agent_instance_id)
|
||||
@@ -244,7 +244,7 @@ class WorkflowBuilder:
|
||||
return wrapper
|
||||
|
||||
raise TypeError(
|
||||
f"WorkflowBuilder expected an Executor or AgentProtocol instance; got {type(candidate).__name__}."
|
||||
f"WorkflowBuilder expected an Executor or SupportsAgentRun instance; got {type(candidate).__name__}."
|
||||
)
|
||||
|
||||
def register_executor(self, factory_func: Callable[[], Executor], name: str | list[str]) -> Self:
|
||||
@@ -321,7 +321,7 @@ class WorkflowBuilder:
|
||||
|
||||
def register_agent(
|
||||
self,
|
||||
factory_func: Callable[[], AgentProtocol],
|
||||
factory_func: Callable[[], SupportsAgentRun],
|
||||
name: str,
|
||||
agent_thread: AgentThread | None = None,
|
||||
) -> Self:
|
||||
@@ -332,7 +332,7 @@ class WorkflowBuilder:
|
||||
enabling deferred initialization and potentially reducing startup time.
|
||||
|
||||
Args:
|
||||
factory_func: A callable that returns an AgentProtocol instance when called.
|
||||
factory_func: A callable that returns an SupportsAgentRun instance when called.
|
||||
name: The name of the registered agent factory. This doesn't have to match
|
||||
the agent's internal name. But it must be unique within the workflow.
|
||||
agent_thread: The thread to use for running the agent. If None, a new thread will be created when
|
||||
@@ -375,8 +375,8 @@ class WorkflowBuilder:
|
||||
|
||||
def add_edge(
|
||||
self,
|
||||
source: Executor | AgentProtocol | str,
|
||||
target: Executor | AgentProtocol | str,
|
||||
source: Executor | SupportsAgentRun | str,
|
||||
target: Executor | SupportsAgentRun | str,
|
||||
condition: EdgeCondition | None = None,
|
||||
) -> Self:
|
||||
"""Add a directed edge between two executors.
|
||||
@@ -441,8 +441,8 @@ class WorkflowBuilder:
|
||||
not isinstance(source, str) and isinstance(target, str)
|
||||
):
|
||||
raise ValueError(
|
||||
"Both source and target must be either registered factory names (str) "
|
||||
"or Executor/AgentProtocol instances."
|
||||
"Both source and target must be either registered factory names (str) or "
|
||||
"Executor/SupportsAgentRun instances."
|
||||
)
|
||||
|
||||
if isinstance(source, str) and isinstance(target, str):
|
||||
@@ -450,7 +450,7 @@ class WorkflowBuilder:
|
||||
self._edge_registry.append(_EdgeRegistration(source=source, target=target, condition=condition))
|
||||
return self
|
||||
|
||||
# Both are Executor/AgentProtocol instances; wrap and add now
|
||||
# Both are Executor/SupportsAgentRun instances; wrap and add now
|
||||
source_exec = self._maybe_wrap_agent(source) # type: ignore[arg-type]
|
||||
target_exec = self._maybe_wrap_agent(target) # type: ignore[arg-type]
|
||||
source_id = self._add_executor(source_exec)
|
||||
@@ -460,8 +460,8 @@ class WorkflowBuilder:
|
||||
|
||||
def add_fan_out_edges(
|
||||
self,
|
||||
source: Executor | AgentProtocol | str,
|
||||
targets: Sequence[Executor | AgentProtocol | str],
|
||||
source: Executor | SupportsAgentRun | str,
|
||||
targets: Sequence[Executor | SupportsAgentRun | str],
|
||||
) -> Self:
|
||||
"""Add multiple edges to the workflow where messages from the source will be sent to all targets.
|
||||
|
||||
@@ -520,8 +520,8 @@ class WorkflowBuilder:
|
||||
not isinstance(source, str) and any(isinstance(t, str) for t in targets)
|
||||
):
|
||||
raise ValueError(
|
||||
"Both source and targets must be either registered factory names (str) "
|
||||
"or Executor/AgentProtocol instances."
|
||||
"Both source and targets must be either registered factory names (str) or "
|
||||
"Executor/SupportsAgentRun instances."
|
||||
)
|
||||
|
||||
if isinstance(source, str) and all(isinstance(t, str) for t in targets):
|
||||
@@ -529,7 +529,7 @@ class WorkflowBuilder:
|
||||
self._edge_registry.append(_FanOutEdgeRegistration(source=source, targets=list(targets))) # type: ignore
|
||||
return self
|
||||
|
||||
# Both are Executor/AgentProtocol instances; wrap and add now
|
||||
# Both are Executor/SupportsAgentRun instances; wrap and add now
|
||||
source_exec = self._maybe_wrap_agent(source) # type: ignore[arg-type]
|
||||
target_execs = [self._maybe_wrap_agent(t) for t in targets] # type: ignore[arg-type]
|
||||
source_id = self._add_executor(source_exec)
|
||||
@@ -540,7 +540,7 @@ class WorkflowBuilder:
|
||||
|
||||
def add_switch_case_edge_group(
|
||||
self,
|
||||
source: Executor | AgentProtocol | str,
|
||||
source: Executor | SupportsAgentRun | str,
|
||||
cases: Sequence[Case | Default],
|
||||
) -> Self:
|
||||
"""Add an edge group that represents a switch-case statement.
|
||||
@@ -620,7 +620,7 @@ class WorkflowBuilder:
|
||||
):
|
||||
raise ValueError(
|
||||
"Both source and case targets must be either registered factory names (str) "
|
||||
"or Executor/AgentProtocol instances."
|
||||
"or Executor/SupportsAgentRun instances."
|
||||
)
|
||||
|
||||
if isinstance(source, str) and all(isinstance(case.target, str) for case in cases):
|
||||
@@ -628,7 +628,7 @@ class WorkflowBuilder:
|
||||
self._edge_registry.append(_SwitchCaseEdgeGroupRegistration(source=source, cases=list(cases))) # type: ignore
|
||||
return self
|
||||
|
||||
# Source is an Executor/AgentProtocol instance; wrap and add now
|
||||
# Source is an Executor/SupportsAgentRun instance; wrap and add now
|
||||
source_exec = self._maybe_wrap_agent(source) # type: ignore[arg-type]
|
||||
source_id = self._add_executor(source_exec)
|
||||
# Convert case data types to internal types that only uses target_id.
|
||||
@@ -647,8 +647,8 @@ class WorkflowBuilder:
|
||||
|
||||
def add_multi_selection_edge_group(
|
||||
self,
|
||||
source: Executor | AgentProtocol | str,
|
||||
targets: Sequence[Executor | AgentProtocol | str],
|
||||
source: Executor | SupportsAgentRun | str,
|
||||
targets: Sequence[Executor | SupportsAgentRun | str],
|
||||
selection_func: Callable[[Any, list[str]], list[str]],
|
||||
) -> Self:
|
||||
"""Add an edge group that represents a multi-selection execution model.
|
||||
@@ -731,8 +731,8 @@ class WorkflowBuilder:
|
||||
not isinstance(source, str) and any(isinstance(t, str) for t in targets)
|
||||
):
|
||||
raise ValueError(
|
||||
"Both source and targets must be either registered factory names (str) "
|
||||
"or Executor/AgentProtocol instances."
|
||||
"Both source and targets must be either registered factory names (str) or "
|
||||
"Executor/SupportsAgentRun instances."
|
||||
)
|
||||
|
||||
if isinstance(source, str) and all(isinstance(t, str) for t in targets):
|
||||
@@ -746,7 +746,7 @@ class WorkflowBuilder:
|
||||
)
|
||||
return self
|
||||
|
||||
# Both are Executor/AgentProtocol instances; wrap and add now
|
||||
# Both are Executor/SupportsAgentRun instances; wrap and add now
|
||||
source_exec = self._maybe_wrap_agent(source) # type: ignore
|
||||
target_execs = [self._maybe_wrap_agent(t) for t in targets] # type: ignore
|
||||
source_id = self._add_executor(source_exec)
|
||||
@@ -757,8 +757,8 @@ class WorkflowBuilder:
|
||||
|
||||
def add_fan_in_edges(
|
||||
self,
|
||||
sources: Sequence[Executor | AgentProtocol | str],
|
||||
target: Executor | AgentProtocol | str,
|
||||
sources: Sequence[Executor | SupportsAgentRun | str],
|
||||
target: Executor | SupportsAgentRun | str,
|
||||
) -> Self:
|
||||
"""Add multiple edges from sources to a single target executor.
|
||||
|
||||
@@ -816,8 +816,8 @@ class WorkflowBuilder:
|
||||
not all(isinstance(s, str) for s in sources) and isinstance(target, str)
|
||||
):
|
||||
raise ValueError(
|
||||
"Both sources and target must be either registered factory names (str) "
|
||||
"or Executor/AgentProtocol instances."
|
||||
"Both sources and target must be either registered factory names (str) or "
|
||||
"Executor/SupportsAgentRun instances."
|
||||
)
|
||||
|
||||
if all(isinstance(s, str) for s in sources) and isinstance(target, str):
|
||||
@@ -825,7 +825,7 @@ class WorkflowBuilder:
|
||||
self._edge_registry.append(_FanInEdgeRegistration(sources=list(sources), target=target)) # type: ignore
|
||||
return self
|
||||
|
||||
# Both are Executor/AgentProtocol instances; wrap and add now
|
||||
# Both are Executor/SupportsAgentRun instances; wrap and add now
|
||||
source_execs = [self._maybe_wrap_agent(s) for s in sources] # type: ignore
|
||||
target_exec = self._maybe_wrap_agent(target) # type: ignore
|
||||
source_ids = [self._add_executor(s) for s in source_execs]
|
||||
@@ -834,7 +834,7 @@ class WorkflowBuilder:
|
||||
|
||||
return self
|
||||
|
||||
def add_chain(self, executors: Sequence[Executor | AgentProtocol | str]) -> Self:
|
||||
def add_chain(self, executors: Sequence[Executor | SupportsAgentRun | str]) -> Self:
|
||||
"""Add a chain of executors to the workflow.
|
||||
|
||||
The output of each executor in the chain will be sent to the next executor in the chain.
|
||||
@@ -895,7 +895,7 @@ class WorkflowBuilder:
|
||||
if not all(isinstance(e, str) for e in executors) and any(isinstance(e, str) for e in executors):
|
||||
raise ValueError(
|
||||
"All executors in the chain must be either registered factory names (str) "
|
||||
"or Executor/AgentProtocol instances."
|
||||
"or Executor/SupportsAgentRun instances."
|
||||
)
|
||||
|
||||
if all(isinstance(e, str) for e in executors):
|
||||
@@ -904,21 +904,21 @@ class WorkflowBuilder:
|
||||
self.add_edge(executors[i], executors[i + 1])
|
||||
return self
|
||||
|
||||
# All are Executor/AgentProtocol instances; wrap and add now
|
||||
# All are Executor/SupportsAgentRun instances; wrap and add now
|
||||
# Wrap each candidate first to ensure stable IDs before adding edges
|
||||
wrapped: list[Executor] = [self._maybe_wrap_agent(e) for e in executors] # type: ignore[arg-type]
|
||||
for i in range(len(wrapped) - 1):
|
||||
self.add_edge(wrapped[i], wrapped[i + 1])
|
||||
return self
|
||||
|
||||
def set_start_executor(self, executor: Executor | AgentProtocol | str) -> Self:
|
||||
def set_start_executor(self, executor: Executor | SupportsAgentRun | str) -> Self:
|
||||
"""Set the starting executor for the workflow.
|
||||
|
||||
The start executor is the entry point for the workflow. When the workflow is executed,
|
||||
the initial message will be sent to this executor.
|
||||
|
||||
Args:
|
||||
executor: The starting executor, which can be an Executor instance, AgentProtocol instance,
|
||||
executor: The starting executor, which can be an Executor instance, SupportsAgentRun instance,
|
||||
or the name of a registered executor factory.
|
||||
|
||||
Returns:
|
||||
@@ -1067,7 +1067,7 @@ class WorkflowBuilder:
|
||||
self._checkpoint_storage = checkpoint_storage
|
||||
return self
|
||||
|
||||
def with_output_from(self, executors: list[Executor | AgentProtocol | str]) -> Self:
|
||||
def with_output_from(self, executors: list[Executor | SupportsAgentRun | str]) -> Self:
|
||||
"""Specify which executors' outputs should be collected as workflow outputs.
|
||||
|
||||
By default, outputs from all executors are collected. This method allows
|
||||
@@ -1231,7 +1231,11 @@ class WorkflowBuilder:
|
||||
if isinstance(factory_name, str)
|
||||
]
|
||||
+ [ex.id for ex in self._output_executors if isinstance(ex, Executor)]
|
||||
+ [resolve_agent_id(agent) for agent in self._output_executors if isinstance(agent, AgentProtocol)]
|
||||
+ [
|
||||
resolve_agent_id(agent)
|
||||
for agent in self._output_executors
|
||||
if isinstance(agent, SupportsAgentRun)
|
||||
]
|
||||
)
|
||||
|
||||
# Perform validation before creating the workflow
|
||||
|
||||
@@ -38,7 +38,7 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
from opentelemetry.util._decorator import _AgnosticContextManager # type: ignore[reportPrivateUsage]
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._agents import AgentProtocol
|
||||
from ._agents import SupportsAgentRun
|
||||
from ._clients import ChatClientProtocol
|
||||
from ._threads import AgentThread
|
||||
from ._tools import FunctionTool
|
||||
@@ -70,7 +70,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
TAgent = TypeVar("TAgent", bound="AgentProtocol")
|
||||
AgentT = TypeVar("AgentT", bound="SupportsAgentRun")
|
||||
TChatClient = TypeVar("TChatClient", bound="ChatClientProtocol[Any]")
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from pydantic import BaseModel
|
||||
from pytest import fixture
|
||||
|
||||
from agent_framework import (
|
||||
AgentProtocol,
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentThread,
|
||||
@@ -24,6 +23,7 @@ from agent_framework import (
|
||||
Content,
|
||||
FunctionInvocationLayer,
|
||||
ResponseStream,
|
||||
SupportsAgentRun,
|
||||
ToolProtocol,
|
||||
tool,
|
||||
)
|
||||
@@ -273,7 +273,7 @@ class MockAgentThread(AgentThread):
|
||||
|
||||
|
||||
# Mock Agent implementation for testing
|
||||
class MockAgent(AgentProtocol):
|
||||
class MockAgent(SupportsAgentRun):
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return str(uuid4())
|
||||
@@ -329,5 +329,5 @@ def agent_thread() -> AgentThread:
|
||||
|
||||
|
||||
@fixture
|
||||
def agent() -> AgentProtocol:
|
||||
def agent() -> SupportsAgentRun:
|
||||
return MockAgent()
|
||||
|
||||
@@ -10,7 +10,6 @@ import pytest
|
||||
from pytest import raises
|
||||
|
||||
from agent_framework import (
|
||||
AgentProtocol,
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentThread,
|
||||
@@ -24,6 +23,7 @@ from agent_framework import (
|
||||
Context,
|
||||
ContextProvider,
|
||||
HostedCodeInterpreterTool,
|
||||
SupportsAgentRun,
|
||||
ToolProtocol,
|
||||
tool,
|
||||
)
|
||||
@@ -36,17 +36,17 @@ def test_agent_thread_type(agent_thread: AgentThread) -> None:
|
||||
assert isinstance(agent_thread, AgentThread)
|
||||
|
||||
|
||||
def test_agent_type(agent: AgentProtocol) -> None:
|
||||
assert isinstance(agent, AgentProtocol)
|
||||
def test_agent_type(agent: SupportsAgentRun) -> None:
|
||||
assert isinstance(agent, SupportsAgentRun)
|
||||
|
||||
|
||||
async def test_agent_run(agent: AgentProtocol) -> None:
|
||||
async def test_agent_run(agent: SupportsAgentRun) -> None:
|
||||
response = await agent.run("test")
|
||||
assert response.messages[0].role == "assistant"
|
||||
assert response.messages[0].text == "Response"
|
||||
|
||||
|
||||
async def test_agent_run_streaming(agent: AgentProtocol) -> None:
|
||||
async def test_agent_run_streaming(agent: SupportsAgentRun) -> None:
|
||||
async def collect_updates(updates: AsyncIterable[AgentResponseUpdate]) -> list[AgentResponseUpdate]:
|
||||
return [u async for u in updates]
|
||||
|
||||
@@ -57,7 +57,7 @@ async def test_agent_run_streaming(agent: AgentProtocol) -> None:
|
||||
|
||||
def test_chat_client_agent_type(chat_client: ChatClientProtocol) -> None:
|
||||
chat_client_agent = ChatAgent(chat_client=chat_client)
|
||||
assert isinstance(chat_client_agent, AgentProtocol)
|
||||
assert isinstance(chat_client_agent, SupportsAgentRun)
|
||||
|
||||
|
||||
async def test_chat_client_agent_init(chat_client: ChatClientProtocol) -> None:
|
||||
@@ -804,7 +804,7 @@ def test_sanitize_agent_name_replaces_invalid_chars():
|
||||
# endregion
|
||||
|
||||
|
||||
# region Test AgentProtocol.get_new_thread and deserialize_thread
|
||||
# region Test SupportsAgentRun.get_new_thread and deserialize_thread
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -8,7 +8,6 @@ import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from agent_framework import (
|
||||
AgentProtocol,
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
ChatMessage,
|
||||
@@ -16,6 +15,7 @@ from agent_framework import (
|
||||
ChatResponseUpdate,
|
||||
Content,
|
||||
ResponseStream,
|
||||
SupportsAgentRun,
|
||||
)
|
||||
from agent_framework._middleware import (
|
||||
AgentContext,
|
||||
@@ -35,7 +35,7 @@ from agent_framework._tools import FunctionTool
|
||||
class TestAgentContext:
|
||||
"""Test cases for AgentContext."""
|
||||
|
||||
def test_init_with_defaults(self, mock_agent: AgentProtocol) -> None:
|
||||
def test_init_with_defaults(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test AgentContext initialization with default values."""
|
||||
messages = [ChatMessage(role="user", text="test")]
|
||||
context = AgentContext(agent=mock_agent, messages=messages)
|
||||
@@ -45,7 +45,7 @@ class TestAgentContext:
|
||||
assert context.stream is False
|
||||
assert context.metadata == {}
|
||||
|
||||
def test_init_with_custom_values(self, mock_agent: AgentProtocol) -> None:
|
||||
def test_init_with_custom_values(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test AgentContext initialization with custom values."""
|
||||
messages = [ChatMessage(role="user", text="test")]
|
||||
metadata = {"key": "value"}
|
||||
@@ -56,7 +56,7 @@ class TestAgentContext:
|
||||
assert context.stream is True
|
||||
assert context.metadata == metadata
|
||||
|
||||
def test_init_with_thread(self, mock_agent: AgentProtocol) -> None:
|
||||
def test_init_with_thread(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test AgentContext initialization with thread parameter."""
|
||||
from agent_framework import AgentThread
|
||||
|
||||
@@ -163,7 +163,7 @@ class TestAgentMiddlewarePipeline:
|
||||
pipeline = AgentMiddlewarePipeline(test_middleware)
|
||||
assert pipeline.has_middlewares
|
||||
|
||||
async def test_execute_no_middleware(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_execute_no_middleware(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test pipeline execution with no middleware."""
|
||||
pipeline = AgentMiddlewarePipeline()
|
||||
messages = [ChatMessage(role="user", text="test")]
|
||||
@@ -177,7 +177,7 @@ class TestAgentMiddlewarePipeline:
|
||||
result = await pipeline.execute(context, final_handler)
|
||||
assert result == expected_response
|
||||
|
||||
async def test_execute_with_middleware(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_execute_with_middleware(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test pipeline execution with middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@@ -205,7 +205,7 @@ class TestAgentMiddlewarePipeline:
|
||||
assert result == expected_response
|
||||
assert execution_order == ["test_before", "handler", "test_after"]
|
||||
|
||||
async def test_execute_stream_no_middleware(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_execute_stream_no_middleware(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test pipeline streaming execution with no middleware."""
|
||||
pipeline = AgentMiddlewarePipeline()
|
||||
messages = [ChatMessage(role="user", text="test")]
|
||||
@@ -228,7 +228,7 @@ class TestAgentMiddlewarePipeline:
|
||||
assert updates[0].text == "chunk1"
|
||||
assert updates[1].text == "chunk2"
|
||||
|
||||
async def test_execute_stream_with_middleware(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_execute_stream_with_middleware(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test pipeline streaming execution with middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@@ -265,7 +265,7 @@ class TestAgentMiddlewarePipeline:
|
||||
assert updates[1].text == "chunk2"
|
||||
assert execution_order == ["test_before", "test_after", "handler_start", "handler_end"]
|
||||
|
||||
async def test_execute_with_pre_next_termination(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_execute_with_pre_next_termination(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test pipeline execution with termination before next()."""
|
||||
middleware = self.PreNextTerminateMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline(middleware)
|
||||
@@ -283,7 +283,7 @@ class TestAgentMiddlewarePipeline:
|
||||
# Handler should not be called when terminated before next()
|
||||
assert execution_order == []
|
||||
|
||||
async def test_execute_with_post_next_termination(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_execute_with_post_next_termination(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test pipeline execution with termination after next()."""
|
||||
middleware = self.PostNextTerminateMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline(middleware)
|
||||
@@ -301,7 +301,7 @@ class TestAgentMiddlewarePipeline:
|
||||
assert response.messages[0].text == "response"
|
||||
assert execution_order == ["handler"]
|
||||
|
||||
async def test_execute_stream_with_pre_next_termination(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_execute_stream_with_pre_next_termination(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test pipeline streaming execution with termination before next()."""
|
||||
middleware = self.PreNextTerminateMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline(middleware)
|
||||
@@ -329,7 +329,7 @@ class TestAgentMiddlewarePipeline:
|
||||
assert execution_order == []
|
||||
assert not updates
|
||||
|
||||
async def test_execute_stream_with_post_next_termination(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_execute_stream_with_post_next_termination(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test pipeline streaming execution with termination after next()."""
|
||||
middleware = self.PostNextTerminateMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline(middleware)
|
||||
@@ -356,7 +356,7 @@ class TestAgentMiddlewarePipeline:
|
||||
assert updates[1].text == "chunk2"
|
||||
assert execution_order == ["handler_start", "handler_end"]
|
||||
|
||||
async def test_execute_with_thread_in_context(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_execute_with_thread_in_context(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test pipeline execution properly passes thread to middleware."""
|
||||
from agent_framework import AgentThread
|
||||
|
||||
@@ -383,7 +383,7 @@ class TestAgentMiddlewarePipeline:
|
||||
assert result == expected_response
|
||||
assert captured_thread is thread
|
||||
|
||||
async def test_execute_with_no_thread_in_context(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_execute_with_no_thread_in_context(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test pipeline execution when no thread is provided."""
|
||||
captured_thread = "not_none" # Use string to distinguish from None
|
||||
|
||||
@@ -761,7 +761,7 @@ class TestChatMiddlewarePipeline:
|
||||
class TestClassBasedMiddleware:
|
||||
"""Test cases for class-based middleware implementations."""
|
||||
|
||||
async def test_agent_middleware_execution(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_agent_middleware_execution(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test class-based agent middleware execution."""
|
||||
metadata_updates: list[str] = []
|
||||
|
||||
@@ -825,7 +825,7 @@ class TestClassBasedMiddleware:
|
||||
class TestFunctionBasedMiddleware:
|
||||
"""Test cases for function-based middleware implementations."""
|
||||
|
||||
async def test_agent_function_middleware(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_agent_function_middleware(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test function-based agent middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@@ -879,7 +879,7 @@ class TestFunctionBasedMiddleware:
|
||||
class TestMixedMiddleware:
|
||||
"""Test cases for mixed class and function-based middleware."""
|
||||
|
||||
async def test_mixed_agent_middleware(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_mixed_agent_middleware(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test mixed class and function-based agent middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@@ -976,7 +976,7 @@ class TestMixedMiddleware:
|
||||
class TestMultipleMiddlewareOrdering:
|
||||
"""Test cases for multiple middleware execution order."""
|
||||
|
||||
async def test_agent_middleware_execution_order(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_agent_middleware_execution_order(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test that multiple agent middleware execute in registration order."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@@ -1110,7 +1110,7 @@ class TestMultipleMiddlewareOrdering:
|
||||
class TestContextContentValidation:
|
||||
"""Test cases for validating middleware context content."""
|
||||
|
||||
async def test_agent_context_validation(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_agent_context_validation(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test that agent context contains expected data."""
|
||||
|
||||
class ContextValidationMiddleware(AgentMiddleware):
|
||||
@@ -1231,7 +1231,7 @@ class TestContextContentValidation:
|
||||
class TestStreamingScenarios:
|
||||
"""Test cases for streaming and non-streaming scenarios."""
|
||||
|
||||
async def test_streaming_flag_validation(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_streaming_flag_validation(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test that stream flag is correctly set for streaming calls."""
|
||||
streaming_flags: list[bool] = []
|
||||
|
||||
@@ -1271,7 +1271,7 @@ class TestStreamingScenarios:
|
||||
# Verify flags: [non-streaming middleware, non-streaming handler, streaming middleware, streaming handler]
|
||||
assert streaming_flags == [False, False, True, True]
|
||||
|
||||
async def test_streaming_middleware_behavior(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_streaming_middleware_behavior(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test middleware behavior with streaming responses."""
|
||||
chunks_processed: list[str] = []
|
||||
|
||||
@@ -1437,7 +1437,7 @@ class MockFunctionArgs(BaseModel):
|
||||
class TestMiddlewareExecutionControl:
|
||||
"""Test cases for middleware execution control (when next() is called vs not called)."""
|
||||
|
||||
async def test_agent_middleware_no_next_no_execution(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_agent_middleware_no_next_no_execution(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test that when agent middleware doesn't call next(), no execution happens."""
|
||||
|
||||
class NoNextMiddleware(AgentMiddleware):
|
||||
@@ -1464,7 +1464,7 @@ class TestMiddlewareExecutionControl:
|
||||
assert not handler_called
|
||||
assert context.result is None
|
||||
|
||||
async def test_agent_middleware_no_next_no_streaming_execution(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_agent_middleware_no_next_no_streaming_execution(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test that when agent middleware doesn't call next(), no streaming execution happens."""
|
||||
|
||||
class NoNextStreamingMiddleware(AgentMiddleware):
|
||||
@@ -1529,7 +1529,7 @@ class TestMiddlewareExecutionControl:
|
||||
assert not handler_called
|
||||
assert context.result is None
|
||||
|
||||
async def test_multiple_middlewares_early_stop(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_multiple_middlewares_early_stop(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test that when first middleware doesn't call next(), subsequent middleware are not called."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@@ -1664,9 +1664,9 @@ class TestMiddlewareExecutionControl:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent() -> AgentProtocol:
|
||||
def mock_agent() -> SupportsAgentRun:
|
||||
"""Mock agent for testing."""
|
||||
agent = MagicMock(spec=AgentProtocol)
|
||||
agent = MagicMock(spec=SupportsAgentRun)
|
||||
agent.name = "test_agent"
|
||||
return agent
|
||||
|
||||
|
||||
@@ -8,13 +8,13 @@ import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from agent_framework import (
|
||||
AgentProtocol,
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
ChatAgent,
|
||||
ChatMessage,
|
||||
Content,
|
||||
ResponseStream,
|
||||
SupportsAgentRun,
|
||||
)
|
||||
from agent_framework._middleware import (
|
||||
AgentContext,
|
||||
@@ -38,7 +38,7 @@ class FunctionTestArgs(BaseModel):
|
||||
class TestResultOverrideMiddleware:
|
||||
"""Test cases for middleware result override functionality."""
|
||||
|
||||
async def test_agent_middleware_response_override_non_streaming(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_agent_middleware_response_override_non_streaming(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test that agent middleware can override response for non-streaming execution."""
|
||||
override_response = AgentResponse(messages=[ChatMessage(role="assistant", text="overridden response")])
|
||||
|
||||
@@ -69,7 +69,7 @@ class TestResultOverrideMiddleware:
|
||||
# Verify original handler was called since middleware called next()
|
||||
assert handler_called
|
||||
|
||||
async def test_agent_middleware_response_override_streaming(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_agent_middleware_response_override_streaming(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test that agent middleware can override response for streaming execution."""
|
||||
|
||||
async def override_stream() -> AsyncIterable[AgentResponseUpdate]:
|
||||
@@ -211,7 +211,7 @@ class TestResultOverrideMiddleware:
|
||||
assert normal_updates[0].text == "test streaming response "
|
||||
assert normal_updates[1].text == "another update"
|
||||
|
||||
async def test_agent_middleware_conditional_no_next(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_agent_middleware_conditional_no_next(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test that when agent middleware conditionally doesn't call next(), no execution happens."""
|
||||
|
||||
class ConditionalNoNextMiddleware(AgentMiddleware):
|
||||
@@ -303,7 +303,7 @@ class TestResultOverrideMiddleware:
|
||||
class TestResultObservability:
|
||||
"""Test cases for middleware result observability functionality."""
|
||||
|
||||
async def test_agent_middleware_response_observability(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_agent_middleware_response_observability(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test that middleware can observe response after execution."""
|
||||
observed_responses: list[AgentResponse] = []
|
||||
|
||||
@@ -370,7 +370,7 @@ class TestResultObservability:
|
||||
assert observed_results[0] == "executed function result"
|
||||
assert result == observed_results[0]
|
||||
|
||||
async def test_agent_middleware_post_execution_override(self, mock_agent: AgentProtocol) -> None:
|
||||
async def test_agent_middleware_post_execution_override(self, mock_agent: SupportsAgentRun) -> None:
|
||||
"""Test that middleware can override response after observing execution."""
|
||||
|
||||
class PostExecutionOverrideMiddleware(AgentMiddleware):
|
||||
@@ -436,9 +436,9 @@ class TestResultObservability:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent() -> AgentProtocol:
|
||||
def mock_agent() -> SupportsAgentRun:
|
||||
"""Mock agent for testing."""
|
||||
agent = MagicMock(spec=AgentProtocol)
|
||||
agent = MagicMock(spec=SupportsAgentRun)
|
||||
agent.name = "test_agent"
|
||||
return agent
|
||||
|
||||
|
||||
@@ -1853,13 +1853,13 @@ class TestChatAgentChatMiddleware:
|
||||
|
||||
|
||||
# class TestMiddlewareWithProtocolOnlyAgent:
|
||||
# """Test use_agent_middleware with agents implementing only AgentProtocol."""
|
||||
# """Test use_agent_middleware with agents implementing only SupportsAgentRun."""
|
||||
|
||||
# async def test_middleware_with_protocol_only_agent(self) -> None:
|
||||
# """Verify middleware works without BaseAgent inheritance for both run."""
|
||||
# from collections.abc import AsyncIterable
|
||||
|
||||
# from agent_framework import AgentProtocol, AgentResponse, AgentResponseUpdate
|
||||
# from agent_framework import SupportsAgentRun, AgentResponse, AgentResponseUpdate
|
||||
|
||||
# execution_order: list[str] = []
|
||||
|
||||
@@ -1873,7 +1873,7 @@ class TestChatAgentChatMiddleware:
|
||||
|
||||
# @use_agent_middleware
|
||||
# class ProtocolOnlyAgent:
|
||||
# """Minimal agent implementing only AgentProtocol, not inheriting from BaseAgent."""
|
||||
# """Minimal agent implementing only SupportsAgentRun, not inheriting from BaseAgent."""
|
||||
|
||||
# def __init__(self):
|
||||
# self.id = "protocol-only-agent"
|
||||
@@ -1896,7 +1896,7 @@ class TestChatAgentChatMiddleware:
|
||||
# return None
|
||||
|
||||
# agent = ProtocolOnlyAgent()
|
||||
# assert isinstance(agent, AgentProtocol)
|
||||
# assert isinstance(agent, SupportsAgentRun)
|
||||
|
||||
# # Test run (non-streaming)
|
||||
# response = await agent.run("test message")
|
||||
|
||||
@@ -12,7 +12,6 @@ from opentelemetry.trace import StatusCode
|
||||
|
||||
from agent_framework import (
|
||||
AGENT_FRAMEWORK_USER_AGENT,
|
||||
AgentProtocol,
|
||||
AgentResponse,
|
||||
BaseChatClient,
|
||||
ChatMessage,
|
||||
@@ -20,6 +19,7 @@ from agent_framework import (
|
||||
ChatResponseUpdate,
|
||||
Content,
|
||||
ResponseStream,
|
||||
SupportsAgentRun,
|
||||
UsageDetails,
|
||||
prepend_agent_framework_to_user_agent,
|
||||
tool,
|
||||
@@ -473,7 +473,7 @@ def mock_chat_agent():
|
||||
|
||||
@pytest.mark.parametrize("enable_sensitive_data", [True, False], indirect=True)
|
||||
async def test_agent_instrumentation_enabled(
|
||||
mock_chat_agent: AgentProtocol, span_exporter: InMemorySpanExporter, enable_sensitive_data
|
||||
mock_chat_agent: SupportsAgentRun, span_exporter: InMemorySpanExporter, enable_sensitive_data
|
||||
):
|
||||
"""Test that when agent diagnostics are enabled, telemetry is applied."""
|
||||
|
||||
@@ -499,7 +499,7 @@ async def test_agent_instrumentation_enabled(
|
||||
|
||||
@pytest.mark.parametrize("enable_sensitive_data", [True, False], indirect=True)
|
||||
async def test_agent_streaming_response_with_diagnostics_enabled(
|
||||
mock_chat_agent: AgentProtocol, span_exporter: InMemorySpanExporter, enable_sensitive_data
|
||||
mock_chat_agent: SupportsAgentRun, span_exporter: InMemorySpanExporter, enable_sensitive_data
|
||||
):
|
||||
"""Test agent streaming telemetry through the agent telemetry mixin."""
|
||||
agent = mock_chat_agent()
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing_extensions import Never
|
||||
|
||||
from agent_framework import (
|
||||
AgentExecutorRequest,
|
||||
AgentProtocol,
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentThread,
|
||||
@@ -18,6 +17,7 @@ from agent_framework import (
|
||||
Content,
|
||||
Executor,
|
||||
ResponseStream,
|
||||
SupportsAgentRun,
|
||||
UsageDetails,
|
||||
WorkflowAgent,
|
||||
WorkflowBuilder,
|
||||
@@ -615,7 +615,7 @@ class TestWorkflowAgent:
|
||||
async def test_agent_executor_output_response_false_filters_streaming_events(self):
|
||||
"""Test that AgentExecutor with output_response=False does not surface streaming events."""
|
||||
|
||||
class MockAgent(AgentProtocol):
|
||||
class MockAgent(SupportsAgentRun):
|
||||
"""Mock agent for testing."""
|
||||
|
||||
def __init__(self, name: str, response_text: str) -> None:
|
||||
@@ -705,7 +705,7 @@ class TestWorkflowAgent:
|
||||
async def test_agent_executor_output_response_no_duplicate_from_workflow_output_event(self):
|
||||
"""Test that AgentExecutor with output_response=True does not duplicate content."""
|
||||
|
||||
class MockAgent(AgentProtocol):
|
||||
class MockAgent(SupportsAgentRun):
|
||||
"""Mock agent for testing."""
|
||||
|
||||
def __init__(self, name: str, response_text: str) -> None:
|
||||
|
||||
@@ -422,7 +422,7 @@ def test_mixing_eager_and_lazy_initialization_error():
|
||||
ValueError,
|
||||
match=(
|
||||
r"Both source and target must be either registered factory names \(str\) "
|
||||
r"or Executor/AgentProtocol instances\."
|
||||
r"or Executor/SupportsAgentRun instances\."
|
||||
),
|
||||
):
|
||||
builder.add_edge(eager_executor, "Lazy")
|
||||
|
||||
@@ -17,8 +17,8 @@ from typing import Any, cast
|
||||
import yaml
|
||||
from agent_framework import (
|
||||
AgentExecutor,
|
||||
AgentProtocol,
|
||||
CheckpointStorage,
|
||||
SupportsAgentRun,
|
||||
Workflow,
|
||||
get_logger,
|
||||
)
|
||||
@@ -78,13 +78,13 @@ class WorkflowFactory:
|
||||
workflow = factory.create_workflow_from_yaml_path("workflow.yaml")
|
||||
"""
|
||||
|
||||
_agents: dict[str, AgentProtocol | AgentExecutor]
|
||||
_agents: dict[str, SupportsAgentRun | AgentExecutor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
agent_factory: AgentFactory | None = None,
|
||||
agents: Mapping[str, AgentProtocol | AgentExecutor] | None = None,
|
||||
agents: Mapping[str, SupportsAgentRun | AgentExecutor] | None = None,
|
||||
bindings: Mapping[str, Any] | None = None,
|
||||
env_file: str | None = None,
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
@@ -132,7 +132,7 @@ class WorkflowFactory:
|
||||
)
|
||||
"""
|
||||
self._agent_factory = agent_factory or AgentFactory(env_file_path=env_file)
|
||||
self._agents: dict[str, AgentProtocol | AgentExecutor] = dict(agents) if agents else {}
|
||||
self._agents: dict[str, SupportsAgentRun | AgentExecutor] = dict(agents) if agents else {}
|
||||
self._bindings: dict[str, Any] = dict(bindings) if bindings else {}
|
||||
self._checkpoint_storage = checkpoint_storage
|
||||
|
||||
@@ -323,7 +323,7 @@ class WorkflowFactory:
|
||||
description = workflow_def.get("description")
|
||||
|
||||
# Create agents from definitions
|
||||
agents: dict[str, AgentProtocol | AgentExecutor] = dict(self._agents)
|
||||
agents: dict[str, SupportsAgentRun | AgentExecutor] = dict(self._agents)
|
||||
agent_defs = workflow_def.get("agents", {})
|
||||
|
||||
for agent_name, agent_def in agent_defs.items():
|
||||
@@ -347,7 +347,7 @@ class WorkflowFactory:
|
||||
workflow_def: dict[str, Any],
|
||||
name: str,
|
||||
description: str | None,
|
||||
agents: dict[str, AgentProtocol | AgentExecutor],
|
||||
agents: dict[str, SupportsAgentRun | AgentExecutor],
|
||||
) -> Workflow:
|
||||
"""Create workflow from definition.
|
||||
|
||||
@@ -506,7 +506,7 @@ class WorkflowFactory:
|
||||
f"Invalid agent definition. Expected 'file', 'kind', or 'connection': {agent_def}"
|
||||
)
|
||||
|
||||
def register_agent(self, name: str, agent: AgentProtocol | AgentExecutor) -> "WorkflowFactory":
|
||||
def register_agent(self, name: str, agent: SupportsAgentRun | AgentExecutor) -> "WorkflowFactory":
|
||||
"""Register an agent instance with the factory for use in workflows.
|
||||
|
||||
Registered agents are available to InvokeAzureAgent actions by name.
|
||||
|
||||
@@ -757,11 +757,11 @@ class EntityDiscovery:
|
||||
True if object appears to be a valid agent
|
||||
"""
|
||||
try:
|
||||
# Try to import AgentProtocol for proper type checking
|
||||
# Try to import SupportsAgentRun for proper type checking
|
||||
try:
|
||||
from agent_framework import AgentProtocol
|
||||
from agent_framework import SupportsAgentRun
|
||||
|
||||
if isinstance(obj, AgentProtocol):
|
||||
if isinstance(obj, SupportsAgentRun):
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -7,7 +7,7 @@ import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import AgentProtocol, Content, Workflow
|
||||
from agent_framework import Content, SupportsAgentRun, Workflow
|
||||
|
||||
from ._conversations import ConversationStore, InMemoryConversationStore
|
||||
from ._discovery import EntityDiscovery
|
||||
@@ -285,7 +285,7 @@ class AgentFrameworkExecutor:
|
||||
yield {"type": "error", "message": str(e), "entity_id": entity_id}
|
||||
|
||||
async def _execute_agent(
|
||||
self, agent: AgentProtocol, request: AgentFrameworkRequest, trace_collector: Any
|
||||
self, agent: SupportsAgentRun, request: AgentFrameworkRequest, trace_collector: Any
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
"""Execute Agent Framework agent with trace collection and optional thread support.
|
||||
|
||||
|
||||
@@ -9,12 +9,12 @@ from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework import (
|
||||
AgentProtocol,
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
ChatMessage,
|
||||
Content,
|
||||
ResponseStream,
|
||||
SupportsAgentRun,
|
||||
get_logger,
|
||||
)
|
||||
from durabletask.entities import DurableEntity
|
||||
@@ -86,12 +86,12 @@ class AgentEntity:
|
||||
This class encapsulates the core logic for executing an agent within a durable entity context.
|
||||
"""
|
||||
|
||||
agent: AgentProtocol
|
||||
agent: SupportsAgentRun
|
||||
callback: AgentResponseCallbackProtocol | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
agent: SupportsAgentRun,
|
||||
callback: AgentResponseCallbackProtocol | None = None,
|
||||
*,
|
||||
state_provider: AgentEntityStateProviderMixin,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
"""Durable Agent Shim for Durable Task Framework.
|
||||
|
||||
This module provides the DurableAIAgent shim that implements AgentProtocol
|
||||
This module provides the DurableAIAgent shim that implements SupportsAgentRun
|
||||
and provides a consistent interface for both Client and Orchestration contexts.
|
||||
The actual execution is delegated to the context-specific providers.
|
||||
"""
|
||||
@@ -12,7 +12,7 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic, Literal, TypeVar
|
||||
|
||||
from agent_framework import AgentProtocol, AgentThread, ChatMessage
|
||||
from agent_framework import AgentThread, ChatMessage, SupportsAgentRun
|
||||
|
||||
from ._executors import DurableAgentExecutor
|
||||
from ._models import DurableAgentThread
|
||||
@@ -47,11 +47,11 @@ class DurableAgentProvider(ABC, Generic[TaskT]):
|
||||
raise NotImplementedError("Subclasses must implement get_agent()")
|
||||
|
||||
|
||||
class DurableAIAgent(AgentProtocol, Generic[TaskT]):
|
||||
class DurableAIAgent(SupportsAgentRun, Generic[TaskT]):
|
||||
"""A durable agent proxy that delegates execution to the provider.
|
||||
|
||||
This class implements AgentProtocol but with one critical difference:
|
||||
- AgentProtocol.run() returns a Coroutine (async, must await)
|
||||
This class implements SupportsAgentRun but with one critical difference:
|
||||
- SupportsAgentRun.run() returns a Coroutine (async, must await)
|
||||
- DurableAIAgent.run() returns TaskT (sync Task object - must yield
|
||||
or the AgentResponse directly in the case of TaskHubGrpcClient)
|
||||
|
||||
@@ -104,8 +104,8 @@ class DurableAIAgent(AgentProtocol, Generic[TaskT]):
|
||||
Additional keys are forwarded to the agent execution.
|
||||
|
||||
Note:
|
||||
This method overrides AgentProtocol.run() with a different return type:
|
||||
- AgentProtocol.run() returns Coroutine[Any, Any, AgentResponse] (async)
|
||||
This method overrides SupportsAgentRun.run() with a different return type:
|
||||
- SupportsAgentRun.run() returns Coroutine[Any, Any, AgentResponse] (async)
|
||||
- DurableAIAgent.run() returns TaskT (Task object for yielding)
|
||||
|
||||
This is intentional to support orchestration contexts that use yield patterns
|
||||
|
||||
@@ -11,7 +11,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import AgentProtocol, get_logger
|
||||
from agent_framework import SupportsAgentRun, get_logger
|
||||
from durabletask.worker import TaskHubGrpcWorker
|
||||
|
||||
from ._callbacks import AgentResponseCallbackProtocol
|
||||
@@ -60,12 +60,12 @@ class DurableAIAgentWorker:
|
||||
"""
|
||||
self._worker = worker
|
||||
self._callback = callback
|
||||
self._registered_agents: dict[str, AgentProtocol] = {}
|
||||
self._registered_agents: dict[str, SupportsAgentRun] = {}
|
||||
logger.debug("[DurableAIAgentWorker] Initialized with worker type: %s", type(worker).__name__)
|
||||
|
||||
def add_agent(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
agent: SupportsAgentRun,
|
||||
callback: AgentResponseCallbackProtocol | None = None,
|
||||
) -> None:
|
||||
"""Register an agent with the worker.
|
||||
@@ -139,7 +139,7 @@ class DurableAIAgentWorker:
|
||||
|
||||
def __create_agent_entity(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
agent: SupportsAgentRun,
|
||||
callback: AgentResponseCallbackProtocol | None = None,
|
||||
) -> type[DurableTaskEntityStateProvider]:
|
||||
"""Factory function to create a DurableEntity class configured with an agent.
|
||||
|
||||
@@ -9,7 +9,7 @@ Run with: pytest tests/test_client.py -v
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from agent_framework import AgentProtocol
|
||||
from agent_framework import SupportsAgentRun
|
||||
|
||||
from agent_framework_durabletask import DurableAgentThread, DurableAIAgentClient
|
||||
from agent_framework_durabletask._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS
|
||||
@@ -46,7 +46,7 @@ class TestDurableAIAgentClientGetAgent:
|
||||
agent = agent_client.get_agent("assistant")
|
||||
|
||||
assert isinstance(agent, DurableAIAgent)
|
||||
assert isinstance(agent, AgentProtocol)
|
||||
assert isinstance(agent, SupportsAgentRun)
|
||||
|
||||
def test_get_agent_shim_has_correct_name(self, agent_client: DurableAIAgentClient) -> None:
|
||||
"""Verify retrieved agent has the correct name."""
|
||||
|
||||
@@ -9,7 +9,7 @@ Run with: pytest tests/test_orchestration_context.py -v
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from agent_framework import AgentProtocol
|
||||
from agent_framework import SupportsAgentRun
|
||||
|
||||
from agent_framework_durabletask import DurableAgentThread
|
||||
from agent_framework_durabletask._orchestration_context import DurableAIAgentOrchestrationContext
|
||||
@@ -36,7 +36,7 @@ class TestDurableAIAgentOrchestrationContextGetAgent:
|
||||
agent = agent_context.get_agent("assistant")
|
||||
|
||||
assert isinstance(agent, DurableAIAgent)
|
||||
assert isinstance(agent, AgentProtocol)
|
||||
assert isinstance(agent, SupportsAgentRun)
|
||||
|
||||
def test_get_agent_shim_has_correct_name(self, agent_context: DurableAIAgentOrchestrationContext) -> None:
|
||||
"""Verify retrieved agent has the correct name."""
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from agent_framework import AgentProtocol, ChatMessage
|
||||
from agent_framework import ChatMessage, SupportsAgentRun
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agent_framework_durabletask import DurableAgentThread
|
||||
@@ -142,15 +142,15 @@ class TestDurableAIAgentParameterFlow:
|
||||
assert kwargs["run_request"].response_format == ResponseFormatModel
|
||||
|
||||
|
||||
class TestDurableAIAgentProtocolCompliance:
|
||||
"""Test that DurableAIAgent implements AgentProtocol correctly."""
|
||||
class TestDurableAISupportsAgentRunCompliance:
|
||||
"""Test that DurableAIAgent implements SupportsAgentRun correctly."""
|
||||
|
||||
def test_agent_implements_protocol(self, test_agent: DurableAIAgent[Any]) -> None:
|
||||
"""Verify DurableAIAgent implements AgentProtocol."""
|
||||
assert isinstance(test_agent, AgentProtocol)
|
||||
"""Verify DurableAIAgent implements SupportsAgentRun."""
|
||||
assert isinstance(test_agent, SupportsAgentRun)
|
||||
|
||||
def test_agent_has_required_properties(self, test_agent: DurableAIAgent[Any]) -> None:
|
||||
"""Verify DurableAIAgent has all required AgentProtocol properties."""
|
||||
"""Verify DurableAIAgent has all required SupportsAgentRun properties."""
|
||||
assert hasattr(test_agent, "id")
|
||||
assert hasattr(test_agent, "name")
|
||||
assert hasattr(test_agent, "display_name")
|
||||
|
||||
@@ -6,7 +6,7 @@ import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import AgentProtocol, ChatMessage
|
||||
from agent_framework import ChatMessage, SupportsAgentRun
|
||||
from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse
|
||||
from agent_framework._workflows._agent_utils import resolve_agent_id
|
||||
from agent_framework._workflows._checkpoint import CheckpointStorage
|
||||
@@ -29,8 +29,8 @@ parallel workflow with:
|
||||
- a default aggregator that combines all agent conversations and completes the workflow
|
||||
|
||||
Notes:
|
||||
- Participants can be provided as AgentProtocol or Executor instances via `.participants()`,
|
||||
or as factories returning AgentProtocol or Executor via `.register_participants()`.
|
||||
- Participants can be provided as SupportsAgentRun or Executor instances via `.participants()`,
|
||||
or as factories returning SupportsAgentRun or Executor via `.register_participants()`.
|
||||
- A custom aggregator can be provided as:
|
||||
- an Executor instance (it should handle list[AgentExecutorResponse],
|
||||
yield output), or
|
||||
@@ -186,8 +186,8 @@ class _CallbackAggregator(Executor):
|
||||
class ConcurrentBuilder:
|
||||
r"""High-level builder for concurrent agent workflows.
|
||||
|
||||
- `participants([...])` accepts a list of AgentProtocol (recommended) or Executor.
|
||||
- `register_participants([...])` accepts a list of factories for AgentProtocol (recommended)
|
||||
- `participants([...])` accepts a list of SupportsAgentRun (recommended) or Executor.
|
||||
- `register_participants([...])` accepts a list of factories for SupportsAgentRun (recommended)
|
||||
or Executor factories
|
||||
- `build()` wires: dispatcher -> fan-out -> participants -> fan-in -> aggregator.
|
||||
- `with_aggregator(...)` overrides the default aggregator with an Executor or callback.
|
||||
@@ -238,8 +238,8 @@ class ConcurrentBuilder:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._participants: list[AgentProtocol | Executor] = []
|
||||
self._participant_factories: list[Callable[[], AgentProtocol | Executor]] = []
|
||||
self._participants: list[SupportsAgentRun | Executor] = []
|
||||
self._participant_factories: list[Callable[[], SupportsAgentRun | Executor]] = []
|
||||
self._aggregator: Executor | None = None
|
||||
self._aggregator_factory: Callable[[], Executor] | None = None
|
||||
self._checkpoint_storage: CheckpointStorage | None = None
|
||||
@@ -249,16 +249,16 @@ class ConcurrentBuilder:
|
||||
|
||||
def register_participants(
|
||||
self,
|
||||
participant_factories: Sequence[Callable[[], AgentProtocol | Executor]],
|
||||
participant_factories: Sequence[Callable[[], SupportsAgentRun | Executor]],
|
||||
) -> "ConcurrentBuilder":
|
||||
r"""Define the parallel participants for this concurrent workflow.
|
||||
|
||||
Accepts factories (callables) that return AgentProtocol instances (e.g., created
|
||||
Accepts factories (callables) that return SupportsAgentRun instances (e.g., created
|
||||
by a chat client) or Executor instances. Each participant created by a factory
|
||||
is wired as a parallel branch using fan-out edges from an internal dispatcher.
|
||||
|
||||
Args:
|
||||
participant_factories: Sequence of callables returning AgentProtocol or Executor instances
|
||||
participant_factories: Sequence of callables returning SupportsAgentRun or Executor instances
|
||||
|
||||
Raises:
|
||||
ValueError: if `participant_factories` is empty or `.participants()`
|
||||
@@ -300,20 +300,20 @@ class ConcurrentBuilder:
|
||||
self._participant_factories = list(participant_factories)
|
||||
return self
|
||||
|
||||
def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "ConcurrentBuilder":
|
||||
def participants(self, participants: Sequence[SupportsAgentRun | Executor]) -> "ConcurrentBuilder":
|
||||
r"""Define the parallel participants for this concurrent workflow.
|
||||
|
||||
Accepts AgentProtocol instances (e.g., created by a chat client) or Executor
|
||||
Accepts SupportsAgentRun instances (e.g., created by a chat client) or Executor
|
||||
instances. Each participant is wired as a parallel branch using fan-out edges
|
||||
from an internal dispatcher.
|
||||
|
||||
Args:
|
||||
participants: Sequence of AgentProtocol or Executor instances
|
||||
participants: Sequence of SupportsAgentRun or Executor instances
|
||||
|
||||
Raises:
|
||||
ValueError: if `participants` is empty, contains duplicates, or `.register_participants()`
|
||||
or `.participants()` were already called
|
||||
TypeError: if any entry is not AgentProtocol or Executor
|
||||
TypeError: if any entry is not SupportsAgentRun or Executor
|
||||
|
||||
Example:
|
||||
|
||||
@@ -341,13 +341,13 @@ class ConcurrentBuilder:
|
||||
if p.id in seen_executor_ids:
|
||||
raise ValueError(f"Duplicate executor participant detected: id '{p.id}'")
|
||||
seen_executor_ids.add(p.id)
|
||||
elif isinstance(p, AgentProtocol):
|
||||
elif isinstance(p, SupportsAgentRun):
|
||||
pid = id(p)
|
||||
if pid in seen_agent_ids:
|
||||
raise ValueError("Duplicate agent participant detected (same agent instance provided twice)")
|
||||
seen_agent_ids.add(pid)
|
||||
else:
|
||||
raise TypeError(f"participants must be AgentProtocol or Executor instances; got {type(p).__name__}")
|
||||
raise TypeError(f"participants must be SupportsAgentRun or Executor instances; got {type(p).__name__}")
|
||||
|
||||
self._participants = list(participants)
|
||||
return self
|
||||
@@ -459,7 +459,7 @@ class ConcurrentBuilder:
|
||||
def with_request_info(
|
||||
self,
|
||||
*,
|
||||
agents: Sequence[str | AgentProtocol] | None = None,
|
||||
agents: Sequence[str | SupportsAgentRun] | None = None,
|
||||
) -> "ConcurrentBuilder":
|
||||
"""Enable request info after agent participant responses.
|
||||
|
||||
@@ -508,7 +508,7 @@ class ConcurrentBuilder:
|
||||
raise ValueError("No participants provided. Call .participants() or .register_participants() first.")
|
||||
# We don't need to check if both are set since that is handled in the respective methods
|
||||
|
||||
participants: list[Executor | AgentProtocol] = []
|
||||
participants: list[Executor | SupportsAgentRun] = []
|
||||
if self._participant_factories:
|
||||
# Resolve the participant factories now. This doesn't break the factory pattern
|
||||
# since the Sequential builder still creates new instances per workflow build.
|
||||
@@ -522,7 +522,7 @@ class ConcurrentBuilder:
|
||||
for p in participants:
|
||||
if isinstance(p, Executor):
|
||||
executors.append(p)
|
||||
elif isinstance(p, AgentProtocol):
|
||||
elif isinstance(p, SupportsAgentRun):
|
||||
if self._request_info_enabled and (
|
||||
not self._request_info_filter or resolve_agent_id(p) in self._request_info_filter
|
||||
):
|
||||
@@ -531,7 +531,7 @@ class ConcurrentBuilder:
|
||||
else:
|
||||
executors.append(AgentExecutor(p))
|
||||
else:
|
||||
raise TypeError(f"Participants must be AgentProtocol or Executor instances. Got {type(p).__name__}.")
|
||||
raise TypeError(f"Participants must be SupportsAgentRun or Executor instances. Got {type(p).__name__}.")
|
||||
|
||||
return executors
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from collections.abc import Awaitable, Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, cast, overload
|
||||
|
||||
from agent_framework import AgentProtocol, ChatAgent
|
||||
from agent_framework import ChatAgent, SupportsAgentRun
|
||||
from agent_framework._threads import AgentThread
|
||||
from agent_framework._types import ChatMessage
|
||||
from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse
|
||||
@@ -523,8 +523,8 @@ class GroupChatBuilder:
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the GroupChatBuilder."""
|
||||
self._participants: dict[str, AgentProtocol | Executor] = {}
|
||||
self._participant_factories: list[Callable[[], AgentProtocol | Executor]] = []
|
||||
self._participants: dict[str, SupportsAgentRun | Executor] = {}
|
||||
self._participant_factories: list[Callable[[], SupportsAgentRun | Executor]] = []
|
||||
|
||||
# Orchestrator related members
|
||||
self._orchestrator: BaseGroupChatOrchestrator | None = None
|
||||
@@ -683,13 +683,13 @@ class GroupChatBuilder:
|
||||
|
||||
def register_participants(
|
||||
self,
|
||||
participant_factories: Sequence[Callable[[], AgentProtocol | Executor]],
|
||||
participant_factories: Sequence[Callable[[], SupportsAgentRun | Executor]],
|
||||
) -> "GroupChatBuilder":
|
||||
"""Register participant factories for this group chat workflow.
|
||||
|
||||
Args:
|
||||
participant_factories: Sequence of callables that produce participant definitions
|
||||
when invoked. Each callable should return either an AgentProtocol instance
|
||||
when invoked. Each callable should return either an SupportsAgentRun instance
|
||||
(auto-wrapped as AgentExecutor) or an Executor instance.
|
||||
|
||||
Returns:
|
||||
@@ -711,10 +711,10 @@ class GroupChatBuilder:
|
||||
self._participant_factories = list(participant_factories)
|
||||
return self
|
||||
|
||||
def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "GroupChatBuilder":
|
||||
def participants(self, participants: Sequence[SupportsAgentRun | Executor]) -> "GroupChatBuilder":
|
||||
"""Define participants for this group chat workflow.
|
||||
|
||||
Accepts AgentProtocol instances (auto-wrapped as AgentExecutor) or Executor instances.
|
||||
Accepts SupportsAgentRun instances (auto-wrapped as AgentExecutor) or Executor instances.
|
||||
|
||||
Args:
|
||||
participants: Sequence of participant definitions
|
||||
@@ -725,7 +725,7 @@ class GroupChatBuilder:
|
||||
Raises:
|
||||
ValueError: If participants are empty, names are duplicated, or participants
|
||||
or participant factories are already set
|
||||
TypeError: If any participant is not AgentProtocol or Executor instance
|
||||
TypeError: If any participant is not SupportsAgentRun or Executor instance
|
||||
|
||||
Example:
|
||||
|
||||
@@ -750,17 +750,17 @@ class GroupChatBuilder:
|
||||
raise ValueError("participants cannot be empty.")
|
||||
|
||||
# Name of the executor mapped to participant instance
|
||||
named: dict[str, AgentProtocol | Executor] = {}
|
||||
named: dict[str, SupportsAgentRun | Executor] = {}
|
||||
for participant in participants:
|
||||
if isinstance(participant, Executor):
|
||||
identifier = participant.id
|
||||
elif isinstance(participant, AgentProtocol):
|
||||
elif isinstance(participant, SupportsAgentRun):
|
||||
if not participant.name:
|
||||
raise ValueError("AgentProtocol participants must have a non-empty name.")
|
||||
raise ValueError("SupportsAgentRun participants must have a non-empty name.")
|
||||
identifier = participant.name
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}."
|
||||
f"Participants must be SupportsAgentRun or Executor instances. Got {type(participant).__name__}."
|
||||
)
|
||||
|
||||
if identifier in named:
|
||||
@@ -861,7 +861,7 @@ class GroupChatBuilder:
|
||||
self._checkpoint_storage = checkpoint_storage
|
||||
return self
|
||||
|
||||
def with_request_info(self, *, agents: Sequence[str | AgentProtocol] | None = None) -> "GroupChatBuilder":
|
||||
def with_request_info(self, *, agents: Sequence[str | SupportsAgentRun] | None = None) -> "GroupChatBuilder":
|
||||
"""Enable request info after agent participant responses.
|
||||
|
||||
This enables human-in-the-loop (HIL) scenarios for the group chat orchestration.
|
||||
@@ -962,7 +962,7 @@ class GroupChatBuilder:
|
||||
raise ValueError("No participants provided. Call .participants() or .register_participants() first.")
|
||||
# We don't need to check if both are set since that is handled in the respective methods
|
||||
|
||||
participants: list[Executor | AgentProtocol] = []
|
||||
participants: list[Executor | SupportsAgentRun] = []
|
||||
if self._participant_factories:
|
||||
for factory in self._participant_factories:
|
||||
participant = factory()
|
||||
@@ -974,7 +974,7 @@ class GroupChatBuilder:
|
||||
for participant in participants:
|
||||
if isinstance(participant, Executor):
|
||||
executors.append(participant)
|
||||
elif isinstance(participant, AgentProtocol):
|
||||
elif isinstance(participant, SupportsAgentRun):
|
||||
if self._request_info_enabled and (
|
||||
not self._request_info_filter or resolve_agent_id(participant) in self._request_info_filter
|
||||
):
|
||||
@@ -984,7 +984,7 @@ class GroupChatBuilder:
|
||||
executors.append(AgentExecutor(participant))
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}."
|
||||
f"Participants must be SupportsAgentRun or Executor instances. Got {type(participant).__name__}."
|
||||
)
|
||||
|
||||
return executors
|
||||
|
||||
@@ -36,7 +36,7 @@ from collections.abc import Awaitable, Callable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework import AgentProtocol, ChatAgent
|
||||
from agent_framework import ChatAgent, SupportsAgentRun
|
||||
from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware
|
||||
from agent_framework._threads import AgentThread
|
||||
from agent_framework._tools import FunctionTool, tool
|
||||
@@ -89,14 +89,14 @@ class HandoffConfiguration:
|
||||
target_id: str
|
||||
description: str | None = None
|
||||
|
||||
def __init__(self, *, target: str | AgentProtocol, description: str | None = None) -> None:
|
||||
def __init__(self, *, target: str | SupportsAgentRun, description: str | None = None) -> None:
|
||||
"""Initialize HandoffConfiguration.
|
||||
|
||||
Args:
|
||||
target: Target agent identifier or AgentProtocol instance
|
||||
target: Target agent identifier or SupportsAgentRun instance
|
||||
description: Optional human-readable description of the handoff
|
||||
"""
|
||||
self.target_id = resolve_agent_id(target) if isinstance(target, AgentProtocol) else target
|
||||
self.target_id = resolve_agent_id(target) if isinstance(target, SupportsAgentRun) else target
|
||||
self.description = description
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
@@ -193,7 +193,7 @@ class HandoffAgentExecutor(AgentExecutor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
agent: SupportsAgentRun,
|
||||
handoffs: Sequence[HandoffConfiguration],
|
||||
*,
|
||||
agent_thread: AgentThread | None = None,
|
||||
@@ -236,9 +236,9 @@ class HandoffAgentExecutor(AgentExecutor):
|
||||
|
||||
def _prepare_agent_with_handoffs(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
agent: SupportsAgentRun,
|
||||
handoffs: Sequence[HandoffConfiguration],
|
||||
) -> AgentProtocol:
|
||||
) -> SupportsAgentRun:
|
||||
"""Prepare an agent by adding handoff tools for the specified target agents.
|
||||
|
||||
Args:
|
||||
@@ -574,8 +574,8 @@ class HandoffBuilder:
|
||||
self,
|
||||
*,
|
||||
name: str | None = None,
|
||||
participants: Sequence[AgentProtocol] | None = None,
|
||||
participant_factories: Mapping[str, Callable[[], AgentProtocol]] | None = None,
|
||||
participants: Sequence[SupportsAgentRun] | None = None,
|
||||
participant_factories: Mapping[str, Callable[[], SupportsAgentRun]] | None = None,
|
||||
description: str | None = None,
|
||||
) -> None:
|
||||
r"""Initialize a HandoffBuilder for creating conversational handoff workflows.
|
||||
@@ -604,8 +604,8 @@ class HandoffBuilder:
|
||||
self._description = description
|
||||
|
||||
# Participant related members
|
||||
self._participants: dict[str, AgentProtocol] = {}
|
||||
self._participant_factories: dict[str, Callable[[], AgentProtocol]] = {}
|
||||
self._participants: dict[str, SupportsAgentRun] = {}
|
||||
self._participant_factories: dict[str, Callable[[], SupportsAgentRun]] = {}
|
||||
self._start_id: str | None = None
|
||||
if participant_factories:
|
||||
self.register_participants(participant_factories)
|
||||
@@ -629,16 +629,16 @@ class HandoffBuilder:
|
||||
self._termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] | None = None
|
||||
|
||||
def register_participants(
|
||||
self, participant_factories: Mapping[str, Callable[[], AgentProtocol]]
|
||||
self, participant_factories: Mapping[str, Callable[[], SupportsAgentRun]]
|
||||
) -> "HandoffBuilder":
|
||||
"""Register factories that produce agents for the handoff workflow.
|
||||
|
||||
Each factory is a callable that returns an AgentProtocol instance.
|
||||
Each factory is a callable that returns an SupportsAgentRun instance.
|
||||
Factories are invoked when building the workflow, allowing for lazy instantiation
|
||||
and state isolation per workflow instance.
|
||||
|
||||
Args:
|
||||
participant_factories: Mapping of factory names to callables that return AgentProtocol
|
||||
participant_factories: Mapping of factory names to callables that return SupportsAgentRun
|
||||
instances. Each produced participant must have a unique identifier
|
||||
(`.name` is preferred if set, otherwise `.id` is used).
|
||||
|
||||
@@ -690,11 +690,11 @@ class HandoffBuilder:
|
||||
self._participant_factories = dict(participant_factories)
|
||||
return self
|
||||
|
||||
def participants(self, participants: Sequence[AgentProtocol]) -> "HandoffBuilder":
|
||||
def participants(self, participants: Sequence[SupportsAgentRun]) -> "HandoffBuilder":
|
||||
"""Register the agents that will participate in the handoff workflow.
|
||||
|
||||
Args:
|
||||
participants: Sequence of AgentProtocol instances. Each must have a unique identifier.
|
||||
participants: Sequence of SupportsAgentRun instances. Each must have a unique identifier.
|
||||
(`.name` is preferred if set, otherwise `.id` is used).
|
||||
|
||||
Returns:
|
||||
@@ -703,7 +703,7 @@ class HandoffBuilder:
|
||||
Raises:
|
||||
ValueError: If participants is empty, contains duplicates, or `.participants()` or
|
||||
`.register_participants()` has already been called.
|
||||
TypeError: If participants are not AgentProtocol instances.
|
||||
TypeError: If participants are not SupportsAgentRun instances.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -729,13 +729,13 @@ class HandoffBuilder:
|
||||
if not participants:
|
||||
raise ValueError("participants cannot be empty")
|
||||
|
||||
named: dict[str, AgentProtocol] = {}
|
||||
named: dict[str, SupportsAgentRun] = {}
|
||||
for participant in participants:
|
||||
if isinstance(participant, AgentProtocol):
|
||||
if isinstance(participant, SupportsAgentRun):
|
||||
resolved_id = self._resolve_to_id(participant)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}."
|
||||
f"Participants must be SupportsAgentRun or Executor instances. Got {type(participant).__name__}."
|
||||
)
|
||||
|
||||
if resolved_id in named:
|
||||
@@ -748,8 +748,8 @@ class HandoffBuilder:
|
||||
|
||||
def add_handoff(
|
||||
self,
|
||||
source: str | AgentProtocol,
|
||||
targets: Sequence[str] | Sequence[AgentProtocol],
|
||||
source: str | SupportsAgentRun,
|
||||
targets: Sequence[str] | Sequence[SupportsAgentRun],
|
||||
*,
|
||||
description: str | None = None,
|
||||
) -> "HandoffBuilder":
|
||||
@@ -763,11 +763,11 @@ class HandoffBuilder:
|
||||
Args:
|
||||
source: The agent that can initiate the handoff. Can be:
|
||||
- Factory name (str): If using participant factories
|
||||
- AgentProtocol instance: The actual agent object
|
||||
- SupportsAgentRun instance: The actual agent object
|
||||
- Cannot mix factory names and instances across source and targets
|
||||
targets: One or more target agents that the source can hand off to. Can be:
|
||||
- Factory name (str): If using participant factories
|
||||
- AgentProtocol instance: The actual agent object
|
||||
- SupportsAgentRun instance: The actual agent object
|
||||
- Single target: ["billing_agent"] or [agent_instance]
|
||||
- Multiple targets: ["billing_agent", "support_agent"] or [agent1, agent2]
|
||||
- Cannot mix factory names and instances across source and targets
|
||||
@@ -786,7 +786,7 @@ class HandoffBuilder:
|
||||
participants(...) hasn't been called yet.
|
||||
2) If source or targets are factory names (str) but participant_factories(...)
|
||||
hasn't been called yet, or if they are not in the participant_factories list.
|
||||
TypeError: If mixing factory names (str) and AgentProtocol/Executor instances
|
||||
TypeError: If mixing factory names (str) and SupportsAgentRun/Executor instances
|
||||
|
||||
Examples:
|
||||
Single target (using factory name):
|
||||
@@ -848,7 +848,7 @@ class HandoffBuilder:
|
||||
self._handoff_config[source].add(HandoffConfiguration(target=t, description=description))
|
||||
return self
|
||||
|
||||
if isinstance(source, (AgentProtocol)) and all(isinstance(t, AgentProtocol) for t in targets):
|
||||
if isinstance(source, (SupportsAgentRun)) and all(isinstance(t, SupportsAgentRun) for t in targets):
|
||||
# Both source and targets are instances
|
||||
if not self._participants:
|
||||
raise ValueError("Call participants(...) before add_handoff(...)")
|
||||
@@ -881,10 +881,10 @@ class HandoffBuilder:
|
||||
return self
|
||||
|
||||
raise TypeError(
|
||||
"Cannot mix factory names (str) and AgentProtocol instances across source and targets in add_handoff()"
|
||||
"Cannot mix factory names (str) and SupportsAgentRun instances across source and targets in add_handoff()"
|
||||
)
|
||||
|
||||
def with_start_agent(self, agent: str | AgentProtocol) -> "HandoffBuilder":
|
||||
def with_start_agent(self, agent: str | SupportsAgentRun) -> "HandoffBuilder":
|
||||
"""Set the agent that will initiate the handoff workflow.
|
||||
|
||||
If not specified, the first registered participant will be used as the starting agent.
|
||||
@@ -892,7 +892,7 @@ class HandoffBuilder:
|
||||
Args:
|
||||
agent: The agent that will start the workflow. Can be:
|
||||
- Factory name (str): If using participant factories
|
||||
- AgentProtocol instance: The actual agent object
|
||||
- SupportsAgentRun instance: The actual agent object
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
"""
|
||||
@@ -903,7 +903,7 @@ class HandoffBuilder:
|
||||
else:
|
||||
raise ValueError("Call register_participants(...) before with_start_agent(...)")
|
||||
self._start_id = agent
|
||||
elif isinstance(agent, AgentProtocol):
|
||||
elif isinstance(agent, SupportsAgentRun):
|
||||
resolved_id = self._resolve_to_id(agent)
|
||||
if self._participants:
|
||||
if resolved_id not in self._participants:
|
||||
@@ -912,14 +912,14 @@ class HandoffBuilder:
|
||||
raise ValueError("Call participants(...) before with_start_agent(...)")
|
||||
self._start_id = resolved_id
|
||||
else:
|
||||
raise TypeError("Start agent must be a factory name (str) or an AgentProtocol instance")
|
||||
raise TypeError("Start agent must be a factory name (str) or an SupportsAgentRun instance")
|
||||
|
||||
return self
|
||||
|
||||
def with_autonomous_mode(
|
||||
self,
|
||||
*,
|
||||
agents: Sequence[AgentProtocol] | Sequence[str] | None = None,
|
||||
agents: Sequence[SupportsAgentRun] | Sequence[str] | None = None,
|
||||
prompts: dict[str, str] | None = None,
|
||||
turn_limits: dict[str, int] | None = None,
|
||||
) -> "HandoffBuilder":
|
||||
@@ -933,7 +933,7 @@ class HandoffBuilder:
|
||||
Args:
|
||||
agents: Optional list of agents to enable autonomous mode for. Can be:
|
||||
- Factory names (str): If using participant factories
|
||||
- AgentProtocol instances: The actual agent objects
|
||||
- SupportsAgentRun instances: The actual agent objects
|
||||
- If not provided, all agents will operate in autonomous mode.
|
||||
prompts: Optional mapping of agent identifiers/factory names to custom prompts to use when continuing
|
||||
in autonomous mode. If not provided, a default prompt will be used.
|
||||
@@ -1084,7 +1084,7 @@ class HandoffBuilder:
|
||||
|
||||
# region Internal Helper Methods
|
||||
|
||||
def _resolve_agents(self) -> dict[str, AgentProtocol]:
|
||||
def _resolve_agents(self) -> dict[str, SupportsAgentRun]:
|
||||
"""Resolve participant factories into agent instances.
|
||||
|
||||
If agent instances were provided directly via participants(...), those are
|
||||
@@ -1092,7 +1092,7 @@ class HandoffBuilder:
|
||||
those are invoked to create the agent instances.
|
||||
|
||||
Returns:
|
||||
Map of executor IDs or factory names to `AgentProtocol` instances
|
||||
Map of executor IDs or factory names to `SupportsAgentRun` instances
|
||||
"""
|
||||
if not self._participants and not self._participant_factories:
|
||||
raise ValueError("No participants provided. Call .participants() or .register_participants() first.")
|
||||
@@ -1103,13 +1103,13 @@ class HandoffBuilder:
|
||||
|
||||
if self._participant_factories:
|
||||
# Invoke each factory to create participant instances
|
||||
factory_names_to_agents: dict[str, AgentProtocol] = {}
|
||||
factory_names_to_agents: dict[str, SupportsAgentRun] = {}
|
||||
for factory_name, factory in self._participant_factories.items():
|
||||
instance = factory()
|
||||
if isinstance(instance, AgentProtocol):
|
||||
if isinstance(instance, SupportsAgentRun):
|
||||
resolved_id = self._resolve_to_id(instance)
|
||||
else:
|
||||
raise TypeError(f"Participants must be AgentProtocol instances. Got {type(instance).__name__}.")
|
||||
raise TypeError(f"Participants must be SupportsAgentRun instances. Got {type(instance).__name__}.")
|
||||
|
||||
if resolved_id in factory_names_to_agents:
|
||||
raise ValueError(f"Duplicate participant name '{resolved_id}' detected")
|
||||
@@ -1122,11 +1122,11 @@ class HandoffBuilder:
|
||||
|
||||
raise ValueError("No executors or participant_factories have been configured")
|
||||
|
||||
def _resolve_handoffs(self, agents: Mapping[str, AgentProtocol]) -> dict[str, list[HandoffConfiguration]]:
|
||||
def _resolve_handoffs(self, agents: Mapping[str, SupportsAgentRun]) -> dict[str, list[HandoffConfiguration]]:
|
||||
"""Handoffs may be specified using factory names or instances; resolve to executor IDs.
|
||||
|
||||
Args:
|
||||
agents: Map of agent IDs or factory names to `AgentProtocol` instances
|
||||
agents: Map of agent IDs or factory names to `SupportsAgentRun` instances
|
||||
|
||||
Returns:
|
||||
Map of executor IDs to list of HandoffConfiguration instances
|
||||
@@ -1173,13 +1173,13 @@ class HandoffBuilder:
|
||||
|
||||
def _resolve_executors(
|
||||
self,
|
||||
agents: dict[str, AgentProtocol],
|
||||
agents: dict[str, SupportsAgentRun],
|
||||
handoffs: dict[str, list[HandoffConfiguration]],
|
||||
) -> dict[str, HandoffAgentExecutor]:
|
||||
"""Resolve agents into HandoffAgentExecutors.
|
||||
|
||||
Args:
|
||||
agents: Map of agent IDs or factory names to `AgentProtocol` instances
|
||||
agents: Map of agent IDs or factory names to `SupportsAgentRun` instances
|
||||
handoffs: Map of executor IDs to list of HandoffConfiguration instances
|
||||
|
||||
Returns:
|
||||
@@ -1213,9 +1213,9 @@ class HandoffBuilder:
|
||||
|
||||
return executors
|
||||
|
||||
def _resolve_to_id(self, candidate: str | AgentProtocol) -> str:
|
||||
def _resolve_to_id(self, candidate: str | SupportsAgentRun) -> str:
|
||||
"""Resolve a participant reference into a concrete executor identifier."""
|
||||
if isinstance(candidate, AgentProtocol):
|
||||
if isinstance(candidate, SupportsAgentRun):
|
||||
return resolve_agent_id(candidate)
|
||||
if isinstance(candidate, str):
|
||||
return candidate
|
||||
|
||||
@@ -13,9 +13,9 @@ from enum import Enum
|
||||
from typing import Any, ClassVar, TypeVar, cast, overload
|
||||
|
||||
from agent_framework import (
|
||||
AgentProtocol,
|
||||
AgentResponse,
|
||||
ChatMessage,
|
||||
SupportsAgentRun,
|
||||
)
|
||||
from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse
|
||||
from agent_framework._workflows._checkpoint import CheckpointStorage
|
||||
@@ -521,7 +521,7 @@ class StandardMagenticManager(MagenticManagerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
agent: SupportsAgentRun,
|
||||
task_ledger: _MagenticTaskLedger | None = None,
|
||||
*,
|
||||
task_ledger_facts_prompt: str | None = None,
|
||||
@@ -562,7 +562,7 @@ class StandardMagenticManager(MagenticManagerBase):
|
||||
max_round_count=max_round_count,
|
||||
)
|
||||
|
||||
self._agent: AgentProtocol = agent
|
||||
self._agent: SupportsAgentRun = agent
|
||||
self.task_ledger: _MagenticTaskLedger | None = task_ledger
|
||||
|
||||
# Prompts may be overridden if needed
|
||||
@@ -1311,10 +1311,10 @@ class MagenticOrchestrator(BaseGroupChatOrchestrator):
|
||||
class MagenticAgentExecutor(AgentExecutor):
|
||||
"""Specialized AgentExecutor for Magentic agent participants."""
|
||||
|
||||
def __init__(self, agent: AgentProtocol) -> None:
|
||||
def __init__(self, agent: SupportsAgentRun) -> None:
|
||||
"""Initialize a Magentic Agent Executor.
|
||||
|
||||
This executor wraps an AgentProtocol instance to be used as a participant
|
||||
This executor wraps an SupportsAgentRun instance to be used as a participant
|
||||
in a Magentic One workflow.
|
||||
|
||||
Args:
|
||||
@@ -1377,13 +1377,13 @@ class MagenticBuilder:
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the Magentic workflow builder."""
|
||||
self._participants: dict[str, AgentProtocol | Executor] = {}
|
||||
self._participant_factories: list[Callable[[], AgentProtocol | Executor]] = []
|
||||
self._participants: dict[str, SupportsAgentRun | Executor] = {}
|
||||
self._participant_factories: list[Callable[[], SupportsAgentRun | Executor]] = []
|
||||
|
||||
# Manager related members
|
||||
self._manager: MagenticManagerBase | None = None
|
||||
self._manager_factory: Callable[[], MagenticManagerBase] | None = None
|
||||
self._manager_agent_factory: Callable[[], AgentProtocol] | None = None
|
||||
self._manager_agent_factory: Callable[[], SupportsAgentRun] | None = None
|
||||
self._standard_manager_options: dict[str, Any] = {}
|
||||
self._enable_plan_review: bool = False
|
||||
|
||||
@@ -1394,12 +1394,12 @@ class MagenticBuilder:
|
||||
|
||||
def register_participants(
|
||||
self,
|
||||
participant_factories: Sequence[Callable[[], AgentProtocol | Executor]],
|
||||
participant_factories: Sequence[Callable[[], SupportsAgentRun | Executor]],
|
||||
) -> "MagenticBuilder":
|
||||
"""Register participant factories for this Magentic workflow.
|
||||
|
||||
Args:
|
||||
participant_factories: Sequence of callables that return AgentProtocol or Executor instances.
|
||||
participant_factories: Sequence of callables that return SupportsAgentRun or Executor instances.
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
@@ -1420,10 +1420,10 @@ class MagenticBuilder:
|
||||
self._participant_factories = list(participant_factories)
|
||||
return self
|
||||
|
||||
def participants(self, participants: Sequence[AgentProtocol | Executor]) -> Self:
|
||||
def participants(self, participants: Sequence[SupportsAgentRun | Executor]) -> Self:
|
||||
"""Define participants for this Magentic workflow.
|
||||
|
||||
Accepts AgentProtocol instances (auto-wrapped as AgentExecutor) or Executor instances.
|
||||
Accepts SupportsAgentRun instances (auto-wrapped as AgentExecutor) or Executor instances.
|
||||
|
||||
Args:
|
||||
participants: Sequence of participant definitions
|
||||
@@ -1434,7 +1434,7 @@ class MagenticBuilder:
|
||||
Raises:
|
||||
ValueError: If participants are empty, names are duplicated, or participants
|
||||
or participant factories are already set
|
||||
TypeError: If any participant is not AgentProtocol or Executor instance
|
||||
TypeError: If any participant is not SupportsAgentRun or Executor instance
|
||||
|
||||
Example:
|
||||
|
||||
@@ -1462,17 +1462,17 @@ class MagenticBuilder:
|
||||
raise ValueError("participants cannot be empty.")
|
||||
|
||||
# Name of the executor mapped to participant instance
|
||||
named: dict[str, AgentProtocol | Executor] = {}
|
||||
named: dict[str, SupportsAgentRun | Executor] = {}
|
||||
for participant in participants:
|
||||
if isinstance(participant, Executor):
|
||||
identifier = participant.id
|
||||
elif isinstance(participant, AgentProtocol):
|
||||
elif isinstance(participant, SupportsAgentRun):
|
||||
if not participant.name:
|
||||
raise ValueError("AgentProtocol participants must have a non-empty name.")
|
||||
raise ValueError("SupportsAgentRun participants must have a non-empty name.")
|
||||
identifier = participant.name
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}."
|
||||
f"Participants must be SupportsAgentRun or Executor instances. Got {type(participant).__name__}."
|
||||
)
|
||||
|
||||
if identifier in named:
|
||||
@@ -1608,7 +1608,7 @@ class MagenticBuilder:
|
||||
def with_manager(
|
||||
self,
|
||||
*,
|
||||
agent: AgentProtocol,
|
||||
agent: SupportsAgentRun,
|
||||
task_ledger: _MagenticTaskLedger | None = None,
|
||||
# Prompt overrides
|
||||
task_ledger_facts_prompt: str | None = None,
|
||||
@@ -1628,7 +1628,7 @@ class MagenticBuilder:
|
||||
This will create a StandardMagenticManager using the provided agent.
|
||||
|
||||
Args:
|
||||
agent: AgentProtocol instance for the standard magentic manager
|
||||
agent: SupportsAgentRun instance for the standard magentic manager
|
||||
(`StandardMagenticManager`)
|
||||
task_ledger: Optional custom task ledger implementation for specialized
|
||||
prompting or structured output requirements
|
||||
@@ -1661,7 +1661,7 @@ class MagenticBuilder:
|
||||
def with_manager(
|
||||
self,
|
||||
*,
|
||||
agent_factory: Callable[[], AgentProtocol],
|
||||
agent_factory: Callable[[], SupportsAgentRun],
|
||||
task_ledger: _MagenticTaskLedger | None = None,
|
||||
# Prompt overrides
|
||||
task_ledger_facts_prompt: str | None = None,
|
||||
@@ -1681,7 +1681,7 @@ class MagenticBuilder:
|
||||
This will create a StandardMagenticManager using the provided agent factory.
|
||||
|
||||
Args:
|
||||
agent_factory: Callable that returns a new AgentProtocol instance for the standard
|
||||
agent_factory: Callable that returns a new SupportsAgentRun instance for the standard
|
||||
magentic manager (`StandardMagenticManager`)
|
||||
task_ledger: Optional custom task ledger implementation for specialized
|
||||
prompting or structured output requirements
|
||||
@@ -1715,9 +1715,9 @@ class MagenticBuilder:
|
||||
*,
|
||||
manager: MagenticManagerBase | None = None,
|
||||
manager_factory: Callable[[], MagenticManagerBase] | None = None,
|
||||
agent_factory: Callable[[], AgentProtocol] | None = None,
|
||||
agent_factory: Callable[[], SupportsAgentRun] | None = None,
|
||||
# Constructor args for StandardMagenticManager when manager is not provided
|
||||
agent: AgentProtocol | None = None,
|
||||
agent: SupportsAgentRun | None = None,
|
||||
task_ledger: _MagenticTaskLedger | None = None,
|
||||
# Prompt overrides
|
||||
task_ledger_facts_prompt: str | None = None,
|
||||
@@ -1956,7 +1956,7 @@ class MagenticBuilder:
|
||||
raise ValueError("No participants provided. Call .participants() or .register_participants() first.")
|
||||
# We don't need to check if both are set since that is handled in the respective methods
|
||||
|
||||
participants: list[Executor | AgentProtocol] = []
|
||||
participants: list[Executor | SupportsAgentRun] = []
|
||||
if self._participant_factories:
|
||||
for factory in self._participant_factories:
|
||||
participant = factory()
|
||||
@@ -1968,11 +1968,11 @@ class MagenticBuilder:
|
||||
for participant in participants:
|
||||
if isinstance(participant, Executor):
|
||||
executors.append(participant)
|
||||
elif isinstance(participant, AgentProtocol):
|
||||
elif isinstance(participant, SupportsAgentRun):
|
||||
executors.append(MagenticAgentExecutor(participant))
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}."
|
||||
f"Participants must be SupportsAgentRun or Executor instances. Got {type(participant).__name__}."
|
||||
)
|
||||
|
||||
return executors
|
||||
|
||||
+6
-6
@@ -2,7 +2,7 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agent_framework._agents import AgentProtocol
|
||||
from agent_framework._agents import SupportsAgentRun
|
||||
from agent_framework._types import ChatMessage
|
||||
from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse
|
||||
from agent_framework._workflows._agent_utils import resolve_agent_id
|
||||
@@ -14,11 +14,11 @@ from agent_framework._workflows._workflow_context import WorkflowContext
|
||||
from agent_framework._workflows._workflow_executor import WorkflowExecutor
|
||||
|
||||
|
||||
def resolve_request_info_filter(agents: list[str | AgentProtocol] | None) -> set[str]:
|
||||
def resolve_request_info_filter(agents: list[str | SupportsAgentRun] | None) -> set[str]:
|
||||
"""Resolve a list of agent/executor references to a set of IDs for filtering.
|
||||
|
||||
Args:
|
||||
agents: List of agent names (str), AgentProtocol instances, or Executor instances.
|
||||
agents: List of agent names (str), SupportsAgentRun instances, or Executor instances.
|
||||
If None, returns None (meaning no filtering - pause for all).
|
||||
|
||||
Returns:
|
||||
@@ -31,7 +31,7 @@ def resolve_request_info_filter(agents: list[str | AgentProtocol] | None) -> set
|
||||
for agent in agents:
|
||||
if isinstance(agent, str):
|
||||
result.add(agent)
|
||||
elif isinstance(agent, AgentProtocol):
|
||||
elif isinstance(agent, SupportsAgentRun):
|
||||
result.add(resolve_agent_id(agent))
|
||||
else:
|
||||
raise TypeError(f"Unsupported type for request_info filter: {type(agent).__name__}")
|
||||
@@ -117,7 +117,7 @@ class AgentApprovalExecutor(WorkflowExecutor):
|
||||
agent's output or send the final response to down stream executors in the orchestration.
|
||||
"""
|
||||
|
||||
def __init__(self, agent: AgentProtocol) -> None:
|
||||
def __init__(self, agent: SupportsAgentRun) -> None:
|
||||
"""Initialize the AgentApprovalExecutor.
|
||||
|
||||
Args:
|
||||
@@ -126,7 +126,7 @@ class AgentApprovalExecutor(WorkflowExecutor):
|
||||
super().__init__(workflow=self._build_workflow(agent), id=resolve_agent_id(agent), propagate_request=True)
|
||||
self._description = agent.description
|
||||
|
||||
def _build_workflow(self, agent: AgentProtocol) -> Workflow:
|
||||
def _build_workflow(self, agent: SupportsAgentRun) -> Workflow:
|
||||
"""Build the internal workflow for the AgentApprovalExecutor."""
|
||||
agent_executor = AgentExecutor(agent)
|
||||
request_info_executor = AgentRequestInfoExecutor(id="agent_request_info_executor")
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
|
||||
This module provides a high-level, agent-focused API to assemble a sequential
|
||||
workflow where:
|
||||
- Participants can be provided as AgentProtocol or Executor instances via `.participants()`,
|
||||
or as factories returning AgentProtocol or Executor via `.register_participants()`
|
||||
- Participants can be provided as SupportsAgentRun or Executor instances via `.participants()`,
|
||||
or as factories returning SupportsAgentRun or Executor via `.register_participants()`
|
||||
- A shared conversation context (list[ChatMessage]) is passed along the chain
|
||||
- Agents append their assistant messages to the context
|
||||
- Custom executors can transform or summarize and return a refined context
|
||||
@@ -15,7 +15,7 @@ Typical wiring:
|
||||
input -> _InputToConversation -> participant1 -> (agent? -> _ResponseToConversation) -> ... -> participantN -> _EndWithConversation
|
||||
|
||||
Notes:
|
||||
- Participants can mix AgentProtocol and Executor objects
|
||||
- Participants can mix SupportsAgentRun and Executor objects
|
||||
- Agents are auto-wrapped by WorkflowBuilder as AgentExecutor (unless already wrapped)
|
||||
- AgentExecutor produces AgentExecutorResponse; _ResponseToConversation converts this to list[ChatMessage]
|
||||
- Non-agent executors must define a handler that consumes `list[ChatMessage]` and sends back
|
||||
@@ -41,7 +41,7 @@ import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import AgentProtocol, ChatMessage
|
||||
from agent_framework import ChatMessage, SupportsAgentRun
|
||||
from agent_framework._workflows._agent_executor import (
|
||||
AgentExecutor,
|
||||
AgentExecutorResponse,
|
||||
@@ -109,8 +109,8 @@ class _EndWithConversation(Executor):
|
||||
class SequentialBuilder:
|
||||
r"""High-level builder for sequential agent/executor workflows with shared context.
|
||||
|
||||
- `participants([...])` accepts a list of AgentProtocol (recommended) or Executor instances
|
||||
- `register_participants([...])` accepts a list of factories for AgentProtocol (recommended)
|
||||
- `participants([...])` accepts a list of SupportsAgentRun (recommended) or Executor instances
|
||||
- `register_participants([...])` accepts a list of factories for SupportsAgentRun (recommended)
|
||||
or Executor factories
|
||||
- Executors must define a handler that consumes list[ChatMessage] and sends out a list[ChatMessage]
|
||||
- The workflow wires participants in order, passing a list[ChatMessage] down the chain
|
||||
@@ -148,8 +148,8 @@ class SequentialBuilder:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._participants: list[AgentProtocol | Executor] = []
|
||||
self._participant_factories: list[Callable[[], AgentProtocol | Executor]] = []
|
||||
self._participants: list[SupportsAgentRun | Executor] = []
|
||||
self._participant_factories: list[Callable[[], SupportsAgentRun | Executor]] = []
|
||||
self._checkpoint_storage: CheckpointStorage | None = None
|
||||
self._request_info_enabled: bool = False
|
||||
self._request_info_filter: set[str] | None = None
|
||||
@@ -157,7 +157,7 @@ class SequentialBuilder:
|
||||
|
||||
def register_participants(
|
||||
self,
|
||||
participant_factories: Sequence[Callable[[], AgentProtocol | Executor]],
|
||||
participant_factories: Sequence[Callable[[], SupportsAgentRun | Executor]],
|
||||
) -> "SequentialBuilder":
|
||||
"""Register participant factories for this sequential workflow."""
|
||||
if self._participants:
|
||||
@@ -172,10 +172,10 @@ class SequentialBuilder:
|
||||
self._participant_factories = list(participant_factories)
|
||||
return self
|
||||
|
||||
def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "SequentialBuilder":
|
||||
def participants(self, participants: Sequence[SupportsAgentRun | Executor]) -> "SequentialBuilder":
|
||||
"""Define the ordered participants for this sequential workflow.
|
||||
|
||||
Accepts AgentProtocol instances (auto-wrapped as AgentExecutor) or Executor instances.
|
||||
Accepts SupportsAgentRun instances (auto-wrapped as AgentExecutor) or Executor instances.
|
||||
Raises if empty or duplicates are provided for clarity.
|
||||
"""
|
||||
if self._participant_factories:
|
||||
@@ -196,7 +196,7 @@ class SequentialBuilder:
|
||||
raise ValueError(f"Duplicate executor participant detected: id '{p.id}'")
|
||||
seen_executor_ids.add(p.id)
|
||||
else:
|
||||
# Treat non-Executor as agent-like (AgentProtocol). Structural checks can be brittle at runtime.
|
||||
# Treat non-Executor as agent-like (SupportsAgentRun). Structural checks can be brittle at runtime.
|
||||
pid = id(p)
|
||||
if pid in seen_agent_ids:
|
||||
raise ValueError("Duplicate agent participant detected (same agent instance provided twice)")
|
||||
@@ -213,7 +213,7 @@ class SequentialBuilder:
|
||||
def with_request_info(
|
||||
self,
|
||||
*,
|
||||
agents: Sequence[str | AgentProtocol] | None = None,
|
||||
agents: Sequence[str | SupportsAgentRun] | None = None,
|
||||
) -> "SequentialBuilder":
|
||||
"""Enable request info after agent participant responses.
|
||||
|
||||
@@ -262,7 +262,7 @@ class SequentialBuilder:
|
||||
raise ValueError("No participants provided. Call .participants() or .register_participants() first.")
|
||||
# We don't need to check if both are set since that is handled in the respective methods
|
||||
|
||||
participants: list[Executor | AgentProtocol] = []
|
||||
participants: list[Executor | SupportsAgentRun] = []
|
||||
if self._participant_factories:
|
||||
# Resolve the participant factories now. This doesn't break the factory pattern
|
||||
# since the Sequential builder still creates new instances per workflow build.
|
||||
@@ -276,7 +276,7 @@ class SequentialBuilder:
|
||||
for p in participants:
|
||||
if isinstance(p, Executor):
|
||||
executors.append(p)
|
||||
elif isinstance(p, AgentProtocol):
|
||||
elif isinstance(p, SupportsAgentRun):
|
||||
if self._request_info_enabled and (
|
||||
not self._request_info_filter or resolve_agent_id(p) in self._request_info_filter
|
||||
):
|
||||
@@ -285,7 +285,7 @@ class SequentialBuilder:
|
||||
else:
|
||||
executors.append(AgentExecutor(p))
|
||||
else:
|
||||
raise TypeError(f"Participants must be AgentProtocol or Executor instances. Got {type(p).__name__}.")
|
||||
raise TypeError(f"Participants must be SupportsAgentRun or Executor instances. Got {type(p).__name__}.")
|
||||
|
||||
return executors
|
||||
|
||||
@@ -312,7 +312,7 @@ class SequentialBuilder:
|
||||
builder.set_start_executor(input_conv)
|
||||
|
||||
# Start of the chain is the input normalizer
|
||||
prior: Executor | AgentProtocol = input_conv
|
||||
prior: Executor | SupportsAgentRun = input_conv
|
||||
for p in participants:
|
||||
builder.add_edge(prior, p)
|
||||
prior = p
|
||||
|
||||
@@ -320,7 +320,7 @@ class TestGroupChatBuilder:
|
||||
|
||||
builder = GroupChatBuilder().with_orchestrator(selection_func=selector)
|
||||
|
||||
with pytest.raises(ValueError, match="AgentProtocol participants must have a non-empty name"):
|
||||
with pytest.raises(ValueError, match="SupportsAgentRun participants must have a non-empty name"):
|
||||
builder.participants([agent])
|
||||
|
||||
def test_empty_participant_name_raises_error(self) -> None:
|
||||
@@ -332,7 +332,7 @@ class TestGroupChatBuilder:
|
||||
|
||||
builder = GroupChatBuilder().with_orchestrator(selection_func=selector)
|
||||
|
||||
with pytest.raises(ValueError, match="AgentProtocol participants must have a non-empty name"):
|
||||
with pytest.raises(ValueError, match="SupportsAgentRun participants must have a non-empty name"):
|
||||
builder.participants([agent])
|
||||
|
||||
|
||||
|
||||
@@ -420,7 +420,7 @@ def test_handoff_builder_rejects_mixed_types_in_add_handoff_source():
|
||||
triage = MockHandoffAgent(name="triage")
|
||||
specialist = MockHandoffAgent(name="specialist")
|
||||
|
||||
with pytest.raises(TypeError, match="Cannot mix factory names \\(str\\) and AgentProtocol.*instances"):
|
||||
with pytest.raises(TypeError, match="Cannot mix factory names \\(str\\) and SupportsAgentRun.*instances"):
|
||||
(
|
||||
HandoffBuilder(participants=[triage, specialist])
|
||||
.with_start_agent(triage)
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Any, ClassVar, cast
|
||||
|
||||
import pytest
|
||||
from agent_framework import (
|
||||
AgentProtocol,
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentThread,
|
||||
@@ -15,6 +14,7 @@ from agent_framework import (
|
||||
ChatMessage,
|
||||
Content,
|
||||
Executor,
|
||||
SupportsAgentRun,
|
||||
Workflow,
|
||||
WorkflowCheckpoint,
|
||||
WorkflowCheckpointException,
|
||||
@@ -576,7 +576,7 @@ class StubAssistantsAgent(BaseAgent):
|
||||
)
|
||||
|
||||
|
||||
async def _collect_agent_responses_setup(participant: AgentProtocol) -> list[ChatMessage]:
|
||||
async def _collect_agent_responses_setup(participant: SupportsAgentRun) -> list[ChatMessage]:
|
||||
captured: list[ChatMessage] = []
|
||||
|
||||
wf = (
|
||||
@@ -1121,10 +1121,10 @@ async def test_magentic_with_agent_factory():
|
||||
"""Test workflow creation using agent_factory for StandardMagenticManager."""
|
||||
factory_call_count = 0
|
||||
|
||||
def agent_factory() -> AgentProtocol:
|
||||
def agent_factory() -> SupportsAgentRun:
|
||||
nonlocal factory_call_count
|
||||
factory_call_count += 1
|
||||
return cast(AgentProtocol, StubManagerAgent())
|
||||
return cast(SupportsAgentRun, StubManagerAgent())
|
||||
|
||||
participant = StubAgent("agentA", "reply from agentA")
|
||||
workflow = (
|
||||
@@ -1239,10 +1239,10 @@ def test_magentic_agent_factory_with_standard_manager_options():
|
||||
"""Test that agent_factory properly passes through standard manager options."""
|
||||
factory_call_count = 0
|
||||
|
||||
def agent_factory() -> AgentProtocol:
|
||||
def agent_factory() -> SupportsAgentRun:
|
||||
nonlocal factory_call_count
|
||||
factory_call_count += 1
|
||||
return cast(AgentProtocol, StubManagerAgent())
|
||||
return cast(SupportsAgentRun, StubManagerAgent())
|
||||
|
||||
# Custom options to verify they are passed through
|
||||
custom_max_stall_count = 5
|
||||
|
||||
@@ -8,11 +8,11 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from agent_framework import (
|
||||
AgentProtocol,
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentThread,
|
||||
ChatMessage,
|
||||
SupportsAgentRun,
|
||||
)
|
||||
from agent_framework._workflows._agent_executor import AgentExecutorRequest, AgentExecutorResponse
|
||||
from agent_framework._workflows._workflow_context import WorkflowContext
|
||||
@@ -44,10 +44,10 @@ class TestResolveRequestInfoFilter:
|
||||
assert result == {"agent1", "agent2"}
|
||||
|
||||
def test_resolves_agent_display_names(self):
|
||||
"""Test resolving AgentProtocol instances by name attribute."""
|
||||
agent1 = MagicMock(spec=AgentProtocol)
|
||||
"""Test resolving SupportsAgentRun instances by name attribute."""
|
||||
agent1 = MagicMock(spec=SupportsAgentRun)
|
||||
agent1.name = "writer"
|
||||
agent2 = MagicMock(spec=AgentProtocol)
|
||||
agent2 = MagicMock(spec=SupportsAgentRun)
|
||||
agent2.name = "reviewer"
|
||||
|
||||
result = resolve_request_info_filter([agent1, agent2])
|
||||
@@ -55,7 +55,7 @@ class TestResolveRequestInfoFilter:
|
||||
|
||||
def test_mixed_types(self):
|
||||
"""Test resolving a mix of strings and agents."""
|
||||
agent = MagicMock(spec=AgentProtocol)
|
||||
agent = MagicMock(spec=SupportsAgentRun)
|
||||
agent.name = "writer"
|
||||
|
||||
result = resolve_request_info_filter(["manual_name", agent])
|
||||
|
||||
@@ -131,7 +131,7 @@ sequenceDiagram
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `agent` | `AgentProtocol` | The agent being invoked |
|
||||
| `agent` | `SupportsAgentRun` | The agent being invoked |
|
||||
| `messages` | `list[ChatMessage]` | Input messages (mutable) |
|
||||
| `thread` | `AgentThread \| None` | Conversation thread |
|
||||
| `options` | `Mapping[str, Any]` | Chat options dict |
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import AgentProtocol, AgentResponse, AgentThread, ChatMessage, HostedMCPTool
|
||||
from agent_framework import SupportsAgentRun, AgentResponse, AgentThread, ChatMessage, HostedMCPTool
|
||||
from agent_framework.azure import AzureAIProjectAgentProvider
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
|
||||
@@ -14,7 +14,7 @@ This sample demonstrates integrating hosted Model Context Protocol (MCP) tools w
|
||||
"""
|
||||
|
||||
|
||||
async def handle_approvals_without_thread(query: str, agent: "AgentProtocol") -> AgentResponse:
|
||||
async def handle_approvals_without_thread(query: str, agent: "SupportsAgentRun") -> AgentResponse:
|
||||
"""When we don't have a thread, we need to ensure we return with the input, approval request and approval."""
|
||||
|
||||
result = await agent.run(query, store=False)
|
||||
@@ -35,7 +35,7 @@ async def handle_approvals_without_thread(query: str, agent: "AgentProtocol") ->
|
||||
return result
|
||||
|
||||
|
||||
async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", thread: "AgentThread") -> AgentResponse:
|
||||
async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", thread: "AgentThread") -> AgentResponse:
|
||||
"""Here we let the thread deal with the previous responses, and we just rerun with the approval."""
|
||||
|
||||
result = await agent.run(query, thread=thread)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import AgentProtocol, AgentResponse, AgentThread, HostedMCPTool
|
||||
from agent_framework import SupportsAgentRun, AgentResponse, AgentThread, HostedMCPTool
|
||||
from agent_framework.azure import AzureAIAgentsProvider
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
|
||||
@@ -15,7 +15,7 @@ servers, including user approval workflows for function call security.
|
||||
"""
|
||||
|
||||
|
||||
async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", thread: "AgentThread") -> AgentResponse:
|
||||
async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", thread: "AgentThread") -> AgentResponse:
|
||||
"""Here we let the thread deal with the previous responses, and we just rerun with the approval."""
|
||||
from agent_framework import ChatMessage
|
||||
|
||||
|
||||
+2
-2
@@ -5,7 +5,7 @@ from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import (
|
||||
AgentProtocol,
|
||||
SupportsAgentRun,
|
||||
AgentThread,
|
||||
HostedMCPTool,
|
||||
HostedWebSearchTool,
|
||||
@@ -43,7 +43,7 @@ def get_time() -> str:
|
||||
return f"The current UTC time is {current_time.strftime('%Y-%m-%d %H:%M:%S')}."
|
||||
|
||||
|
||||
async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", thread: "AgentThread"):
|
||||
async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", thread: "AgentThread"):
|
||||
"""Here we let the thread deal with the previous responses, and we just rerun with the approval."""
|
||||
from agent_framework import ChatMessage
|
||||
|
||||
|
||||
+4
-4
@@ -15,10 +15,10 @@ Azure OpenAI Responses Client, including user approval workflows for function ca
|
||||
"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent_framework import AgentProtocol, AgentThread
|
||||
from agent_framework import SupportsAgentRun, AgentThread
|
||||
|
||||
|
||||
async def handle_approvals_without_thread(query: str, agent: "AgentProtocol"):
|
||||
async def handle_approvals_without_thread(query: str, agent: "SupportsAgentRun"):
|
||||
"""When we don't have a thread, we need to ensure we return with the input, approval request and approval."""
|
||||
from agent_framework import ChatMessage
|
||||
|
||||
@@ -40,7 +40,7 @@ async def handle_approvals_without_thread(query: str, agent: "AgentProtocol"):
|
||||
return result
|
||||
|
||||
|
||||
async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", thread: "AgentThread"):
|
||||
async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", thread: "AgentThread"):
|
||||
"""Here we let the thread deal with the previous responses, and we just rerun with the approval."""
|
||||
from agent_framework import ChatMessage
|
||||
|
||||
@@ -63,7 +63,7 @@ async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", threa
|
||||
return result
|
||||
|
||||
|
||||
async def handle_approvals_with_thread_streaming(query: str, agent: "AgentProtocol", thread: "AgentThread"):
|
||||
async def handle_approvals_with_thread_streaming(query: str, agent: "SupportsAgentRun", thread: "AgentThread"):
|
||||
"""Here we let the thread deal with the previous responses, and we just rerun with the approval."""
|
||||
from agent_framework import ChatMessage
|
||||
|
||||
|
||||
+4
-4
@@ -14,10 +14,10 @@ OpenAI Responses Client, including user approval workflows for function call sec
|
||||
"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent_framework import AgentProtocol, AgentThread
|
||||
from agent_framework import SupportsAgentRun, AgentThread
|
||||
|
||||
|
||||
async def handle_approvals_without_thread(query: str, agent: "AgentProtocol"):
|
||||
async def handle_approvals_without_thread(query: str, agent: "SupportsAgentRun"):
|
||||
"""When we don't have a thread, we need to ensure we return with the input, approval request and approval."""
|
||||
from agent_framework import ChatMessage
|
||||
|
||||
@@ -39,7 +39,7 @@ async def handle_approvals_without_thread(query: str, agent: "AgentProtocol"):
|
||||
return result
|
||||
|
||||
|
||||
async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", thread: "AgentThread"):
|
||||
async def handle_approvals_with_thread(query: str, agent: "SupportsAgentRun", thread: "AgentThread"):
|
||||
"""Here we let the thread deal with the previous responses, and we just rerun with the approval."""
|
||||
from agent_framework import ChatMessage
|
||||
|
||||
@@ -62,7 +62,7 @@ async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", threa
|
||||
return result
|
||||
|
||||
|
||||
async def handle_approvals_with_thread_streaming(query: str, agent: "AgentProtocol", thread: "AgentThread"):
|
||||
async def handle_approvals_with_thread_streaming(query: str, agent: "SupportsAgentRun", thread: "AgentThread"):
|
||||
"""Here we let the thread deal with the previous responses, and we just rerun with the approval."""
|
||||
from agent_framework import ChatMessage
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ async def main() -> None:
|
||||
)
|
||||
|
||||
# 2) Build a concurrent workflow
|
||||
# Participants are either Agents (type of AgentProtocol) or Executors
|
||||
# Participants are either Agents (type of SupportsAgentRun) or Executors
|
||||
workflow = ConcurrentBuilder().participants([researcher, marketer, legal]).build()
|
||||
|
||||
# 3) Run with a single prompt and pretty-print the final combined messages
|
||||
|
||||
@@ -80,7 +80,7 @@ async def main() -> None:
|
||||
return response.messages[-1].text if response.messages else ""
|
||||
|
||||
# Build with a custom aggregator callback function
|
||||
# - participants([...]) accepts AgentProtocol (agents) or Executor instances.
|
||||
# - participants([...]) accepts SupportsAgentRun (agents) or Executor instances.
|
||||
# Each participant becomes a parallel branch (fan-out) from an internal dispatcher.
|
||||
# - with_aggregator(...) overrides the default aggregator:
|
||||
# • Default aggregator -> returns list[ChatMessage] (one user + one assistant per agent)
|
||||
|
||||
@@ -122,7 +122,7 @@ async def run_workflow(workflow: Workflow, query: str) -> None:
|
||||
async def main() -> None:
|
||||
# Create a concurrent builder with participant factories and a custom aggregator
|
||||
# - register_participants([...]) accepts factory functions that return
|
||||
# AgentProtocol (agents) or Executor instances.
|
||||
# SupportsAgentRun (agents) or Executor instances.
|
||||
# - register_aggregator(...) takes a factory function that returns an Executor instance.
|
||||
concurrent_builder = (
|
||||
ConcurrentBuilder()
|
||||
|
||||
@@ -51,8 +51,6 @@ async def main() -> None:
|
||||
"You are a Researcher. You find information without additional computation or quantitative analysis."
|
||||
),
|
||||
# This agent requires the gpt-4o-search-preview model to perform web searches.
|
||||
# Feel free to explore with other agents that support web search, for example,
|
||||
# the `OpenAIResponseAgent` or `AzureAgentProtocol` with bing grounding.
|
||||
chat_client=OpenAIChatClient(model_id="gpt-4o-search-preview"),
|
||||
)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from agent_framework import AgentResponse, ChatAgent, ChatMessage, tool
|
||||
from agent_framework.openai import OpenAIResponsesClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent_framework import AgentProtocol
|
||||
from agent_framework import SupportsAgentRun
|
||||
|
||||
"""
|
||||
Demonstration of a tool with approvals.
|
||||
@@ -40,7 +40,7 @@ def get_weather_detail(location: Annotated[str, "The city and state, e.g. San Fr
|
||||
)
|
||||
|
||||
|
||||
async def handle_approvals(query: str, agent: "AgentProtocol") -> AgentResponse:
|
||||
async def handle_approvals(query: str, agent: "SupportsAgentRun") -> AgentResponse:
|
||||
"""Handle function call approvals.
|
||||
|
||||
When we don't have a thread, we need to ensure we include the original query,
|
||||
@@ -75,7 +75,7 @@ async def handle_approvals(query: str, agent: "AgentProtocol") -> AgentResponse:
|
||||
return result
|
||||
|
||||
|
||||
async def handle_approvals_streaming(query: str, agent: "AgentProtocol") -> None:
|
||||
async def handle_approvals_streaming(query: str, agent: "SupportsAgentRun") -> None:
|
||||
"""Handle function call approvals with streaming responses.
|
||||
|
||||
When we don't have a thread, we need to ensure we include the original query,
|
||||
|
||||
@@ -29,8 +29,6 @@ async def main() -> None:
|
||||
"You are a Researcher. You find information without additional computation or quantitative analysis."
|
||||
),
|
||||
# This agent requires the gpt-4o-search-preview model to perform web searches.
|
||||
# Feel free to explore with other agents that support web search, for example,
|
||||
# the `OpenAIResponseAgent` or `AzureAgentProtocol` with bing grounding.
|
||||
chat_client=OpenAIChatClient(model_id="gpt-4o-search-preview"),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user