Python: small streamline of agents and agent threads (#145)

* small streamline of agents and agent threads

* slight update to tests
This commit is contained in:
Eduard van Valkenburg
2025-07-08 21:01:22 +02:00
committed by GitHub
Unverified
parent 7c8ec5ec19
commit e5ec41b869
2 changed files with 73 additions and 55 deletions
+41 -33
View File
@@ -1,28 +1,19 @@
# Copyright (c) Microsoft. All rights reserved.
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, Awaitable, Callable
from typing import Any, Protocol, TypeVar, runtime_checkable
from abc import abstractmethod
from collections.abc import AsyncIterable, Sequence
from typing import Any, Protocol, runtime_checkable
from ._pydantic import AFBaseModel
from ._types import ChatMessage, ChatResponse, ChatResponseUpdate
TThreadType = TypeVar("TThreadType", bound="AgentThread")
# region AgentThread
class AgentThread(ABC):
class AgentThread(AFBaseModel):
"""Base class for agent threads."""
def __init__(self) -> None:
"""Initialize the agent thread."""
self._id: str | None = None # type: ignore
@property
def id(self) -> str | None:
"""Returns the ID of the current thread (if any)."""
return self._id
id: str | None = None
async def create(self) -> str | None:
"""Starts the thread and returns the thread ID."""
@@ -31,42 +22,42 @@ class AgentThread(ABC):
return self.id
# Otherwise, create the thread.
self._id = await self._create()
self.id = await self._create()
return self.id
async def delete(self) -> None:
"""Ends the current thread."""
await self._delete()
self._id = None
self.id = None
async def on_new_message(
self,
new_message: ChatMessage,
new_messages: ChatMessage | Sequence[ChatMessage],
) -> None:
"""Invoked when a new message has been contributed to the chat by any participant."""
# If the thread is not created yet, create it.
if self.id is None:
await self.create()
await self._on_new_message(new_message)
await self._on_new_message(new_messages=new_messages)
@abstractmethod
async def _create(self) -> str:
"""Starts the thread and returns the thread ID."""
raise NotImplementedError
...
@abstractmethod
async def _delete(self) -> None:
"""Ends the current thread."""
raise NotImplementedError
...
@abstractmethod
async def _on_new_message(
self,
new_message: ChatMessage,
new_messages: ChatMessage | Sequence[ChatMessage],
) -> None:
"""Invoked when a new message has been contributed to the chat by any participant."""
raise NotImplementedError
...
# region Agent Protocol
@@ -76,11 +67,30 @@ class AgentThread(ABC):
class Agent(Protocol):
"""A protocol for an agent that can be invoked."""
@property
def id(self) -> str:
"""Returns the ID of the agent."""
...
@property
def name(self) -> str | None:
"""Returns the name of the agent."""
...
@property
def description(self) -> str | None:
"""Returns the description of the agent."""
...
@property
def instructions(self) -> str | None:
"""Returns the instructions for the agent."""
...
async def run(
self,
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
messages: str | ChatMessage | list[ChatMessage] | None = None,
*,
arguments: dict[str, Any] | None = None,
thread: AgentThread | None = None,
**kwargs: Any,
) -> ChatResponse:
@@ -98,7 +108,6 @@ class Agent(Protocol):
Args:
messages: The message(s) to send to the agent.
arguments: Additional arguments to pass to the agent.
thread: The conversation thread associated with the message(s).
kwargs: Additional keyword arguments.
@@ -107,13 +116,11 @@ class Agent(Protocol):
"""
...
async def run_stream(
def run_stream(
self,
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
messages: str | ChatMessage | list[ChatMessage] | None = None,
*,
arguments: dict[str, Any] | None = None,
thread: AgentThread | None = None,
on_intermediate_message: Callable[[ChatMessage], Awaitable[None]] | None = None,
**kwargs: Any,
) -> AsyncIterable[ChatResponseUpdate]:
"""Run the agent as a stream.
@@ -128,13 +135,14 @@ class Agent(Protocol):
Args:
messages: The message(s) to send to the agent.
arguments: Additional arguments to pass to the agent.
thread: The conversation thread associated with the message(s).
on_intermediate_message: A callback function to handle intermediate steps of the
agent's execution as fully formed messages.
kwargs: Additional keyword arguments.
Yields:
An agent response item.
"""
...
def get_new_thread(self) -> AgentThread:
"""Creates a new conversation thread for the agent."""
...
+32 -22
View File
@@ -1,9 +1,10 @@
# Copyright (c) Microsoft. All rights reserved.
import uuid
from collections.abc import AsyncIterable, Awaitable, Callable
from typing import Any, TypeVar, cast
from collections.abc import AsyncIterable, Sequence
from typing import Any, TypeVar
from uuid import uuid4
from pydantic import BaseModel, Field
from pytest import fixture
from agent_framework import Agent, AgentThread, ChatMessage, ChatResponse, ChatResponseUpdate, ChatRole, TextContent
@@ -14,22 +15,26 @@ TThreadType = TypeVar("TThreadType", bound=AgentThread)
# Mock AgentThread implementation for testing
class MockAgentThread(AgentThread):
async def _create(self) -> str:
return str(uuid.uuid4())
return str(uuid4())
async def _delete(self) -> None:
pass
async def _on_new_message(self, new_message: ChatMessage) -> None:
async def _on_new_message(self, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
pass
# Mock Agent implementation for testing
class MockAgent:
class MockAgent(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
name: str | None = None
description: str | None = None
instructions: str | None = None
async def run(
self,
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
messages: ChatMessage | str | list[ChatMessage] | None = None,
*,
arguments: dict[str, Any] | None = None,
thread: AgentThread | None = None,
**kwargs: Any,
) -> ChatResponse:
@@ -37,15 +42,16 @@ class MockAgent:
async def run_stream(
self,
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
messages: str | ChatMessage | list[ChatMessage] | None = None,
*,
arguments: dict[str, Any] | None = None,
thread: AgentThread | None = None,
on_intermediate_message: Callable[[ChatMessage], Awaitable[None]] | None = None,
**kwargs: Any,
) -> AsyncIterable[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent("Response")])
def get_new_thread(self) -> AgentThread:
return MockAgentThread()
@fixture
def agent_thread() -> AgentThread:
@@ -53,53 +59,57 @@ def agent_thread() -> AgentThread:
@fixture
def agent() -> MockAgent:
def agent() -> Agent:
return MockAgent()
async def test_agent_thread_id_property(agent_thread: MockAgentThread) -> None:
def test_agent_thread_type(agent_thread: AgentThread) -> None:
assert isinstance(agent_thread, AgentThread)
async def test_agent_thread_id_property(agent_thread: AgentThread) -> None:
assert agent_thread.id is None
await agent_thread.create()
assert isinstance(agent_thread.id, str)
async def test_agent_thread_create(agent_thread: MockAgentThread) -> None:
async def test_agent_thread_create(agent_thread: AgentThread) -> None:
thread_id = await agent_thread.create()
assert thread_id == agent_thread.id
assert isinstance(thread_id, str)
async def test_agent_thread_create_already_exists(agent_thread: MockAgentThread) -> None:
async def test_agent_thread_create_already_exists(agent_thread: AgentThread) -> None:
thread_id = await agent_thread.create()
same_id = await agent_thread.create()
assert thread_id == same_id
async def test_agent_thread_delete_already_deleted(agent_thread: MockAgentThread) -> None:
async def test_agent_thread_delete_already_deleted(agent_thread: AgentThread) -> None:
await agent_thread.delete()
await agent_thread.delete() # Should not raise error
async def test_agent_thread_on_new_message_creates_thread(agent_thread: MockAgentThread) -> None:
async def test_agent_thread_on_new_message_creates_thread(agent_thread: AgentThread) -> None:
message = ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")])
await agent_thread.on_new_message(message)
assert agent_thread.id is not None
def test_agent_type(agent: MockAgent) -> None:
def test_agent_type(agent: Agent) -> None:
assert isinstance(agent, Agent)
async def test_agent_run(agent: MockAgent) -> None:
async def test_agent_run(agent: Agent) -> None:
response = await agent.run("test")
assert response.messages[0].role == ChatRole.ASSISTANT
assert cast(TextContent, response.messages[0].contents[0]).text == "Response"
assert response.messages[0].text == "Response"
async def tesT_agent_run_stream(agent: MockAgent) -> None:
async def test_agent_run_stream(agent: Agent) -> None:
async def collect_updates(updates: AsyncIterable[ChatResponseUpdate]) -> list[ChatResponseUpdate]:
return [u async for u in updates]
updates = await collect_updates(agent.run_stream(messages="test"))
assert len(updates) == 1
assert cast(TextContent, updates[0].contents[0]).text == "Response"
assert updates[0].text == "Response"