.Python: Added Agent and AgentThread abstractions (#130)

* Added Agent and AgentThread classes

* Addressed PR feedback

* Converted Agent to protocol

* Removed thread deletion logic

* Small update

* Small updates to the Agent protocol
This commit is contained in:
Dmytro Struk
2025-07-07 10:47:14 -07:00
committed by GitHub
Unverified
parent 09309c1239
commit 35c938fb5b
5 changed files with 262 additions and 0 deletions
+2
View File
@@ -10,6 +10,8 @@ except importlib.metadata.PackageNotFoundError:
_IMPORTS = {
"get_logger": "._logging",
"Agent": "._agents",
"AgentThread": "._agents",
"AITool": "._tools",
"ai_function": "._tools",
"AIContent": "._types",
+3
View File
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
from . import __version__ # type: ignore[attr-defined]
from ._agents import Agent, AgentThread
from ._clients import ChatClient, EmbeddingGenerator
from ._logging import get_logger
from ._tools import AITool, ai_function
@@ -32,6 +33,8 @@ __all__ = [
"AIContent",
"AIContents",
"AITool",
"Agent",
"AgentThread",
"ChatClient",
"ChatFinishReason",
"ChatMessage",
+140
View File
@@ -0,0 +1,140 @@
# Copyright (c) Microsoft. All rights reserved.
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, Awaitable, Callable
from typing import Any, Protocol, TypeVar, runtime_checkable
from ._types import ChatMessage, ChatResponse, ChatResponseUpdate
TThreadType = TypeVar("TThreadType", bound="AgentThread")
# region AgentThread
class AgentThread(ABC):
"""Base class for agent threads."""
def __init__(self) -> None:
"""Initialize the agent thread."""
self._id: str | None = None # type: ignore
@property
def id(self) -> str | None:
"""Returns the ID of the current thread (if any)."""
return self._id
async def create(self) -> str | None:
"""Starts the thread and returns the thread ID."""
# If the thread ID is already set, we're done, just return the Id.
if self.id is not None:
return self.id
# Otherwise, create the thread.
self._id = await self._create()
return self.id
async def delete(self) -> None:
"""Ends the current thread."""
await self._delete()
self._id = None
async def on_new_message(
self,
new_message: ChatMessage,
) -> None:
"""Invoked when a new message has been contributed to the chat by any participant."""
# If the thread is not created yet, create it.
if self.id is None:
await self.create()
await self._on_new_message(new_message)
@abstractmethod
async def _create(self) -> str:
"""Starts the thread and returns the thread ID."""
raise NotImplementedError
@abstractmethod
async def _delete(self) -> None:
"""Ends the current thread."""
raise NotImplementedError
@abstractmethod
async def _on_new_message(
self,
new_message: ChatMessage,
) -> None:
"""Invoked when a new message has been contributed to the chat by any participant."""
raise NotImplementedError
# region Agent Protocol
@runtime_checkable
class Agent(Protocol):
"""A protocol for an agent that can be invoked."""
async def run(
self,
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
*,
arguments: dict[str, Any] | None = None,
thread: AgentThread | None = None,
**kwargs: Any,
) -> ChatResponse:
"""Get a response from the agent.
This method returns the final result of the agent's execution
as a single ChatResponse object. The caller is blocked until
the final result is available.
Note: For streaming responses, use the run_stream method, which returns
intermediate steps and the final result as a stream of ChatResponseUpdate
objects. Streaming only the final result is not feasible because the timing of
the final result's availability is unknown, and blocking the caller until then
is undesirable in streaming scenarios.
Args:
messages: The message(s) to send to the agent.
arguments: Additional arguments to pass to the agent.
thread: The conversation thread associated with the message(s).
kwargs: Additional keyword arguments.
Returns:
An agent response item.
"""
...
async def run_stream(
self,
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
*,
arguments: dict[str, Any] | None = None,
thread: AgentThread | None = None,
on_intermediate_message: Callable[[ChatMessage], Awaitable[None]] | None = None,
**kwargs: Any,
) -> AsyncIterable[ChatResponseUpdate]:
"""Run the agent as a stream.
This method will return the intermediate steps and final results of the
agent's execution as a stream of ChatResponseUpdate objects to the caller.
To get the intermediate steps of the agent's execution as fully formed messages,
use the on_intermediate_message callback.
Note: A ChatResponseUpdate object contains a chunk of a message.
Args:
messages: The message(s) to send to the agent.
arguments: Additional arguments to pass to the agent.
thread: The conversation thread associated with the message(s).
on_intermediate_message: A callback function to handle intermediate steps of the
agent's execution as fully formed messages.
kwargs: Additional keyword arguments.
Yields:
An agent response item.
"""
...
+12
View File
@@ -5,3 +5,15 @@ class AgentFrameworkException(Exception):
"""Base class for exceptions in the Agent Framework."""
pass
class AgentException(AgentFrameworkException):
"""Base class for all agent exceptions."""
pass
class AgentExecutionException(AgentException):
"""An error occurred while executing the agent."""
pass
+105
View File
@@ -0,0 +1,105 @@
# Copyright (c) Microsoft. All rights reserved.
import uuid
from collections.abc import AsyncIterable, Awaitable, Callable
from typing import Any, TypeVar, cast
from pytest import fixture
from agent_framework import Agent, AgentThread, ChatMessage, ChatResponse, ChatResponseUpdate, ChatRole, TextContent
TThreadType = TypeVar("TThreadType", bound=AgentThread)
# Mock AgentThread implementation for testing
class MockAgentThread(AgentThread):
async def _create(self) -> str:
return str(uuid.uuid4())
async def _delete(self) -> None:
pass
async def _on_new_message(self, new_message: ChatMessage) -> None:
pass
# Mock Agent implementation for testing
class MockAgent:
async def run(
self,
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
*,
arguments: dict[str, Any] | None = None,
thread: AgentThread | None = None,
**kwargs: Any,
) -> ChatResponse:
return ChatResponse(messages=[ChatMessage(role=ChatRole.ASSISTANT, contents=[TextContent("Response")])])
async def run_stream(
self,
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
*,
arguments: dict[str, Any] | None = None,
thread: AgentThread | None = None,
on_intermediate_message: Callable[[ChatMessage], Awaitable[None]] | None = None,
**kwargs: Any,
) -> AsyncIterable[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent("Response")])
@fixture
def agent_thread() -> AgentThread:
return MockAgentThread()
@fixture
def agent() -> MockAgent:
return MockAgent()
async def test_agent_thread_id_property(agent_thread: MockAgentThread) -> None:
assert agent_thread.id is None
await agent_thread.create()
assert isinstance(agent_thread.id, str)
async def test_agent_thread_create(agent_thread: MockAgentThread) -> None:
thread_id = await agent_thread.create()
assert thread_id == agent_thread.id
assert isinstance(thread_id, str)
async def test_agent_thread_create_already_exists(agent_thread: MockAgentThread) -> None:
thread_id = await agent_thread.create()
same_id = await agent_thread.create()
assert thread_id == same_id
async def test_agent_thread_delete_already_deleted(agent_thread: MockAgentThread) -> None:
await agent_thread.delete()
await agent_thread.delete() # Should not raise error
async def test_agent_thread_on_new_message_creates_thread(agent_thread: MockAgentThread) -> None:
message = ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")])
await agent_thread.on_new_message(message)
assert agent_thread.id is not None
def test_agent_type(agent: MockAgent) -> None:
assert isinstance(agent, Agent)
async def test_agent_run(agent: MockAgent) -> None:
response = await agent.run("test")
assert response.messages[0].role == ChatRole.ASSISTANT
assert cast(TextContent, response.messages[0].contents[0]).text == "Response"
async def tesT_agent_run_stream(agent: MockAgent) -> None:
async def collect_updates(updates: AsyncIterable[ChatResponseUpdate]) -> list[ChatResponseUpdate]:
return [u async for u in updates]
updates = await collect_updates(agent.run_stream(messages="test"))
assert len(updates) == 1
assert cast(TextContent, updates[0].contents[0]).text == "Response"