From 35c938fb5bab2527d146206b390b4f9d9b84980d Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Mon, 7 Jul 2025 10:47:14 -0700 Subject: [PATCH] .Python: Added Agent and AgentThread abstractions (#130) * Added Agent and AgentThread classes * Addressed PR feedback * Converted Agent to protocol * Removed thread deletion logic * Small update * Small updates to the Agent protocol --- python/agent_framework/__init__.py | 2 + python/agent_framework/__init__.pyi | 3 + python/agent_framework/_agents.py | 140 +++++++++++++++++++++++++++ python/agent_framework/exceptions.py | 12 +++ python/tests/unit/test_agents.py | 105 ++++++++++++++++++++ 5 files changed, 262 insertions(+) create mode 100644 python/tests/unit/test_agents.py diff --git a/python/agent_framework/__init__.py b/python/agent_framework/__init__.py index c933ae97e8..269797b340 100644 --- a/python/agent_framework/__init__.py +++ b/python/agent_framework/__init__.py @@ -10,6 +10,8 @@ except importlib.metadata.PackageNotFoundError: _IMPORTS = { "get_logger": "._logging", + "Agent": "._agents", + "AgentThread": "._agents", "AITool": "._tools", "ai_function": "._tools", "AIContent": "._types", diff --git a/python/agent_framework/__init__.pyi b/python/agent_framework/__init__.pyi index b4d78c6828..89c6301740 100644 --- a/python/agent_framework/__init__.pyi +++ b/python/agent_framework/__init__.pyi @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from . import __version__ # type: ignore[attr-defined] +from ._agents import Agent, AgentThread from ._clients import ChatClient, EmbeddingGenerator from ._logging import get_logger from ._tools import AITool, ai_function @@ -32,6 +33,8 @@ __all__ = [ "AIContent", "AIContents", "AITool", + "Agent", + "AgentThread", "ChatClient", "ChatFinishReason", "ChatMessage", diff --git a/python/agent_framework/_agents.py b/python/agent_framework/_agents.py index e69de29bb2..25f722b447 100644 --- a/python/agent_framework/_agents.py +++ b/python/agent_framework/_agents.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft. All rights reserved. + +from abc import ABC, abstractmethod +from collections.abc import AsyncIterable, Awaitable, Callable +from typing import Any, Protocol, TypeVar, runtime_checkable + +from ._types import ChatMessage, ChatResponse, ChatResponseUpdate + +TThreadType = TypeVar("TThreadType", bound="AgentThread") + + +# region AgentThread + + +class AgentThread(ABC): + """Base class for agent threads.""" + + def __init__(self) -> None: + """Initialize the agent thread.""" + self._id: str | None = None # type: ignore + + @property + def id(self) -> str | None: + """Returns the ID of the current thread (if any).""" + return self._id + + async def create(self) -> str | None: + """Starts the thread and returns the thread ID.""" + # If the thread ID is already set, we're done, just return the Id. + if self.id is not None: + return self.id + + # Otherwise, create the thread. + self._id = await self._create() + return self.id + + async def delete(self) -> None: + """Ends the current thread.""" + await self._delete() + self._id = None + + async def on_new_message( + self, + new_message: ChatMessage, + ) -> None: + """Invoked when a new message has been contributed to the chat by any participant.""" + # If the thread is not created yet, create it. + if self.id is None: + await self.create() + + await self._on_new_message(new_message) + + @abstractmethod + async def _create(self) -> str: + """Starts the thread and returns the thread ID.""" + raise NotImplementedError + + @abstractmethod + async def _delete(self) -> None: + """Ends the current thread.""" + raise NotImplementedError + + @abstractmethod + async def _on_new_message( + self, + new_message: ChatMessage, + ) -> None: + """Invoked when a new message has been contributed to the chat by any participant.""" + raise NotImplementedError + + +# region Agent Protocol + + +@runtime_checkable +class Agent(Protocol): + """A protocol for an agent that can be invoked.""" + + async def run( + self, + messages: str | ChatMessage | list[str | ChatMessage] | None = None, + *, + arguments: dict[str, Any] | None = None, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> ChatResponse: + """Get a response from the agent. + + This method returns the final result of the agent's execution + as a single ChatResponse object. The caller is blocked until + the final result is available. + + Note: For streaming responses, use the run_stream method, which returns + intermediate steps and the final result as a stream of ChatResponseUpdate + objects. Streaming only the final result is not feasible because the timing of + the final result's availability is unknown, and blocking the caller until then + is undesirable in streaming scenarios. + + Args: + messages: The message(s) to send to the agent. + arguments: Additional arguments to pass to the agent. + thread: The conversation thread associated with the message(s). + kwargs: Additional keyword arguments. + + Returns: + An agent response item. + """ + ... + + async def run_stream( + self, + messages: str | ChatMessage | list[str | ChatMessage] | None = None, + *, + arguments: dict[str, Any] | None = None, + thread: AgentThread | None = None, + on_intermediate_message: Callable[[ChatMessage], Awaitable[None]] | None = None, + **kwargs: Any, + ) -> AsyncIterable[ChatResponseUpdate]: + """Run the agent as a stream. + + This method will return the intermediate steps and final results of the + agent's execution as a stream of ChatResponseUpdate objects to the caller. + + To get the intermediate steps of the agent's execution as fully formed messages, + use the on_intermediate_message callback. + + Note: A ChatResponseUpdate object contains a chunk of a message. + + Args: + messages: The message(s) to send to the agent. + arguments: Additional arguments to pass to the agent. + thread: The conversation thread associated with the message(s). + on_intermediate_message: A callback function to handle intermediate steps of the + agent's execution as fully formed messages. + kwargs: Additional keyword arguments. + + Yields: + An agent response item. + """ + ... diff --git a/python/agent_framework/exceptions.py b/python/agent_framework/exceptions.py index 599d58914e..6c1a16801c 100644 --- a/python/agent_framework/exceptions.py +++ b/python/agent_framework/exceptions.py @@ -5,3 +5,15 @@ class AgentFrameworkException(Exception): """Base class for exceptions in the Agent Framework.""" pass + + +class AgentException(AgentFrameworkException): + """Base class for all agent exceptions.""" + + pass + + +class AgentExecutionException(AgentException): + """An error occurred while executing the agent.""" + + pass diff --git a/python/tests/unit/test_agents.py b/python/tests/unit/test_agents.py new file mode 100644 index 0000000000..5b6c4ea089 --- /dev/null +++ b/python/tests/unit/test_agents.py @@ -0,0 +1,105 @@ +# Copyright (c) Microsoft. All rights reserved. + +import uuid +from collections.abc import AsyncIterable, Awaitable, Callable +from typing import Any, TypeVar, cast + +from pytest import fixture + +from agent_framework import Agent, AgentThread, ChatMessage, ChatResponse, ChatResponseUpdate, ChatRole, TextContent + +TThreadType = TypeVar("TThreadType", bound=AgentThread) + + +# Mock AgentThread implementation for testing +class MockAgentThread(AgentThread): + async def _create(self) -> str: + return str(uuid.uuid4()) + + async def _delete(self) -> None: + pass + + async def _on_new_message(self, new_message: ChatMessage) -> None: + pass + + +# Mock Agent implementation for testing +class MockAgent: + async def run( + self, + messages: str | ChatMessage | list[str | ChatMessage] | None = None, + *, + arguments: dict[str, Any] | None = None, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> ChatResponse: + return ChatResponse(messages=[ChatMessage(role=ChatRole.ASSISTANT, contents=[TextContent("Response")])]) + + async def run_stream( + self, + messages: str | ChatMessage | list[str | ChatMessage] | None = None, + *, + arguments: dict[str, Any] | None = None, + thread: AgentThread | None = None, + on_intermediate_message: Callable[[ChatMessage], Awaitable[None]] | None = None, + **kwargs: Any, + ) -> AsyncIterable[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent("Response")]) + + +@fixture +def agent_thread() -> AgentThread: + return MockAgentThread() + + +@fixture +def agent() -> MockAgent: + return MockAgent() + + +async def test_agent_thread_id_property(agent_thread: MockAgentThread) -> None: + assert agent_thread.id is None + await agent_thread.create() + assert isinstance(agent_thread.id, str) + + +async def test_agent_thread_create(agent_thread: MockAgentThread) -> None: + thread_id = await agent_thread.create() + assert thread_id == agent_thread.id + assert isinstance(thread_id, str) + + +async def test_agent_thread_create_already_exists(agent_thread: MockAgentThread) -> None: + thread_id = await agent_thread.create() + same_id = await agent_thread.create() + assert thread_id == same_id + + +async def test_agent_thread_delete_already_deleted(agent_thread: MockAgentThread) -> None: + await agent_thread.delete() + await agent_thread.delete() # Should not raise error + + +async def test_agent_thread_on_new_message_creates_thread(agent_thread: MockAgentThread) -> None: + message = ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")]) + await agent_thread.on_new_message(message) + assert agent_thread.id is not None + + +def test_agent_type(agent: MockAgent) -> None: + assert isinstance(agent, Agent) + + +async def test_agent_run(agent: MockAgent) -> None: + response = await agent.run("test") + assert response.messages[0].role == ChatRole.ASSISTANT + assert cast(TextContent, response.messages[0].contents[0]).text == "Response" + + +async def tesT_agent_run_stream(agent: MockAgent) -> None: + async def collect_updates(updates: AsyncIterable[ChatResponseUpdate]) -> list[ChatResponseUpdate]: + return [u async for u in updates] + + updates = await collect_updates(agent.run_stream(messages="test")) + assert len(updates) == 1 + assert cast(TextContent, updates[0].contents[0]).text == "Response"