From 1ed89d4db86928649a8a845b5fbfcd2dacddf541 Mon Sep 17 00:00:00 2001 From: Eduard van Valkenburg Date: Wed, 30 Jul 2025 10:50:22 +0200 Subject: [PATCH] Python: updated API in sync with dotnet (#269) * updated API in sync with dotnet * fix test * updated name and display_name * fixed mypy setup * add pre-commit cache --- .github/workflows/python-code-quality.yml | 6 +- python/packages/azure/pyproject.toml | 2 + .../agent_framework_foundry/_chat_client.py | 2 +- python/packages/foundry/pyproject.toml | 2 + .../packages/main/agent_framework/_agents.py | 132 ++++++++++-------- python/packages/main/pyproject.toml | 2 + .../packages/main/tests/main/test_agents.py | 65 ++++++--- .../azure_chat_client_basic.py | 2 +- .../agents/foundry/foundry_basic.py | 2 +- .../foundry/foundry_with_code_interpreter.py | 2 +- .../openai_chat_client_basic.py | 2 +- .../openai_responses_client_basic.py | 2 +- python/shared_tasks.toml | 1 - 13 files changed, 138 insertions(+), 84 deletions(-) diff --git a/.github/workflows/python-code-quality.yml b/.github/workflows/python-code-quality.yml index b6b0591216..2cd25fa5d7 100644 --- a/.github/workflows/python-code-quality.yml +++ b/.github/workflows/python-code-quality.yml @@ -38,9 +38,13 @@ jobs: cache-dependency-glob: "**/uv.lock" - name: Install the project run: uv sync --all-extras --dev + - uses: actions/cache@v3 + with: + path: ~/.cache/pre-commit + key: pre-commit|${{ matrix.python-version }}|${{ hashFiles('python/.pre-commit-config.yaml') }} - uses: pre-commit/action@v3.0.1 name: Run Pre-Commit Hooks with: extra_args: --config python/.pre-commit-config.yaml --all-files - name: Run Mypy - run: uv run mypy -p agent_framework + run: uv run poe mypy diff --git a/python/packages/azure/pyproject.toml b/python/packages/azure/pyproject.toml index 3a048cb1d7..152cb15ddb 100644 --- a/python/packages/azure/pyproject.toml +++ b/python/packages/azure/pyproject.toml @@ -74,6 +74,8 @@ exclude_dirs = ["tests"] [tool.poe] executor.type = "uv" include = "../../shared_tasks.toml" +[tool.poe.tasks] +mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_azure" [tool.uv.build-backend] module-name = "agent_framework_azure" diff --git a/python/packages/foundry/agent_framework_foundry/_chat_client.py b/python/packages/foundry/agent_framework_foundry/_chat_client.py index 135cf1fe42..4a1994583f 100644 --- a/python/packages/foundry/agent_framework_foundry/_chat_client.py +++ b/python/packages/foundry/agent_framework_foundry/_chat_client.py @@ -96,7 +96,7 @@ class FoundrySettings(AFBaseSettings): class FoundryChatClient(ChatClientBase): """Azure AI Foundry Chat client.""" - MODEL_PROVIDER_NAME: ClassVar[str] = "azure_ai_foundry" # type: ignore[reportIncompatibleVariableOverride] + MODEL_PROVIDER_NAME: ClassVar[str] = "azure_ai_foundry" # type: ignore[reportIncompatibleVariableOverride, misc] client: AIProjectClient = Field(...) credential: AsyncTokenCredential | None = Field(...) agent_id: str | None = Field(default=None) diff --git a/python/packages/foundry/pyproject.toml b/python/packages/foundry/pyproject.toml index ce6828400f..b8621931e8 100644 --- a/python/packages/foundry/pyproject.toml +++ b/python/packages/foundry/pyproject.toml @@ -76,6 +76,8 @@ exclude_dirs = ["tests"] [tool.poe] executor.type = "uv" include = "../../shared_tasks.toml" +[tool.poe.tasks] +mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_foundry" [tool.uv.build-backend] module-name = "agent_framework_foundry" diff --git a/python/packages/main/agent_framework/_agents.py b/python/packages/main/agent_framework/_agents.py index 30c98fc1e5..c9566bff77 100644 --- a/python/packages/main/agent_framework/_agents.py +++ b/python/packages/main/agent_framework/_agents.py @@ -6,11 +6,6 @@ from enum import Enum from typing import Any, Literal, Protocol, TypeVar, runtime_checkable from uuid import uuid4 -if sys.version_info >= (3, 11): - from typing import Self # pragma: no cover -else: - from typing_extensions import Self # pragma: no cover - from pydantic import BaseModel, Field from ._clients import ChatClient @@ -28,12 +23,17 @@ from ._types import ( ) from .exceptions import AgentExecutionException +if sys.version_info >= (3, 11): + from typing import Self # pragma: no cover +else: + from typing_extensions import Self # pragma: no cover + TThreadType = TypeVar("TThreadType", bound="AgentThread") # region AgentThread __all__ = [ - "Agent", + "AIAgent", "AgentBase", "AgentThread", "ChatClientAgent", @@ -77,7 +77,7 @@ class MessagesRetrievableThread(Protocol): @runtime_checkable -class Agent(Protocol): +class AIAgent(Protocol): """A protocol for an agent that can be invoked.""" @property @@ -90,6 +90,11 @@ class Agent(Protocol): """Returns the name of the agent.""" ... + @property + def display_name(self) -> str: + """Returns the display name of the agent.""" + ... + @property def description(self) -> str | None: """Returns the description of the agent.""" @@ -124,7 +129,7 @@ class Agent(Protocol): """ ... - def run_stream( + def run_streaming( self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, *, @@ -157,17 +162,19 @@ class Agent(Protocol): class AgentBase(AFBaseModel): - """Base class for all agents. + """Base class for all Agent Framework agents. Attributes: id: The unique identifier of the agent If no id is provided, a new UUID will be generated. - name: The name of the agent - description: The description of the agent + name: The name of the agent, can be None. + description: The description of the agent. + display_name: The display name of the agent, which is either the name or id. + """ id: str = Field(default_factory=lambda: str(uuid4())) - name: str = Field(default="UnnamedAgent") + name: str | None = None description: str | None = None async def _notify_thread_of_new_messages( @@ -177,6 +184,43 @@ class AgentBase(AFBaseModel): if isinstance(new_messages, ChatMessage) or len(new_messages) > 0: await thread.on_new_messages(new_messages) + @property + def display_name(self) -> str: + """Returns the display name of the agent. + + This is the name if present, otherwise the id. + """ + return self.name or self.id + + def _validate_or_create_thread_type( + self, + thread: AgentThread | None, + construct_thread: Callable[[], TThreadType], + expected_type: type[TThreadType], + ) -> TThreadType: + """Validate or create a AgentThread of the right type. + + Args: + thread: The thread to validate or create. + construct_thread: A callable that constructs a new thread if `thread` is None. + expected_type: The expected type of the thread. + + Returns: + The validated or newly created thread of the expected type. + + Raises: + AgentExecutionException: If the thread is not of the expected type. + """ + if thread is None: + return construct_thread() + + if not isinstance(thread, expected_type): + raise AgentExecutionException( + f"{self.__class__.__name__} currently only supports agent threads of type {expected_type.__name__}." + ) + + return thread + # region ChatClientAgentThread @@ -442,13 +486,7 @@ class ChatClientAgent(AgentBase): will only be passed to functions that are called. """ input_messages = self._normalize_messages(messages) - - thread, thread_messages = await self._prepare_thread_and_messages( - thread=thread, - input_messages=input_messages, - construct_thread=lambda: ChatClientAgentThread(), - expected_type=ChatClientAgentThread, - ) + thread, thread_messages = await self._prepare_thread_and_messages(thread=thread, input_messages=input_messages) response = await self.chat_client.get_response( messages=thread_messages, @@ -491,7 +529,7 @@ class ChatClientAgent(AgentBase): additional_properties=response.additional_properties, ) - async def run_stream( + async def run_streaming( self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, *, @@ -523,7 +561,7 @@ class ChatClientAgent(AgentBase): """Stream the agent with the given messages and options. Remarks: - Since you won't always call the agent.run_stream directly, but it get's called + Since you won't always call the agent.run_streaming directly, but it get's called through orchestration, it is advised to set your default values for all the chat client parameters in the agent constructor. If both parameters are used, the ones passed to the run methods take precedence. @@ -552,14 +590,7 @@ class ChatClientAgent(AgentBase): """ input_messages = self._normalize_messages(messages) - - thread, thread_messages = await self._prepare_thread_and_messages( - thread=thread, - input_messages=input_messages, - construct_thread=lambda: ChatClientAgentThread(), - expected_type=ChatClientAgentThread, - ) - + thread, thread_messages = await self._prepare_thread_and_messages(thread=thread, input_messages=input_messages) response_updates: list[ChatResponseUpdate] = [] async for update in self.chat_client.get_streaming_response( @@ -607,7 +638,7 @@ class ChatClientAgent(AgentBase): await self._notify_thread_of_new_messages(thread, input_messages) await self._notify_thread_of_new_messages(thread, response.messages) - def get_new_thread(self) -> AgentThread: + def get_new_thread(self) -> ChatClientAgentThread: return ChatClientAgentThread() def _update_thread_with_type_and_conversation_id( @@ -644,45 +675,32 @@ class ChatClientAgent(AgentBase): *, thread: AgentThread | None, input_messages: list[ChatMessage] | None = None, - construct_thread: Callable[[], TThreadType], - expected_type: type[TThreadType], - ) -> tuple[TThreadType, list[ChatMessage]]: - """Prepare thread and messages for agent execution. + ) -> tuple[ChatClientAgentThread, list[ChatMessage]]: + """Prepare the messages for agent execution. Args: - thread: The conversation thread, or None to create a new one. + thread: The conversation thread. input_messages: Messages to process. - construct_thread: Factory function to create a new thread. - expected_type: Expected thread type for validation. Returns: - Tuple of the thread and normalized messages. + The validated thread and normalized messages. Raises: - AgentExecutionException: If thread type is incompatible. + AgentExecutionException: If the thread is not of the expected type. """ + validated_thread: ChatClientAgentThread = self._validate_or_create_thread_type( # type: ignore[reportAssignmentType] + thread=thread, + construct_thread=self.get_new_thread, + expected_type=ChatClientAgentThread, + ) messages: list[ChatMessage] = [] if self.instructions: messages.append(ChatMessage(role=ChatRole.SYSTEM, text=self.instructions)) - - if thread is None: - thread = construct_thread() - - if not isinstance(thread, expected_type): - raise AgentExecutionException( - f"{self.__class__.__name__} currently only supports agent threads of type {expected_type.__name__}." - ) - - # Add any existing messages from the thread to the messages to be sent to the chat client. - if isinstance(thread, MessagesRetrievableThread): - async for message in thread.get_messages(): + if isinstance(validated_thread, MessagesRetrievableThread): + async for message in validated_thread.get_messages(): messages.append(message) - - if input_messages is None: - return thread, messages - - messages.extend(input_messages) - return thread, messages + messages.extend(input_messages or []) + return validated_thread, messages def _normalize_messages( self, diff --git a/python/packages/main/pyproject.toml b/python/packages/main/pyproject.toml index d36d5633a7..a98a8fbf08 100644 --- a/python/packages/main/pyproject.toml +++ b/python/packages/main/pyproject.toml @@ -87,6 +87,8 @@ exclude_dirs = ["tests"] [tool.poe] executor.type = "uv" include = "../../shared_tasks.toml" +[tool.poe.tasks] +mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework" [tool.uv.build-backend] module-name = "agent_framework" diff --git a/python/packages/main/tests/main/test_agents.py b/python/packages/main/tests/main/test_agents.py index 8bcb75a35a..bd514c0a60 100644 --- a/python/packages/main/tests/main/test_agents.py +++ b/python/packages/main/tests/main/test_agents.py @@ -7,10 +7,10 @@ from uuid import uuid4 from pytest import fixture, raises from agent_framework import ( - Agent, AgentRunResponse, AgentRunResponseUpdate, AgentThread, + AIAgent, ChatClient, ChatClientAgent, ChatClientAgentThread, @@ -33,7 +33,7 @@ class MockAgentThread(AgentThread): # Mock Agent implementation for testing -class MockAgent(Agent): +class MockAgent(AIAgent): @property def id(self) -> str: return str(uuid4()) @@ -56,7 +56,7 @@ class MockAgent(Agent): ) -> AgentRunResponse: return AgentRunResponse(messages=[ChatMessage(role=ChatRole.ASSISTANT, contents=[TextContent("Response")])]) - async def run_stream( + async def run_streaming( self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, *, @@ -105,7 +105,7 @@ def agent_thread() -> AgentThread: @fixture -def agent() -> Agent: +def agent() -> AIAgent: return MockAgent() @@ -118,21 +118,21 @@ def test_agent_thread_type(agent_thread: AgentThread) -> None: assert isinstance(agent_thread, AgentThread) -def test_agent_type(agent: Agent) -> None: - assert isinstance(agent, Agent) +def test_agent_type(agent: AIAgent) -> None: + assert isinstance(agent, AIAgent) -async def test_agent_run(agent: Agent) -> None: +async def test_agent_run(agent: AIAgent) -> None: response = await agent.run("test") assert response.messages[0].role == ChatRole.ASSISTANT assert response.messages[0].text == "Response" -async def test_agent_run_stream(agent: Agent) -> None: +async def test_agent_run_streaming(agent: AIAgent) -> None: async def collect_updates(updates: AsyncIterable[AgentRunResponseUpdate]) -> list[AgentRunResponseUpdate]: return [u async for u in updates] - updates = await collect_updates(agent.run_stream(messages="test")) + updates = await collect_updates(agent.run_streaming(messages="test")) assert len(updates) == 1 assert updates[0].text == "Response" @@ -191,7 +191,7 @@ async def test_chat_client_agent_thread_on_new_messages_in_memory() -> None: def test_chat_client_agent_type(chat_client: ChatClient) -> None: chat_client_agent = ChatClientAgent(chat_client=chat_client) - assert isinstance(chat_client_agent, Agent) + assert isinstance(chat_client_agent, AIAgent) async def test_chat_client_agent_init(chat_client: ChatClient) -> None: @@ -199,8 +199,19 @@ async def test_chat_client_agent_init(chat_client: ChatClient) -> None: agent = ChatClientAgent(chat_client=chat_client, id=agent_id, description="Test") assert agent.id == agent_id - assert agent.name == "UnnamedAgent" + assert agent.name is None assert agent.description == "Test" + assert agent.display_name == agent_id # Display name defaults to id if name is None + + +async def test_chat_client_agent_init_with_name(chat_client: ChatClient) -> None: + agent_id = str(uuid4()) + agent = ChatClientAgent(chat_client=chat_client, id=agent_id, name="Test Agent", description="Test") + + assert agent.id == agent_id + assert agent.name == "Test Agent" + assert agent.description == "Test" + assert agent.display_name == "Test Agent" # Display name is the name if present async def test_chat_client_agent_run(chat_client: ChatClient) -> None: @@ -211,10 +222,10 @@ async def test_chat_client_agent_run(chat_client: ChatClient) -> None: assert result.text == "test response" -async def test_chat_client_agent_run_stream(chat_client: ChatClient) -> None: +async def test_chat_client_agent_run_streaming(chat_client: ChatClient) -> None: agent = ChatClientAgent(chat_client=chat_client) - result = await AgentRunResponse.from_agent_response_generator(agent.run_stream("Hello")) + result = await AgentRunResponse.from_agent_response_generator(agent.run_streaming("Hello")) assert result.text == "test streaming response" @@ -232,19 +243,35 @@ async def test_chat_client_agent_prepare_thread_and_messages(chat_client: ChatCl message = ChatMessage(role=ChatRole.USER, text="Hello") thread = ChatClientAgentThread(messages=[message]) - result_thread, result_messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage] - thread=thread, - input_messages=[ChatMessage(role=ChatRole.USER, text="Test")], - construct_thread=lambda: ChatClientAgentThread(), - expected_type=ChatClientAgentThread, - ) + result_thread = agent._validate_or_create_thread_type( + thread, lambda: ChatClientAgentThread(), expected_type=ChatClientAgentThread + ) # type: ignore[reportPrivateUsage] assert result_thread == thread + assert isinstance(result_thread, ChatClientAgentThread) + + _, result_messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage] + thread=result_thread, + input_messages=[ChatMessage(role=ChatRole.USER, text="Test")], + ) + assert len(result_messages) == 2 assert result_messages[0] == message assert result_messages[1].text == "Test" +async def test_chat_client_agent_validate_or_create_thread(chat_client: ChatClient) -> None: + agent = ChatClientAgent(chat_client=chat_client) + thread = None + + result_thread = agent._validate_or_create_thread_type( + thread, lambda: ChatClientAgentThread(), expected_type=ChatClientAgentThread + ) # type: ignore[reportPrivateUsage] + + assert result_thread != thread + assert isinstance(result_thread, ChatClientAgentThread) + + async def test_chat_client_agent_update_thread_id() -> None: chat_client = MockChatClient( mock_response=ChatResponse( diff --git a/python/samples/getting_started/agents/azure_chat_client/azure_chat_client_basic.py b/python/samples/getting_started/agents/azure_chat_client/azure_chat_client_basic.py index 21a8a25b27..80564a8785 100644 --- a/python/samples/getting_started/agents/azure_chat_client/azure_chat_client_basic.py +++ b/python/samples/getting_started/agents/azure_chat_client/azure_chat_client_basic.py @@ -48,7 +48,7 @@ async def streaming_example() -> None: query = "What's the weather like in Portland?" print(f"User: {query}") print("Agent: ", end="", flush=True) - async for chunk in agent.run_stream(query): + async for chunk in agent.run_streaming(query): if chunk.text: print(chunk.text, end="", flush=True) print("\n") diff --git a/python/samples/getting_started/agents/foundry/foundry_basic.py b/python/samples/getting_started/agents/foundry/foundry_basic.py index bfcbed40ba..c1759ebdfc 100644 --- a/python/samples/getting_started/agents/foundry/foundry_basic.py +++ b/python/samples/getting_started/agents/foundry/foundry_basic.py @@ -48,7 +48,7 @@ async def streaming_example() -> None: query = "What's the weather like in Portland?" print(f"User: {query}") print("Agent: ", end="", flush=True) - async for chunk in agent.run_stream(query): + async for chunk in agent.run_streaming(query): if chunk.text: print(chunk.text, end="", flush=True) print("\n") diff --git a/python/samples/getting_started/agents/foundry/foundry_with_code_interpreter.py b/python/samples/getting_started/agents/foundry/foundry_with_code_interpreter.py index 1d11123a5e..9f321a6470 100644 --- a/python/samples/getting_started/agents/foundry/foundry_with_code_interpreter.py +++ b/python/samples/getting_started/agents/foundry/foundry_with_code_interpreter.py @@ -44,7 +44,7 @@ async def main() -> None: print(f"User: {query}") print("Agent: ", end="", flush=True) generated_code = "" - async for chunk in agent.run_stream(query): + async for chunk in agent.run_streaming(query): if chunk.text: print(chunk.text, end="", flush=True) code_interpreter_chunk = get_code_interpreter_chunk(chunk) diff --git a/python/samples/getting_started/agents/openai_chat_client/openai_chat_client_basic.py b/python/samples/getting_started/agents/openai_chat_client/openai_chat_client_basic.py index 01425caba2..41ed787aa5 100644 --- a/python/samples/getting_started/agents/openai_chat_client/openai_chat_client_basic.py +++ b/python/samples/getting_started/agents/openai_chat_client/openai_chat_client_basic.py @@ -46,7 +46,7 @@ async def streaming_example() -> None: query = "What's the weather like in Portland?" print(f"User: {query}") print("Agent: ", end="", flush=True) - async for chunk in agent.run_stream(query): + async for chunk in agent.run_streaming(query): if chunk.text: print(chunk.text, end="", flush=True) print("\n") diff --git a/python/samples/getting_started/agents/openai_responses_client/openai_responses_client_basic.py b/python/samples/getting_started/agents/openai_responses_client/openai_responses_client_basic.py index 6c29260a69..174973fd84 100644 --- a/python/samples/getting_started/agents/openai_responses_client/openai_responses_client_basic.py +++ b/python/samples/getting_started/agents/openai_responses_client/openai_responses_client_basic.py @@ -46,7 +46,7 @@ async def streaming_example() -> None: query = "What's the weather like in Portland?" print(f"User: {query}") print("Agent: ", end="", flush=True) - async for chunk in agent.run_stream(query): + async for chunk in agent.run_streaming(query): if chunk.text: print(chunk.text, end="", flush=True) print("\n") diff --git a/python/shared_tasks.toml b/python/shared_tasks.toml index 9448be94d4..f0f52d99ad 100644 --- a/python/shared_tasks.toml +++ b/python/shared_tasks.toml @@ -2,7 +2,6 @@ fmt = "ruff format" format.ref = "fmt" lint = "ruff check" -mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework" pyright = "pyright" build = "uv build" test = "pytest --cov=agent_framework --cov-report=term-missing:skip-covered tests"