mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: small streamline of agents and agent threads (#145)
* small streamline of agents and agent threads * slight update to tests
This commit is contained in:
committed by
GitHub
Unverified
parent
7c8ec5ec19
commit
e5ec41b869
@@ -1,28 +1,19 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable
|
||||
from typing import Any, Protocol, TypeVar, runtime_checkable
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterable, Sequence
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from ._pydantic import AFBaseModel
|
||||
from ._types import ChatMessage, ChatResponse, ChatResponseUpdate
|
||||
|
||||
TThreadType = TypeVar("TThreadType", bound="AgentThread")
|
||||
|
||||
|
||||
# region AgentThread
|
||||
|
||||
|
||||
class AgentThread(ABC):
|
||||
class AgentThread(AFBaseModel):
|
||||
"""Base class for agent threads."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the agent thread."""
|
||||
self._id: str | None = None # type: ignore
|
||||
|
||||
@property
|
||||
def id(self) -> str | None:
|
||||
"""Returns the ID of the current thread (if any)."""
|
||||
return self._id
|
||||
id: str | None = None
|
||||
|
||||
async def create(self) -> str | None:
|
||||
"""Starts the thread and returns the thread ID."""
|
||||
@@ -31,42 +22,42 @@ class AgentThread(ABC):
|
||||
return self.id
|
||||
|
||||
# Otherwise, create the thread.
|
||||
self._id = await self._create()
|
||||
self.id = await self._create()
|
||||
return self.id
|
||||
|
||||
async def delete(self) -> None:
|
||||
"""Ends the current thread."""
|
||||
await self._delete()
|
||||
self._id = None
|
||||
self.id = None
|
||||
|
||||
async def on_new_message(
|
||||
self,
|
||||
new_message: ChatMessage,
|
||||
new_messages: ChatMessage | Sequence[ChatMessage],
|
||||
) -> None:
|
||||
"""Invoked when a new message has been contributed to the chat by any participant."""
|
||||
# If the thread is not created yet, create it.
|
||||
if self.id is None:
|
||||
await self.create()
|
||||
|
||||
await self._on_new_message(new_message)
|
||||
await self._on_new_message(new_messages=new_messages)
|
||||
|
||||
@abstractmethod
|
||||
async def _create(self) -> str:
|
||||
"""Starts the thread and returns the thread ID."""
|
||||
raise NotImplementedError
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def _delete(self) -> None:
|
||||
"""Ends the current thread."""
|
||||
raise NotImplementedError
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def _on_new_message(
|
||||
self,
|
||||
new_message: ChatMessage,
|
||||
new_messages: ChatMessage | Sequence[ChatMessage],
|
||||
) -> None:
|
||||
"""Invoked when a new message has been contributed to the chat by any participant."""
|
||||
raise NotImplementedError
|
||||
...
|
||||
|
||||
|
||||
# region Agent Protocol
|
||||
@@ -76,11 +67,30 @@ class AgentThread(ABC):
|
||||
class Agent(Protocol):
|
||||
"""A protocol for an agent that can be invoked."""
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
"""Returns the ID of the agent."""
|
||||
...
|
||||
|
||||
@property
|
||||
def name(self) -> str | None:
|
||||
"""Returns the name of the agent."""
|
||||
...
|
||||
|
||||
@property
|
||||
def description(self) -> str | None:
|
||||
"""Returns the description of the agent."""
|
||||
...
|
||||
|
||||
@property
|
||||
def instructions(self) -> str | None:
|
||||
"""Returns the instructions for the agent."""
|
||||
...
|
||||
|
||||
async def run(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
|
||||
messages: str | ChatMessage | list[ChatMessage] | None = None,
|
||||
*,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
@@ -98,7 +108,6 @@ class Agent(Protocol):
|
||||
|
||||
Args:
|
||||
messages: The message(s) to send to the agent.
|
||||
arguments: Additional arguments to pass to the agent.
|
||||
thread: The conversation thread associated with the message(s).
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
@@ -107,13 +116,11 @@ class Agent(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
async def run_stream(
|
||||
def run_stream(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
|
||||
messages: str | ChatMessage | list[ChatMessage] | None = None,
|
||||
*,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
thread: AgentThread | None = None,
|
||||
on_intermediate_message: Callable[[ChatMessage], Awaitable[None]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[ChatResponseUpdate]:
|
||||
"""Run the agent as a stream.
|
||||
@@ -128,13 +135,14 @@ class Agent(Protocol):
|
||||
|
||||
Args:
|
||||
messages: The message(s) to send to the agent.
|
||||
arguments: Additional arguments to pass to the agent.
|
||||
thread: The conversation thread associated with the message(s).
|
||||
on_intermediate_message: A callback function to handle intermediate steps of the
|
||||
agent's execution as fully formed messages.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Yields:
|
||||
An agent response item.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_new_thread(self) -> AgentThread:
|
||||
"""Creates a new conversation thread for the agent."""
|
||||
...
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import uuid
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable
|
||||
from typing import Any, TypeVar, cast
|
||||
from collections.abc import AsyncIterable, Sequence
|
||||
from typing import Any, TypeVar
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pytest import fixture
|
||||
|
||||
from agent_framework import Agent, AgentThread, ChatMessage, ChatResponse, ChatResponseUpdate, ChatRole, TextContent
|
||||
@@ -14,22 +15,26 @@ TThreadType = TypeVar("TThreadType", bound=AgentThread)
|
||||
# Mock AgentThread implementation for testing
|
||||
class MockAgentThread(AgentThread):
|
||||
async def _create(self) -> str:
|
||||
return str(uuid.uuid4())
|
||||
return str(uuid4())
|
||||
|
||||
async def _delete(self) -> None:
|
||||
pass
|
||||
|
||||
async def _on_new_message(self, new_message: ChatMessage) -> None:
|
||||
async def _on_new_message(self, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
# Mock Agent implementation for testing
|
||||
class MockAgent:
|
||||
class MockAgent(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
instructions: str | None = None
|
||||
|
||||
async def run(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
|
||||
messages: ChatMessage | str | list[ChatMessage] | None = None,
|
||||
*,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
@@ -37,15 +42,16 @@ class MockAgent:
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
|
||||
messages: str | ChatMessage | list[ChatMessage] | None = None,
|
||||
*,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
thread: AgentThread | None = None,
|
||||
on_intermediate_message: Callable[[ChatMessage], Awaitable[None]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent("Response")])
|
||||
|
||||
def get_new_thread(self) -> AgentThread:
|
||||
return MockAgentThread()
|
||||
|
||||
|
||||
@fixture
|
||||
def agent_thread() -> AgentThread:
|
||||
@@ -53,53 +59,57 @@ def agent_thread() -> AgentThread:
|
||||
|
||||
|
||||
@fixture
|
||||
def agent() -> MockAgent:
|
||||
def agent() -> Agent:
|
||||
return MockAgent()
|
||||
|
||||
|
||||
async def test_agent_thread_id_property(agent_thread: MockAgentThread) -> None:
|
||||
def test_agent_thread_type(agent_thread: AgentThread) -> None:
|
||||
assert isinstance(agent_thread, AgentThread)
|
||||
|
||||
|
||||
async def test_agent_thread_id_property(agent_thread: AgentThread) -> None:
|
||||
assert agent_thread.id is None
|
||||
await agent_thread.create()
|
||||
assert isinstance(agent_thread.id, str)
|
||||
|
||||
|
||||
async def test_agent_thread_create(agent_thread: MockAgentThread) -> None:
|
||||
async def test_agent_thread_create(agent_thread: AgentThread) -> None:
|
||||
thread_id = await agent_thread.create()
|
||||
assert thread_id == agent_thread.id
|
||||
assert isinstance(thread_id, str)
|
||||
|
||||
|
||||
async def test_agent_thread_create_already_exists(agent_thread: MockAgentThread) -> None:
|
||||
async def test_agent_thread_create_already_exists(agent_thread: AgentThread) -> None:
|
||||
thread_id = await agent_thread.create()
|
||||
same_id = await agent_thread.create()
|
||||
assert thread_id == same_id
|
||||
|
||||
|
||||
async def test_agent_thread_delete_already_deleted(agent_thread: MockAgentThread) -> None:
|
||||
async def test_agent_thread_delete_already_deleted(agent_thread: AgentThread) -> None:
|
||||
await agent_thread.delete()
|
||||
await agent_thread.delete() # Should not raise error
|
||||
|
||||
|
||||
async def test_agent_thread_on_new_message_creates_thread(agent_thread: MockAgentThread) -> None:
|
||||
async def test_agent_thread_on_new_message_creates_thread(agent_thread: AgentThread) -> None:
|
||||
message = ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")])
|
||||
await agent_thread.on_new_message(message)
|
||||
assert agent_thread.id is not None
|
||||
|
||||
|
||||
def test_agent_type(agent: MockAgent) -> None:
|
||||
def test_agent_type(agent: Agent) -> None:
|
||||
assert isinstance(agent, Agent)
|
||||
|
||||
|
||||
async def test_agent_run(agent: MockAgent) -> None:
|
||||
async def test_agent_run(agent: Agent) -> None:
|
||||
response = await agent.run("test")
|
||||
assert response.messages[0].role == ChatRole.ASSISTANT
|
||||
assert cast(TextContent, response.messages[0].contents[0]).text == "Response"
|
||||
assert response.messages[0].text == "Response"
|
||||
|
||||
|
||||
async def tesT_agent_run_stream(agent: MockAgent) -> None:
|
||||
async def test_agent_run_stream(agent: Agent) -> None:
|
||||
async def collect_updates(updates: AsyncIterable[ChatResponseUpdate]) -> list[ChatResponseUpdate]:
|
||||
return [u async for u in updates]
|
||||
|
||||
updates = await collect_updates(agent.run_stream(messages="test"))
|
||||
assert len(updates) == 1
|
||||
assert cast(TextContent, updates[0].contents[0]).text == "Response"
|
||||
assert updates[0].text == "Response"
|
||||
|
||||
Reference in New Issue
Block a user