mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: updated API in sync with dotnet (#269)
* updated API in sync with dotnet * fix test * updated name and display_name * fixed mypy setup * add pre-commit cache
This commit is contained in:
committed by
GitHub
Unverified
parent
dc993b4734
commit
1ed89d4db8
@@ -74,6 +74,8 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_azure"
|
||||
|
||||
[tool.uv.build-backend]
|
||||
module-name = "agent_framework_azure"
|
||||
|
||||
@@ -96,7 +96,7 @@ class FoundrySettings(AFBaseSettings):
|
||||
class FoundryChatClient(ChatClientBase):
|
||||
"""Azure AI Foundry Chat client."""
|
||||
|
||||
MODEL_PROVIDER_NAME: ClassVar[str] = "azure_ai_foundry" # type: ignore[reportIncompatibleVariableOverride]
|
||||
MODEL_PROVIDER_NAME: ClassVar[str] = "azure_ai_foundry" # type: ignore[reportIncompatibleVariableOverride, misc]
|
||||
client: AIProjectClient = Field(...)
|
||||
credential: AsyncTokenCredential | None = Field(...)
|
||||
agent_id: str | None = Field(default=None)
|
||||
|
||||
@@ -76,6 +76,8 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_foundry"
|
||||
|
||||
[tool.uv.build-backend]
|
||||
module-name = "agent_framework_foundry"
|
||||
|
||||
@@ -6,11 +6,6 @@ from enum import Enum
|
||||
from typing import Any, Literal, Protocol, TypeVar, runtime_checkable
|
||||
from uuid import uuid4
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self # pragma: no cover
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ._clients import ChatClient
|
||||
@@ -28,12 +23,17 @@ from ._types import (
|
||||
)
|
||||
from .exceptions import AgentExecutionException
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self # pragma: no cover
|
||||
|
||||
TThreadType = TypeVar("TThreadType", bound="AgentThread")
|
||||
|
||||
# region AgentThread
|
||||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"AIAgent",
|
||||
"AgentBase",
|
||||
"AgentThread",
|
||||
"ChatClientAgent",
|
||||
@@ -77,7 +77,7 @@ class MessagesRetrievableThread(Protocol):
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Agent(Protocol):
|
||||
class AIAgent(Protocol):
|
||||
"""A protocol for an agent that can be invoked."""
|
||||
|
||||
@property
|
||||
@@ -90,6 +90,11 @@ class Agent(Protocol):
|
||||
"""Returns the name of the agent."""
|
||||
...
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""Returns the display name of the agent."""
|
||||
...
|
||||
|
||||
@property
|
||||
def description(self) -> str | None:
|
||||
"""Returns the description of the agent."""
|
||||
@@ -124,7 +129,7 @@ class Agent(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def run_stream(
|
||||
def run_streaming(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
@@ -157,17 +162,19 @@ class Agent(Protocol):
|
||||
|
||||
|
||||
class AgentBase(AFBaseModel):
|
||||
"""Base class for all agents.
|
||||
"""Base class for all Agent Framework agents.
|
||||
|
||||
Attributes:
|
||||
id: The unique identifier of the agent If no id is provided,
|
||||
a new UUID will be generated.
|
||||
name: The name of the agent
|
||||
description: The description of the agent
|
||||
name: The name of the agent, can be None.
|
||||
description: The description of the agent.
|
||||
display_name: The display name of the agent, which is either the name or id.
|
||||
|
||||
"""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
name: str = Field(default="UnnamedAgent")
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
|
||||
async def _notify_thread_of_new_messages(
|
||||
@@ -177,6 +184,43 @@ class AgentBase(AFBaseModel):
|
||||
if isinstance(new_messages, ChatMessage) or len(new_messages) > 0:
|
||||
await thread.on_new_messages(new_messages)
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""Returns the display name of the agent.
|
||||
|
||||
This is the name if present, otherwise the id.
|
||||
"""
|
||||
return self.name or self.id
|
||||
|
||||
def _validate_or_create_thread_type(
|
||||
self,
|
||||
thread: AgentThread | None,
|
||||
construct_thread: Callable[[], TThreadType],
|
||||
expected_type: type[TThreadType],
|
||||
) -> TThreadType:
|
||||
"""Validate or create a AgentThread of the right type.
|
||||
|
||||
Args:
|
||||
thread: The thread to validate or create.
|
||||
construct_thread: A callable that constructs a new thread if `thread` is None.
|
||||
expected_type: The expected type of the thread.
|
||||
|
||||
Returns:
|
||||
The validated or newly created thread of the expected type.
|
||||
|
||||
Raises:
|
||||
AgentExecutionException: If the thread is not of the expected type.
|
||||
"""
|
||||
if thread is None:
|
||||
return construct_thread()
|
||||
|
||||
if not isinstance(thread, expected_type):
|
||||
raise AgentExecutionException(
|
||||
f"{self.__class__.__name__} currently only supports agent threads of type {expected_type.__name__}."
|
||||
)
|
||||
|
||||
return thread
|
||||
|
||||
|
||||
# region ChatClientAgentThread
|
||||
|
||||
@@ -442,13 +486,7 @@ class ChatClientAgent(AgentBase):
|
||||
will only be passed to functions that are called.
|
||||
"""
|
||||
input_messages = self._normalize_messages(messages)
|
||||
|
||||
thread, thread_messages = await self._prepare_thread_and_messages(
|
||||
thread=thread,
|
||||
input_messages=input_messages,
|
||||
construct_thread=lambda: ChatClientAgentThread(),
|
||||
expected_type=ChatClientAgentThread,
|
||||
)
|
||||
thread, thread_messages = await self._prepare_thread_and_messages(thread=thread, input_messages=input_messages)
|
||||
|
||||
response = await self.chat_client.get_response(
|
||||
messages=thread_messages,
|
||||
@@ -491,7 +529,7 @@ class ChatClientAgent(AgentBase):
|
||||
additional_properties=response.additional_properties,
|
||||
)
|
||||
|
||||
async def run_stream(
|
||||
async def run_streaming(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
@@ -523,7 +561,7 @@ class ChatClientAgent(AgentBase):
|
||||
"""Stream the agent with the given messages and options.
|
||||
|
||||
Remarks:
|
||||
Since you won't always call the agent.run_stream directly, but it get's called
|
||||
Since you won't always call the agent.run_streaming directly, but it get's called
|
||||
through orchestration, it is advised to set your default values for
|
||||
all the chat client parameters in the agent constructor.
|
||||
If both parameters are used, the ones passed to the run methods take precedence.
|
||||
@@ -552,14 +590,7 @@ class ChatClientAgent(AgentBase):
|
||||
|
||||
"""
|
||||
input_messages = self._normalize_messages(messages)
|
||||
|
||||
thread, thread_messages = await self._prepare_thread_and_messages(
|
||||
thread=thread,
|
||||
input_messages=input_messages,
|
||||
construct_thread=lambda: ChatClientAgentThread(),
|
||||
expected_type=ChatClientAgentThread,
|
||||
)
|
||||
|
||||
thread, thread_messages = await self._prepare_thread_and_messages(thread=thread, input_messages=input_messages)
|
||||
response_updates: list[ChatResponseUpdate] = []
|
||||
|
||||
async for update in self.chat_client.get_streaming_response(
|
||||
@@ -607,7 +638,7 @@ class ChatClientAgent(AgentBase):
|
||||
await self._notify_thread_of_new_messages(thread, input_messages)
|
||||
await self._notify_thread_of_new_messages(thread, response.messages)
|
||||
|
||||
def get_new_thread(self) -> AgentThread:
|
||||
def get_new_thread(self) -> ChatClientAgentThread:
|
||||
return ChatClientAgentThread()
|
||||
|
||||
def _update_thread_with_type_and_conversation_id(
|
||||
@@ -644,45 +675,32 @@ class ChatClientAgent(AgentBase):
|
||||
*,
|
||||
thread: AgentThread | None,
|
||||
input_messages: list[ChatMessage] | None = None,
|
||||
construct_thread: Callable[[], TThreadType],
|
||||
expected_type: type[TThreadType],
|
||||
) -> tuple[TThreadType, list[ChatMessage]]:
|
||||
"""Prepare thread and messages for agent execution.
|
||||
) -> tuple[ChatClientAgentThread, list[ChatMessage]]:
|
||||
"""Prepare the messages for agent execution.
|
||||
|
||||
Args:
|
||||
thread: The conversation thread, or None to create a new one.
|
||||
thread: The conversation thread.
|
||||
input_messages: Messages to process.
|
||||
construct_thread: Factory function to create a new thread.
|
||||
expected_type: Expected thread type for validation.
|
||||
|
||||
Returns:
|
||||
Tuple of the thread and normalized messages.
|
||||
The validated thread and normalized messages.
|
||||
|
||||
Raises:
|
||||
AgentExecutionException: If thread type is incompatible.
|
||||
AgentExecutionException: If the thread is not of the expected type.
|
||||
"""
|
||||
validated_thread: ChatClientAgentThread = self._validate_or_create_thread_type( # type: ignore[reportAssignmentType]
|
||||
thread=thread,
|
||||
construct_thread=self.get_new_thread,
|
||||
expected_type=ChatClientAgentThread,
|
||||
)
|
||||
messages: list[ChatMessage] = []
|
||||
if self.instructions:
|
||||
messages.append(ChatMessage(role=ChatRole.SYSTEM, text=self.instructions))
|
||||
|
||||
if thread is None:
|
||||
thread = construct_thread()
|
||||
|
||||
if not isinstance(thread, expected_type):
|
||||
raise AgentExecutionException(
|
||||
f"{self.__class__.__name__} currently only supports agent threads of type {expected_type.__name__}."
|
||||
)
|
||||
|
||||
# Add any existing messages from the thread to the messages to be sent to the chat client.
|
||||
if isinstance(thread, MessagesRetrievableThread):
|
||||
async for message in thread.get_messages():
|
||||
if isinstance(validated_thread, MessagesRetrievableThread):
|
||||
async for message in validated_thread.get_messages():
|
||||
messages.append(message)
|
||||
|
||||
if input_messages is None:
|
||||
return thread, messages
|
||||
|
||||
messages.extend(input_messages)
|
||||
return thread, messages
|
||||
messages.extend(input_messages or [])
|
||||
return validated_thread, messages
|
||||
|
||||
def _normalize_messages(
|
||||
self,
|
||||
|
||||
@@ -87,6 +87,8 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework"
|
||||
|
||||
[tool.uv.build-backend]
|
||||
module-name = "agent_framework"
|
||||
|
||||
@@ -7,10 +7,10 @@ from uuid import uuid4
|
||||
from pytest import fixture, raises
|
||||
|
||||
from agent_framework import (
|
||||
Agent,
|
||||
AgentRunResponse,
|
||||
AgentRunResponseUpdate,
|
||||
AgentThread,
|
||||
AIAgent,
|
||||
ChatClient,
|
||||
ChatClientAgent,
|
||||
ChatClientAgentThread,
|
||||
@@ -33,7 +33,7 @@ class MockAgentThread(AgentThread):
|
||||
|
||||
|
||||
# Mock Agent implementation for testing
|
||||
class MockAgent(Agent):
|
||||
class MockAgent(AIAgent):
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return str(uuid4())
|
||||
@@ -56,7 +56,7 @@ class MockAgent(Agent):
|
||||
) -> AgentRunResponse:
|
||||
return AgentRunResponse(messages=[ChatMessage(role=ChatRole.ASSISTANT, contents=[TextContent("Response")])])
|
||||
|
||||
async def run_stream(
|
||||
async def run_streaming(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
@@ -105,7 +105,7 @@ def agent_thread() -> AgentThread:
|
||||
|
||||
|
||||
@fixture
|
||||
def agent() -> Agent:
|
||||
def agent() -> AIAgent:
|
||||
return MockAgent()
|
||||
|
||||
|
||||
@@ -118,21 +118,21 @@ def test_agent_thread_type(agent_thread: AgentThread) -> None:
|
||||
assert isinstance(agent_thread, AgentThread)
|
||||
|
||||
|
||||
def test_agent_type(agent: Agent) -> None:
|
||||
assert isinstance(agent, Agent)
|
||||
def test_agent_type(agent: AIAgent) -> None:
|
||||
assert isinstance(agent, AIAgent)
|
||||
|
||||
|
||||
async def test_agent_run(agent: Agent) -> None:
|
||||
async def test_agent_run(agent: AIAgent) -> None:
|
||||
response = await agent.run("test")
|
||||
assert response.messages[0].role == ChatRole.ASSISTANT
|
||||
assert response.messages[0].text == "Response"
|
||||
|
||||
|
||||
async def test_agent_run_stream(agent: Agent) -> None:
|
||||
async def test_agent_run_streaming(agent: AIAgent) -> None:
|
||||
async def collect_updates(updates: AsyncIterable[AgentRunResponseUpdate]) -> list[AgentRunResponseUpdate]:
|
||||
return [u async for u in updates]
|
||||
|
||||
updates = await collect_updates(agent.run_stream(messages="test"))
|
||||
updates = await collect_updates(agent.run_streaming(messages="test"))
|
||||
assert len(updates) == 1
|
||||
assert updates[0].text == "Response"
|
||||
|
||||
@@ -191,7 +191,7 @@ async def test_chat_client_agent_thread_on_new_messages_in_memory() -> None:
|
||||
|
||||
def test_chat_client_agent_type(chat_client: ChatClient) -> None:
|
||||
chat_client_agent = ChatClientAgent(chat_client=chat_client)
|
||||
assert isinstance(chat_client_agent, Agent)
|
||||
assert isinstance(chat_client_agent, AIAgent)
|
||||
|
||||
|
||||
async def test_chat_client_agent_init(chat_client: ChatClient) -> None:
|
||||
@@ -199,8 +199,19 @@ async def test_chat_client_agent_init(chat_client: ChatClient) -> None:
|
||||
agent = ChatClientAgent(chat_client=chat_client, id=agent_id, description="Test")
|
||||
|
||||
assert agent.id == agent_id
|
||||
assert agent.name == "UnnamedAgent"
|
||||
assert agent.name is None
|
||||
assert agent.description == "Test"
|
||||
assert agent.display_name == agent_id # Display name defaults to id if name is None
|
||||
|
||||
|
||||
async def test_chat_client_agent_init_with_name(chat_client: ChatClient) -> None:
|
||||
agent_id = str(uuid4())
|
||||
agent = ChatClientAgent(chat_client=chat_client, id=agent_id, name="Test Agent", description="Test")
|
||||
|
||||
assert agent.id == agent_id
|
||||
assert agent.name == "Test Agent"
|
||||
assert agent.description == "Test"
|
||||
assert agent.display_name == "Test Agent" # Display name is the name if present
|
||||
|
||||
|
||||
async def test_chat_client_agent_run(chat_client: ChatClient) -> None:
|
||||
@@ -211,10 +222,10 @@ async def test_chat_client_agent_run(chat_client: ChatClient) -> None:
|
||||
assert result.text == "test response"
|
||||
|
||||
|
||||
async def test_chat_client_agent_run_stream(chat_client: ChatClient) -> None:
|
||||
async def test_chat_client_agent_run_streaming(chat_client: ChatClient) -> None:
|
||||
agent = ChatClientAgent(chat_client=chat_client)
|
||||
|
||||
result = await AgentRunResponse.from_agent_response_generator(agent.run_stream("Hello"))
|
||||
result = await AgentRunResponse.from_agent_response_generator(agent.run_streaming("Hello"))
|
||||
|
||||
assert result.text == "test streaming response"
|
||||
|
||||
@@ -232,19 +243,35 @@ async def test_chat_client_agent_prepare_thread_and_messages(chat_client: ChatCl
|
||||
message = ChatMessage(role=ChatRole.USER, text="Hello")
|
||||
thread = ChatClientAgentThread(messages=[message])
|
||||
|
||||
result_thread, result_messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
|
||||
thread=thread,
|
||||
input_messages=[ChatMessage(role=ChatRole.USER, text="Test")],
|
||||
construct_thread=lambda: ChatClientAgentThread(),
|
||||
expected_type=ChatClientAgentThread,
|
||||
)
|
||||
result_thread = agent._validate_or_create_thread_type(
|
||||
thread, lambda: ChatClientAgentThread(), expected_type=ChatClientAgentThread
|
||||
) # type: ignore[reportPrivateUsage]
|
||||
|
||||
assert result_thread == thread
|
||||
assert isinstance(result_thread, ChatClientAgentThread)
|
||||
|
||||
_, result_messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
|
||||
thread=result_thread,
|
||||
input_messages=[ChatMessage(role=ChatRole.USER, text="Test")],
|
||||
)
|
||||
|
||||
assert len(result_messages) == 2
|
||||
assert result_messages[0] == message
|
||||
assert result_messages[1].text == "Test"
|
||||
|
||||
|
||||
async def test_chat_client_agent_validate_or_create_thread(chat_client: ChatClient) -> None:
|
||||
agent = ChatClientAgent(chat_client=chat_client)
|
||||
thread = None
|
||||
|
||||
result_thread = agent._validate_or_create_thread_type(
|
||||
thread, lambda: ChatClientAgentThread(), expected_type=ChatClientAgentThread
|
||||
) # type: ignore[reportPrivateUsage]
|
||||
|
||||
assert result_thread != thread
|
||||
assert isinstance(result_thread, ChatClientAgentThread)
|
||||
|
||||
|
||||
async def test_chat_client_agent_update_thread_id() -> None:
|
||||
chat_client = MockChatClient(
|
||||
mock_response=ChatResponse(
|
||||
|
||||
@@ -48,7 +48,7 @@ async def streaming_example() -> None:
|
||||
query = "What's the weather like in Portland?"
|
||||
print(f"User: {query}")
|
||||
print("Agent: ", end="", flush=True)
|
||||
async for chunk in agent.run_stream(query):
|
||||
async for chunk in agent.run_streaming(query):
|
||||
if chunk.text:
|
||||
print(chunk.text, end="", flush=True)
|
||||
print("\n")
|
||||
|
||||
@@ -48,7 +48,7 @@ async def streaming_example() -> None:
|
||||
query = "What's the weather like in Portland?"
|
||||
print(f"User: {query}")
|
||||
print("Agent: ", end="", flush=True)
|
||||
async for chunk in agent.run_stream(query):
|
||||
async for chunk in agent.run_streaming(query):
|
||||
if chunk.text:
|
||||
print(chunk.text, end="", flush=True)
|
||||
print("\n")
|
||||
|
||||
@@ -44,7 +44,7 @@ async def main() -> None:
|
||||
print(f"User: {query}")
|
||||
print("Agent: ", end="", flush=True)
|
||||
generated_code = ""
|
||||
async for chunk in agent.run_stream(query):
|
||||
async for chunk in agent.run_streaming(query):
|
||||
if chunk.text:
|
||||
print(chunk.text, end="", flush=True)
|
||||
code_interpreter_chunk = get_code_interpreter_chunk(chunk)
|
||||
|
||||
+1
-1
@@ -46,7 +46,7 @@ async def streaming_example() -> None:
|
||||
query = "What's the weather like in Portland?"
|
||||
print(f"User: {query}")
|
||||
print("Agent: ", end="", flush=True)
|
||||
async for chunk in agent.run_stream(query):
|
||||
async for chunk in agent.run_streaming(query):
|
||||
if chunk.text:
|
||||
print(chunk.text, end="", flush=True)
|
||||
print("\n")
|
||||
|
||||
+1
-1
@@ -46,7 +46,7 @@ async def streaming_example() -> None:
|
||||
query = "What's the weather like in Portland?"
|
||||
print(f"User: {query}")
|
||||
print("Agent: ", end="", flush=True)
|
||||
async for chunk in agent.run_stream(query):
|
||||
async for chunk in agent.run_streaming(query):
|
||||
if chunk.text:
|
||||
print(chunk.text, end="", flush=True)
|
||||
print("\n")
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
fmt = "ruff format"
|
||||
format.ref = "fmt"
|
||||
lint = "ruff check"
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework"
|
||||
pyright = "pyright"
|
||||
build = "uv build"
|
||||
test = "pytest --cov=agent_framework --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
Reference in New Issue
Block a user