Python: updated API in sync with dotnet (#269)

* updated API in sync with dotnet

* fix test

* updated name and display_name

* fixed mypy setup

* add pre-commit cache
This commit is contained in:
Eduard van Valkenburg
2025-07-30 10:50:22 +02:00
committed by GitHub
Unverified
parent dc993b4734
commit 1ed89d4db8
13 changed files with 138 additions and 84 deletions
+2
View File
@@ -74,6 +74,8 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_azure"
[tool.uv.build-backend]
module-name = "agent_framework_azure"
@@ -96,7 +96,7 @@ class FoundrySettings(AFBaseSettings):
class FoundryChatClient(ChatClientBase):
"""Azure AI Foundry Chat client."""
MODEL_PROVIDER_NAME: ClassVar[str] = "azure_ai_foundry" # type: ignore[reportIncompatibleVariableOverride]
MODEL_PROVIDER_NAME: ClassVar[str] = "azure_ai_foundry" # type: ignore[reportIncompatibleVariableOverride, misc]
client: AIProjectClient = Field(...)
credential: AsyncTokenCredential | None = Field(...)
agent_id: str | None = Field(default=None)
+2
View File
@@ -76,6 +76,8 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_foundry"
[tool.uv.build-backend]
module-name = "agent_framework_foundry"
+75 -57
View File
@@ -6,11 +6,6 @@ from enum import Enum
from typing import Any, Literal, Protocol, TypeVar, runtime_checkable
from uuid import uuid4
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
else:
from typing_extensions import Self # pragma: no cover
from pydantic import BaseModel, Field
from ._clients import ChatClient
@@ -28,12 +23,17 @@ from ._types import (
)
from .exceptions import AgentExecutionException
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
else:
from typing_extensions import Self # pragma: no cover
TThreadType = TypeVar("TThreadType", bound="AgentThread")
# region AgentThread
__all__ = [
"Agent",
"AIAgent",
"AgentBase",
"AgentThread",
"ChatClientAgent",
@@ -77,7 +77,7 @@ class MessagesRetrievableThread(Protocol):
@runtime_checkable
class Agent(Protocol):
class AIAgent(Protocol):
"""A protocol for an agent that can be invoked."""
@property
@@ -90,6 +90,11 @@ class Agent(Protocol):
"""Returns the name of the agent."""
...
@property
def display_name(self) -> str:
"""Returns the display name of the agent."""
...
@property
def description(self) -> str | None:
"""Returns the description of the agent."""
@@ -124,7 +129,7 @@ class Agent(Protocol):
"""
...
def run_stream(
def run_streaming(
self,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
@@ -157,17 +162,19 @@ class Agent(Protocol):
class AgentBase(AFBaseModel):
"""Base class for all agents.
"""Base class for all Agent Framework agents.
Attributes:
id: The unique identifier of the agent If no id is provided,
a new UUID will be generated.
name: The name of the agent
description: The description of the agent
name: The name of the agent, can be None.
description: The description of the agent.
display_name: The display name of the agent, which is either the name or id.
"""
id: str = Field(default_factory=lambda: str(uuid4()))
name: str = Field(default="UnnamedAgent")
name: str | None = None
description: str | None = None
async def _notify_thread_of_new_messages(
@@ -177,6 +184,43 @@ class AgentBase(AFBaseModel):
if isinstance(new_messages, ChatMessage) or len(new_messages) > 0:
await thread.on_new_messages(new_messages)
@property
def display_name(self) -> str:
"""Returns the display name of the agent.
This is the name if present, otherwise the id.
"""
return self.name or self.id
def _validate_or_create_thread_type(
self,
thread: AgentThread | None,
construct_thread: Callable[[], TThreadType],
expected_type: type[TThreadType],
) -> TThreadType:
"""Validate or create a AgentThread of the right type.
Args:
thread: The thread to validate or create.
construct_thread: A callable that constructs a new thread if `thread` is None.
expected_type: The expected type of the thread.
Returns:
The validated or newly created thread of the expected type.
Raises:
AgentExecutionException: If the thread is not of the expected type.
"""
if thread is None:
return construct_thread()
if not isinstance(thread, expected_type):
raise AgentExecutionException(
f"{self.__class__.__name__} currently only supports agent threads of type {expected_type.__name__}."
)
return thread
# region ChatClientAgentThread
@@ -442,13 +486,7 @@ class ChatClientAgent(AgentBase):
will only be passed to functions that are called.
"""
input_messages = self._normalize_messages(messages)
thread, thread_messages = await self._prepare_thread_and_messages(
thread=thread,
input_messages=input_messages,
construct_thread=lambda: ChatClientAgentThread(),
expected_type=ChatClientAgentThread,
)
thread, thread_messages = await self._prepare_thread_and_messages(thread=thread, input_messages=input_messages)
response = await self.chat_client.get_response(
messages=thread_messages,
@@ -491,7 +529,7 @@ class ChatClientAgent(AgentBase):
additional_properties=response.additional_properties,
)
async def run_stream(
async def run_streaming(
self,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
@@ -523,7 +561,7 @@ class ChatClientAgent(AgentBase):
"""Stream the agent with the given messages and options.
Remarks:
Since you won't always call the agent.run_stream directly, but it get's called
Since you won't always call the agent.run_streaming directly, but it get's called
through orchestration, it is advised to set your default values for
all the chat client parameters in the agent constructor.
If both parameters are used, the ones passed to the run methods take precedence.
@@ -552,14 +590,7 @@ class ChatClientAgent(AgentBase):
"""
input_messages = self._normalize_messages(messages)
thread, thread_messages = await self._prepare_thread_and_messages(
thread=thread,
input_messages=input_messages,
construct_thread=lambda: ChatClientAgentThread(),
expected_type=ChatClientAgentThread,
)
thread, thread_messages = await self._prepare_thread_and_messages(thread=thread, input_messages=input_messages)
response_updates: list[ChatResponseUpdate] = []
async for update in self.chat_client.get_streaming_response(
@@ -607,7 +638,7 @@ class ChatClientAgent(AgentBase):
await self._notify_thread_of_new_messages(thread, input_messages)
await self._notify_thread_of_new_messages(thread, response.messages)
def get_new_thread(self) -> AgentThread:
def get_new_thread(self) -> ChatClientAgentThread:
return ChatClientAgentThread()
def _update_thread_with_type_and_conversation_id(
@@ -644,45 +675,32 @@ class ChatClientAgent(AgentBase):
*,
thread: AgentThread | None,
input_messages: list[ChatMessage] | None = None,
construct_thread: Callable[[], TThreadType],
expected_type: type[TThreadType],
) -> tuple[TThreadType, list[ChatMessage]]:
"""Prepare thread and messages for agent execution.
) -> tuple[ChatClientAgentThread, list[ChatMessage]]:
"""Prepare the messages for agent execution.
Args:
thread: The conversation thread, or None to create a new one.
thread: The conversation thread.
input_messages: Messages to process.
construct_thread: Factory function to create a new thread.
expected_type: Expected thread type for validation.
Returns:
Tuple of the thread and normalized messages.
The validated thread and normalized messages.
Raises:
AgentExecutionException: If thread type is incompatible.
AgentExecutionException: If the thread is not of the expected type.
"""
validated_thread: ChatClientAgentThread = self._validate_or_create_thread_type( # type: ignore[reportAssignmentType]
thread=thread,
construct_thread=self.get_new_thread,
expected_type=ChatClientAgentThread,
)
messages: list[ChatMessage] = []
if self.instructions:
messages.append(ChatMessage(role=ChatRole.SYSTEM, text=self.instructions))
if thread is None:
thread = construct_thread()
if not isinstance(thread, expected_type):
raise AgentExecutionException(
f"{self.__class__.__name__} currently only supports agent threads of type {expected_type.__name__}."
)
# Add any existing messages from the thread to the messages to be sent to the chat client.
if isinstance(thread, MessagesRetrievableThread):
async for message in thread.get_messages():
if isinstance(validated_thread, MessagesRetrievableThread):
async for message in validated_thread.get_messages():
messages.append(message)
if input_messages is None:
return thread, messages
messages.extend(input_messages)
return thread, messages
messages.extend(input_messages or [])
return validated_thread, messages
def _normalize_messages(
self,
+2
View File
@@ -87,6 +87,8 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework"
[tool.uv.build-backend]
module-name = "agent_framework"
+46 -19
View File
@@ -7,10 +7,10 @@ from uuid import uuid4
from pytest import fixture, raises
from agent_framework import (
Agent,
AgentRunResponse,
AgentRunResponseUpdate,
AgentThread,
AIAgent,
ChatClient,
ChatClientAgent,
ChatClientAgentThread,
@@ -33,7 +33,7 @@ class MockAgentThread(AgentThread):
# Mock Agent implementation for testing
class MockAgent(Agent):
class MockAgent(AIAgent):
@property
def id(self) -> str:
return str(uuid4())
@@ -56,7 +56,7 @@ class MockAgent(Agent):
) -> AgentRunResponse:
return AgentRunResponse(messages=[ChatMessage(role=ChatRole.ASSISTANT, contents=[TextContent("Response")])])
async def run_stream(
async def run_streaming(
self,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
@@ -105,7 +105,7 @@ def agent_thread() -> AgentThread:
@fixture
def agent() -> Agent:
def agent() -> AIAgent:
return MockAgent()
@@ -118,21 +118,21 @@ def test_agent_thread_type(agent_thread: AgentThread) -> None:
assert isinstance(agent_thread, AgentThread)
def test_agent_type(agent: Agent) -> None:
assert isinstance(agent, Agent)
def test_agent_type(agent: AIAgent) -> None:
assert isinstance(agent, AIAgent)
async def test_agent_run(agent: Agent) -> None:
async def test_agent_run(agent: AIAgent) -> None:
response = await agent.run("test")
assert response.messages[0].role == ChatRole.ASSISTANT
assert response.messages[0].text == "Response"
async def test_agent_run_stream(agent: Agent) -> None:
async def test_agent_run_streaming(agent: AIAgent) -> None:
async def collect_updates(updates: AsyncIterable[AgentRunResponseUpdate]) -> list[AgentRunResponseUpdate]:
return [u async for u in updates]
updates = await collect_updates(agent.run_stream(messages="test"))
updates = await collect_updates(agent.run_streaming(messages="test"))
assert len(updates) == 1
assert updates[0].text == "Response"
@@ -191,7 +191,7 @@ async def test_chat_client_agent_thread_on_new_messages_in_memory() -> None:
def test_chat_client_agent_type(chat_client: ChatClient) -> None:
chat_client_agent = ChatClientAgent(chat_client=chat_client)
assert isinstance(chat_client_agent, Agent)
assert isinstance(chat_client_agent, AIAgent)
async def test_chat_client_agent_init(chat_client: ChatClient) -> None:
@@ -199,8 +199,19 @@ async def test_chat_client_agent_init(chat_client: ChatClient) -> None:
agent = ChatClientAgent(chat_client=chat_client, id=agent_id, description="Test")
assert agent.id == agent_id
assert agent.name == "UnnamedAgent"
assert agent.name is None
assert agent.description == "Test"
assert agent.display_name == agent_id # Display name defaults to id if name is None
async def test_chat_client_agent_init_with_name(chat_client: ChatClient) -> None:
agent_id = str(uuid4())
agent = ChatClientAgent(chat_client=chat_client, id=agent_id, name="Test Agent", description="Test")
assert agent.id == agent_id
assert agent.name == "Test Agent"
assert agent.description == "Test"
assert agent.display_name == "Test Agent" # Display name is the name if present
async def test_chat_client_agent_run(chat_client: ChatClient) -> None:
@@ -211,10 +222,10 @@ async def test_chat_client_agent_run(chat_client: ChatClient) -> None:
assert result.text == "test response"
async def test_chat_client_agent_run_stream(chat_client: ChatClient) -> None:
async def test_chat_client_agent_run_streaming(chat_client: ChatClient) -> None:
agent = ChatClientAgent(chat_client=chat_client)
result = await AgentRunResponse.from_agent_response_generator(agent.run_stream("Hello"))
result = await AgentRunResponse.from_agent_response_generator(agent.run_streaming("Hello"))
assert result.text == "test streaming response"
@@ -232,19 +243,35 @@ async def test_chat_client_agent_prepare_thread_and_messages(chat_client: ChatCl
message = ChatMessage(role=ChatRole.USER, text="Hello")
thread = ChatClientAgentThread(messages=[message])
result_thread, result_messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
thread=thread,
input_messages=[ChatMessage(role=ChatRole.USER, text="Test")],
construct_thread=lambda: ChatClientAgentThread(),
expected_type=ChatClientAgentThread,
)
result_thread = agent._validate_or_create_thread_type(
thread, lambda: ChatClientAgentThread(), expected_type=ChatClientAgentThread
) # type: ignore[reportPrivateUsage]
assert result_thread == thread
assert isinstance(result_thread, ChatClientAgentThread)
_, result_messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
thread=result_thread,
input_messages=[ChatMessage(role=ChatRole.USER, text="Test")],
)
assert len(result_messages) == 2
assert result_messages[0] == message
assert result_messages[1].text == "Test"
async def test_chat_client_agent_validate_or_create_thread(chat_client: ChatClient) -> None:
agent = ChatClientAgent(chat_client=chat_client)
thread = None
result_thread = agent._validate_or_create_thread_type(
thread, lambda: ChatClientAgentThread(), expected_type=ChatClientAgentThread
) # type: ignore[reportPrivateUsage]
assert result_thread != thread
assert isinstance(result_thread, ChatClientAgentThread)
async def test_chat_client_agent_update_thread_id() -> None:
chat_client = MockChatClient(
mock_response=ChatResponse(
@@ -48,7 +48,7 @@ async def streaming_example() -> None:
query = "What's the weather like in Portland?"
print(f"User: {query}")
print("Agent: ", end="", flush=True)
async for chunk in agent.run_stream(query):
async for chunk in agent.run_streaming(query):
if chunk.text:
print(chunk.text, end="", flush=True)
print("\n")
@@ -48,7 +48,7 @@ async def streaming_example() -> None:
query = "What's the weather like in Portland?"
print(f"User: {query}")
print("Agent: ", end="", flush=True)
async for chunk in agent.run_stream(query):
async for chunk in agent.run_streaming(query):
if chunk.text:
print(chunk.text, end="", flush=True)
print("\n")
@@ -44,7 +44,7 @@ async def main() -> None:
print(f"User: {query}")
print("Agent: ", end="", flush=True)
generated_code = ""
async for chunk in agent.run_stream(query):
async for chunk in agent.run_streaming(query):
if chunk.text:
print(chunk.text, end="", flush=True)
code_interpreter_chunk = get_code_interpreter_chunk(chunk)
@@ -46,7 +46,7 @@ async def streaming_example() -> None:
query = "What's the weather like in Portland?"
print(f"User: {query}")
print("Agent: ", end="", flush=True)
async for chunk in agent.run_stream(query):
async for chunk in agent.run_streaming(query):
if chunk.text:
print(chunk.text, end="", flush=True)
print("\n")
@@ -46,7 +46,7 @@ async def streaming_example() -> None:
query = "What's the weather like in Portland?"
print(f"User: {query}")
print("Agent: ", end="", flush=True)
async for chunk in agent.run_stream(query):
async for chunk in agent.run_streaming(query):
if chunk.text:
print(chunk.text, end="", flush=True)
print("\n")
-1
View File
@@ -2,7 +2,6 @@
fmt = "ruff format"
format.ref = "fmt"
lint = "ruff check"
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework"
pyright = "pyright"
build = "uv build"
test = "pytest --cov=agent_framework --cov-report=term-missing:skip-covered tests"