mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
.Python: Added Agent and AgentThread abstractions (#130)
* Added Agent and AgentThread classes * Addressed PR feedback * Converted Agent to protocol * Removed thread deletion logic * Small update * Small updates to the Agent protocol
This commit is contained in:
committed by
GitHub
Unverified
parent
09309c1239
commit
35c938fb5b
@@ -10,6 +10,8 @@ except importlib.metadata.PackageNotFoundError:
|
||||
|
||||
_IMPORTS = {
|
||||
"get_logger": "._logging",
|
||||
"Agent": "._agents",
|
||||
"AgentThread": "._agents",
|
||||
"AITool": "._tools",
|
||||
"ai_function": "._tools",
|
||||
"AIContent": "._types",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from . import __version__ # type: ignore[attr-defined]
|
||||
from ._agents import Agent, AgentThread
|
||||
from ._clients import ChatClient, EmbeddingGenerator
|
||||
from ._logging import get_logger
|
||||
from ._tools import AITool, ai_function
|
||||
@@ -32,6 +33,8 @@ __all__ = [
|
||||
"AIContent",
|
||||
"AIContents",
|
||||
"AITool",
|
||||
"Agent",
|
||||
"AgentThread",
|
||||
"ChatClient",
|
||||
"ChatFinishReason",
|
||||
"ChatMessage",
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable
|
||||
from typing import Any, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
from ._types import ChatMessage, ChatResponse, ChatResponseUpdate
|
||||
|
||||
TThreadType = TypeVar("TThreadType", bound="AgentThread")
|
||||
|
||||
|
||||
# region AgentThread
|
||||
|
||||
|
||||
class AgentThread(ABC):
|
||||
"""Base class for agent threads."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the agent thread."""
|
||||
self._id: str | None = None # type: ignore
|
||||
|
||||
@property
|
||||
def id(self) -> str | None:
|
||||
"""Returns the ID of the current thread (if any)."""
|
||||
return self._id
|
||||
|
||||
async def create(self) -> str | None:
|
||||
"""Starts the thread and returns the thread ID."""
|
||||
# If the thread ID is already set, we're done, just return the Id.
|
||||
if self.id is not None:
|
||||
return self.id
|
||||
|
||||
# Otherwise, create the thread.
|
||||
self._id = await self._create()
|
||||
return self.id
|
||||
|
||||
async def delete(self) -> None:
|
||||
"""Ends the current thread."""
|
||||
await self._delete()
|
||||
self._id = None
|
||||
|
||||
async def on_new_message(
|
||||
self,
|
||||
new_message: ChatMessage,
|
||||
) -> None:
|
||||
"""Invoked when a new message has been contributed to the chat by any participant."""
|
||||
# If the thread is not created yet, create it.
|
||||
if self.id is None:
|
||||
await self.create()
|
||||
|
||||
await self._on_new_message(new_message)
|
||||
|
||||
@abstractmethod
|
||||
async def _create(self) -> str:
|
||||
"""Starts the thread and returns the thread ID."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def _delete(self) -> None:
|
||||
"""Ends the current thread."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def _on_new_message(
|
||||
self,
|
||||
new_message: ChatMessage,
|
||||
) -> None:
|
||||
"""Invoked when a new message has been contributed to the chat by any participant."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# region Agent Protocol
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Agent(Protocol):
|
||||
"""A protocol for an agent that can be invoked."""
|
||||
|
||||
async def run(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
|
||||
*,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
"""Get a response from the agent.
|
||||
|
||||
This method returns the final result of the agent's execution
|
||||
as a single ChatResponse object. The caller is blocked until
|
||||
the final result is available.
|
||||
|
||||
Note: For streaming responses, use the run_stream method, which returns
|
||||
intermediate steps and the final result as a stream of ChatResponseUpdate
|
||||
objects. Streaming only the final result is not feasible because the timing of
|
||||
the final result's availability is unknown, and blocking the caller until then
|
||||
is undesirable in streaming scenarios.
|
||||
|
||||
Args:
|
||||
messages: The message(s) to send to the agent.
|
||||
arguments: Additional arguments to pass to the agent.
|
||||
thread: The conversation thread associated with the message(s).
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
An agent response item.
|
||||
"""
|
||||
...
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
|
||||
*,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
thread: AgentThread | None = None,
|
||||
on_intermediate_message: Callable[[ChatMessage], Awaitable[None]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[ChatResponseUpdate]:
|
||||
"""Run the agent as a stream.
|
||||
|
||||
This method will return the intermediate steps and final results of the
|
||||
agent's execution as a stream of ChatResponseUpdate objects to the caller.
|
||||
|
||||
To get the intermediate steps of the agent's execution as fully formed messages,
|
||||
use the on_intermediate_message callback.
|
||||
|
||||
Note: A ChatResponseUpdate object contains a chunk of a message.
|
||||
|
||||
Args:
|
||||
messages: The message(s) to send to the agent.
|
||||
arguments: Additional arguments to pass to the agent.
|
||||
thread: The conversation thread associated with the message(s).
|
||||
on_intermediate_message: A callback function to handle intermediate steps of the
|
||||
agent's execution as fully formed messages.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Yields:
|
||||
An agent response item.
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -5,3 +5,15 @@ class AgentFrameworkException(Exception):
|
||||
"""Base class for exceptions in the Agent Framework."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentException(AgentFrameworkException):
|
||||
"""Base class for all agent exceptions."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentExecutionException(AgentException):
|
||||
"""An error occurred while executing the agent."""
|
||||
|
||||
pass
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import uuid
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
from pytest import fixture
|
||||
|
||||
from agent_framework import Agent, AgentThread, ChatMessage, ChatResponse, ChatResponseUpdate, ChatRole, TextContent
|
||||
|
||||
TThreadType = TypeVar("TThreadType", bound=AgentThread)
|
||||
|
||||
|
||||
# Mock AgentThread implementation for testing
|
||||
class MockAgentThread(AgentThread):
|
||||
async def _create(self) -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
async def _delete(self) -> None:
|
||||
pass
|
||||
|
||||
async def _on_new_message(self, new_message: ChatMessage) -> None:
|
||||
pass
|
||||
|
||||
|
||||
# Mock Agent implementation for testing
|
||||
class MockAgent:
|
||||
async def run(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
|
||||
*,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
return ChatResponse(messages=[ChatMessage(role=ChatRole.ASSISTANT, contents=[TextContent("Response")])])
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
|
||||
*,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
thread: AgentThread | None = None,
|
||||
on_intermediate_message: Callable[[ChatMessage], Awaitable[None]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent("Response")])
|
||||
|
||||
|
||||
@fixture
|
||||
def agent_thread() -> AgentThread:
|
||||
return MockAgentThread()
|
||||
|
||||
|
||||
@fixture
|
||||
def agent() -> MockAgent:
|
||||
return MockAgent()
|
||||
|
||||
|
||||
async def test_agent_thread_id_property(agent_thread: MockAgentThread) -> None:
|
||||
assert agent_thread.id is None
|
||||
await agent_thread.create()
|
||||
assert isinstance(agent_thread.id, str)
|
||||
|
||||
|
||||
async def test_agent_thread_create(agent_thread: MockAgentThread) -> None:
|
||||
thread_id = await agent_thread.create()
|
||||
assert thread_id == agent_thread.id
|
||||
assert isinstance(thread_id, str)
|
||||
|
||||
|
||||
async def test_agent_thread_create_already_exists(agent_thread: MockAgentThread) -> None:
|
||||
thread_id = await agent_thread.create()
|
||||
same_id = await agent_thread.create()
|
||||
assert thread_id == same_id
|
||||
|
||||
|
||||
async def test_agent_thread_delete_already_deleted(agent_thread: MockAgentThread) -> None:
|
||||
await agent_thread.delete()
|
||||
await agent_thread.delete() # Should not raise error
|
||||
|
||||
|
||||
async def test_agent_thread_on_new_message_creates_thread(agent_thread: MockAgentThread) -> None:
|
||||
message = ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")])
|
||||
await agent_thread.on_new_message(message)
|
||||
assert agent_thread.id is not None
|
||||
|
||||
|
||||
def test_agent_type(agent: MockAgent) -> None:
|
||||
assert isinstance(agent, Agent)
|
||||
|
||||
|
||||
async def test_agent_run(agent: MockAgent) -> None:
|
||||
response = await agent.run("test")
|
||||
assert response.messages[0].role == ChatRole.ASSISTANT
|
||||
assert cast(TextContent, response.messages[0].contents[0]).text == "Response"
|
||||
|
||||
|
||||
async def tesT_agent_run_stream(agent: MockAgent) -> None:
|
||||
async def collect_updates(updates: AsyncIterable[ChatResponseUpdate]) -> list[ChatResponseUpdate]:
|
||||
return [u async for u in updates]
|
||||
|
||||
updates = await collect_updates(agent.run_stream(messages="test"))
|
||||
assert len(updates) == 1
|
||||
assert cast(TextContent, updates[0].contents[0]).text == "Response"
|
||||
Reference in New Issue
Block a user