diff --git a/.github/workflows/python-merge-tests.yml b/.github/workflows/python-merge-tests.yml index 819822ae27..9868dc6c80 100644 --- a/.github/workflows/python-merge-tests.yml +++ b/.github/workflows/python-merge-tests.yml @@ -78,11 +78,13 @@ jobs: run: | uv sync --all-packages --all-extras --dev -U --prerelease=if-necessary-or-explicit - name: Test with pytest - run: uv run poe --directory ./packages/${{ env.PACKAGE_NAME }} test --junitxml=coverage.xml + timeout-minutes: 10 + run: uv run poe --directory ./packages/${{ env.PACKAGE_NAME }} test -n logical --dist loadfile --dist worksteal --junitxml=coverage.xml working-directory: ./python - - name: Test openai samples + - name: Test main samples + timeout-minutes: 10 if: env.RUN_SAMPLES_TESTS == 'true' - run: uv run pytest tests/samples/ -m "openai" --junitxml=coverage_samples_main.xml + run: uv run pytest tests/samples/ -m "openai" working-directory: ./python - name: Move coverage file run: | @@ -144,15 +146,17 @@ jobs: tenant-id: ${{ secrets.AZURE_TENANT_ID }} subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID }} - name: Test with pytest - run: uv run poe --directory ./packages/${{ env.PACKAGE_NAME }} test --junitxml=coverage.xml + timeout-minutes: 10 + run: uv run poe --directory ./packages/${{ env.PACKAGE_NAME }} test -n logical --dist loadfile --dist worksteal --junitxml=coverage.xml working-directory: ./python - name: Test azure samples + timeout-minutes: 10 if: env.RUN_SAMPLES_TESTS == 'true' - run: uv run pytest tests/samples/ -m "azure" --junitxml=coverage_samples_azure.xml + run: uv run pytest tests/samples/ -m "azure" working-directory: ./python - name: Move coverage file run: | - mv ./packages/${{ env.PACKAGE_NAME }}/coverage.xml coverage_${{ env.PACKAGE_NAME }}.xml + mv ./packages/${{ env.PACKAGE_NAME }}/coverage.xml ./coverage_${{ env.PACKAGE_NAME }}.xml working-directory: ./python - name: Upload coverage artifact uses: actions/upload-artifact@v4 @@ -209,11 +213,13 @@ jobs: tenant-id: ${{ secrets.AZURE_TENANT_ID }} subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID }} - name: Test with pytest - run: uv run poe --directory ./packages/${{ env.PACKAGE_NAME }} test --junitxml=coverage.xml + timeout-minutes: 10 + run: uv run poe --directory ./packages/${{ env.PACKAGE_NAME }} test -n logical --dist loadfile --dist worksteal --junitxml=coverage.xml working-directory: ./python - name: Test foundry samples + timeout-minutes: 10 if: env.RUN_SAMPLES_TESTS == 'true' - run: uv run pytest tests/samples/ -m "foundry" --junitxml=coverage_samples_foundry.xml + run: uv run pytest tests/samples/ -m "foundry" working-directory: ./python - name: Move coverage file run: | diff --git a/.github/workflows/python-test-coverage.yml b/.github/workflows/python-test-coverage.yml index 59c8b4b4bf..08ffe47158 100644 --- a/.github/workflows/python-test-coverage.yml +++ b/.github/workflows/python-test-coverage.yml @@ -36,7 +36,7 @@ jobs: - name: Install the project run: uv sync --all-extras --dev - name: Run all tests with coverage report - run: uv run poe all-tests --cov-report=xml:python-coverage.xml -q --junitxml=pytest.xml + run: uv run poe all-tests -n logical --dist loadfile --dist worksteal --cov-report=xml:python-coverage.xml -q --junitxml=pytest.xml - name: Upload coverage report uses: actions/upload-artifact@v4 with: diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 6552b7829c..5981f27757 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -48,7 +48,7 @@ jobs: run: | echo "PACKAGE_NAME=main" >> $GITHUB_ENV - name: Test with pytest - main - run: uv run poe --directory ./packages/${{ env.PACKAGE_NAME }} test --junitxml=coverage.xml + run: uv run poe --directory ./packages/${{ env.PACKAGE_NAME }} test -n logical --dist loadfile --dist worksteal --junitxml=coverage.xml working-directory: ./python - name: Move coverage file - main run: | @@ -70,7 +70,7 @@ jobs: run: | echo "PACKAGE_NAME=azure" >> $GITHUB_ENV - name: Test with pytest - azure - run: uv run poe --directory ./packages/${{ env.PACKAGE_NAME }} test --junitxml=coverage.xml + run: uv run poe --directory ./packages/${{ env.PACKAGE_NAME }} test -n logical --dist loadfile --dist worksteal --junitxml=coverage.xml working-directory: ./python - name: Move coverage file - azure run: | @@ -92,7 +92,7 @@ jobs: run: | echo "PACKAGE_NAME=foundry" >> $GITHUB_ENV - name: Test with pytest - foundry - run: uv run poe --directory ./packages/${{ env.PACKAGE_NAME }} test --junitxml=coverage.xml + run: uv run poe --directory ./packages/${{ env.PACKAGE_NAME }} test -n logical --dist loadfile --dist worksteal --junitxml=coverage.xml working-directory: ./python - name: Move coverage file - foundry run: | @@ -114,7 +114,7 @@ jobs: run: | echo "PACKAGE_NAME=workflow" >> $GITHUB_ENV - name: Test with pytest - workflow - run: uv run poe --directory ./packages/${{ env.PACKAGE_NAME }} test --junitxml=coverage.xml + run: uv run poe --directory ./packages/${{ env.PACKAGE_NAME }} test -n logical --dist loadfile --dist worksteal --junitxml=coverage.xml working-directory: ./python - name: Move coverage file - workflow run: | diff --git a/python/packages/azure/tests/test_azure_assistants_client.py b/python/packages/azure/tests/test_azure_assistants_client.py index abc128151a..7bd2f6571b 100644 --- a/python/packages/azure/tests/test_azure_assistants_client.py +++ b/python/packages/azure/tests/test_azure_assistants_client.py @@ -16,7 +16,6 @@ from agent_framework import ( ChatResponseUpdate, HostedCodeInterpreterTool, TextContent, - ai_function, ) from agent_framework.exceptions import ServiceInitializationError from azure.identity import AzureCliCredential @@ -543,45 +542,6 @@ async def test_azure_assistants_client_agent_level_tool_persistence(): assert any(term in second_response.text.lower() for term in ["miami", "sunny", "72"]) -@skip_if_azure_integration_tests_disabled -async def test_azure_assistants_client_run_level_tool_isolation(): - """Test that run-level tools are isolated to specific runs and don't persist with Azure Assistants Client.""" - # Counter to track how many times the weather tool is called - call_count = 0 - - @ai_function - async def get_weather_with_counter(location: Annotated[str, "The location as a city name"]) -> str: - """Get the current weather in a given location.""" - nonlocal call_count - call_count += 1 - return f"The weather in {location} is sunny and 72°F." - - async with ChatAgent( - chat_client=AzureAssistantsClient(credential=AzureCliCredential()), - instructions="You are a helpful assistant.", - ) as agent: - # First run - use run-level tool - first_response = await agent.run( - "What's the weather like in Chicago?", - tools=[get_weather_with_counter], # Run-level tool - ) - - assert isinstance(first_response, AgentRunResponse) - assert first_response.text is not None - # Should use the run-level weather tool (call count should be 1) - assert call_count == 1 - assert any(term in first_response.text.lower() for term in ["chicago", "sunny", "72"]) - - # Second run - run-level tool should NOT persist (key isolation test) - second_response = await agent.run("What's the weather like in Miami?") - - assert isinstance(second_response, AgentRunResponse) - assert second_response.text is not None - # Should NOT use the weather tool since it was only run-level in previous call - # Call count should still be 1 (no additional calls) - assert call_count == 1 - - def test_azure_assistants_client_entra_id_authentication() -> None: """Test Entra ID authentication path with credential.""" mock_credential = MagicMock() diff --git a/python/packages/azure/tests/test_azure_chat_client.py b/python/packages/azure/tests/test_azure_chat_client.py index 60417581fb..b4669d1ec0 100644 --- a/python/packages/azure/tests/test_azure_chat_client.py +++ b/python/packages/azure/tests/test_azure_chat_client.py @@ -2,7 +2,6 @@ import json import os -from typing import Annotated from unittest.mock import AsyncMock, MagicMock, patch import openai @@ -834,42 +833,3 @@ async def test_azure_chat_client_agent_level_tool_persistence(): assert second_response.text is not None # Should use the agent-level weather tool again assert any(term in second_response.text.lower() for term in ["miami", "sunny", "72"]) - - -@skip_if_azure_integration_tests_disabled -async def test_azure_chat_client_run_level_tool_isolation(): - """Test that run-level tools are isolated to specific runs and don't persist with Azure Chat Client.""" - # Counter to track how many times the weather tool is called - call_count = 0 - - @ai_function - async def get_weather_with_counter(location: Annotated[str, "The location as a city name"]) -> str: - """Get the current weather in a given location.""" - nonlocal call_count - call_count += 1 - return f"The weather in {location} is sunny and 72°F." - - async with ChatAgent( - chat_client=AzureChatClient(credential=AzureCliCredential()), - instructions="You are a helpful assistant.", - ) as agent: - # First run - use run-level tool - first_response = await agent.run( - "What's the weather like in Chicago?", - tools=[get_weather_with_counter], # Run-level tool - ) - - assert isinstance(first_response, AgentRunResponse) - assert first_response.text is not None - # Should use the run-level weather tool (call count should be 1) - assert call_count == 1 - assert any(term in first_response.text.lower() for term in ["chicago", "sunny", "72"]) - - # Second run - run-level tool should NOT persist (key isolation test) - second_response = await agent.run("What's the weather like in Miami?") - - assert isinstance(second_response, AgentRunResponse) - assert second_response.text is not None - # Should NOT use the weather tool since it was only run-level in previous call - # Call count should still be 1 (no additional calls) - assert call_count == 1 diff --git a/python/packages/azure/tests/test_azure_responses_client.py b/python/packages/azure/tests/test_azure_responses_client.py index dcbc41e839..2e02989478 100644 --- a/python/packages/azure/tests/test_azure_responses_client.py +++ b/python/packages/azure/tests/test_azure_responses_client.py @@ -459,42 +459,3 @@ async def test_azure_responses_client_agent_level_tool_persistence(): assert second_response.text is not None # Should use the agent-level weather tool again assert any(term in second_response.text.lower() for term in ["miami", "sunny", "72"]) - - -@skip_if_azure_integration_tests_disabled -async def test_azure_responses_client_run_level_tool_isolation(): - """Test that run-level tools are isolated to specific runs and don't persist with Azure Responses Client.""" - # Counter to track how many times the weather tool is called - call_count = 0 - - @ai_function - async def get_weather_with_counter(location: Annotated[str, "The location as a city name"]) -> str: - """Get the current weather in a given location.""" - nonlocal call_count - call_count += 1 - return f"The weather in {location} is sunny and 72°F." - - async with ChatAgent( - chat_client=AzureResponsesClient(credential=AzureCliCredential()), - instructions="You are a helpful assistant.", - ) as agent: - # First run - use run-level tool - first_response = await agent.run( - "What's the weather like in Chicago?", - tools=[get_weather_with_counter], # Run-level tool - ) - - assert isinstance(first_response, AgentRunResponse) - assert first_response.text is not None - # Should use the run-level weather tool (call count should be 1) - assert call_count == 1 - assert any(term in first_response.text.lower() for term in ["chicago", "sunny", "72"]) - - # Second run - run-level tool should NOT persist (key isolation test) - second_response = await agent.run("What's the weather like in Miami?") - - assert isinstance(second_response, AgentRunResponse) - assert second_response.text is not None - # Should NOT use the weather tool since it was only run-level in previous call - # Call count should still be 1 (no additional calls) - assert call_count == 1 diff --git a/python/packages/foundry/tests/test_foundry_chat_client.py b/python/packages/foundry/tests/test_foundry_chat_client.py index 967b178ede..b77efe5954 100644 --- a/python/packages/foundry/tests/test_foundry_chat_client.py +++ b/python/packages/foundry/tests/test_foundry_chat_client.py @@ -22,7 +22,6 @@ from agent_framework import ( Role, TextContent, UriContent, - ai_function, ) from agent_framework import __version__ as AF_VERSION from agent_framework.exceptions import ServiceInitializationError @@ -933,42 +932,3 @@ async def test_foundry_chat_client_agent_level_tool_persistence(): assert second_response.text is not None # Should use the agent-level weather tool again assert any(term in second_response.text.lower() for term in ["miami", "sunny", "25"]) - - -@skip_if_foundry_integration_tests_disabled -async def test_foundry_chat_client_run_level_tool_isolation(): - """Test that run-level tools are isolated to specific runs and don't persist with FoundryChatClient.""" - # Counter to track how many times the weather tool is called - call_count = 0 - - @ai_function - async def get_weather_with_counter(location: Annotated[str, "The location as a city name"]) -> str: - """Get the current weather in a given location.""" - nonlocal call_count - call_count += 1 - return f"The weather in {location} is sunny and 25°C." - - async with ChatAgent( - chat_client=FoundryChatClient(async_credential=AzureCliCredential()), - instructions="You are a helpful assistant.", - ) as agent: - # First run - use run-level tool - first_response = await agent.run( - "What's the weather like in Chicago?", - tools=[get_weather_with_counter], # Run-level tool - ) - - assert isinstance(first_response, AgentRunResponse) - assert first_response.text is not None - # Should use the run-level weather tool (call count should be 1) - assert call_count == 1 - assert any(term in first_response.text.lower() for term in ["chicago", "sunny", "25"]) - - # Second run - run-level tool should NOT persist (key isolation test) - second_response = await agent.run("What's the weather like in Miami?") - - assert isinstance(second_response, AgentRunResponse) - assert second_response.text is not None - # Should NOT use the weather tool since it was only run-level in previous call - # Call count should still be 1 (no additional calls) - assert call_count == 1 diff --git a/python/packages/main/agent_framework/_clients.py b/python/packages/main/agent_framework/_clients.py index 4ad1119162..f80d1f5dc1 100644 --- a/python/packages/main/agent_framework/_clients.py +++ b/python/packages/main/agent_framework/_clients.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, runt from pydantic import BaseModel from ._logging import get_logger +from ._mcp import MCPTool from ._pydantic import AFBaseModel from ._threads import ChatMessageStore from ._tools import AIFunction, ToolProtocol @@ -391,6 +392,25 @@ class BaseChatClient(AFBaseModel, ABC): return_messages.append(msg) return return_messages + @staticmethod + def _normalize_tools( + tools: ToolProtocol + | MutableMapping[str, Any] + | Callable[..., Any] + | list[ToolProtocol | MutableMapping[str, Any] | Callable[..., Any]] + | None = None, + ) -> list[ToolProtocol | dict[str, Any] | Callable[..., Any]]: + """Normalize the tools input to a list of tools.""" + final_tools: list[ToolProtocol | dict[str, Any] | Callable[..., Any]] = [] + if not tools: + return final_tools + for tool in tools if isinstance(tools, list) else [tools]: # type: ignore[reportUnknownType] + if isinstance(tool, MCPTool): + final_tools.extend(tool.functions) # type: ignore + continue + final_tools.append(tool) # type: ignore + return final_tools + # region Internal methods to be implemented by the derived classes @abstractmethod @@ -513,7 +533,7 @@ class BaseChatClient(AFBaseModel, ABC): temperature=temperature, top_p=top_p, tool_choice=tool_choice, - tools=tools, # type: ignore + tools=self._normalize_tools(tools), # type: ignore user=user, additional_properties=additional_properties or {}, ) @@ -592,7 +612,7 @@ class BaseChatClient(AFBaseModel, ABC): temperature=temperature, top_p=top_p, tool_choice=tool_choice, - tools=tools, # type: ignore + tools=self._normalize_tools(tools), # type: ignore user=user, additional_properties=additional_properties or {}, **kwargs, diff --git a/python/packages/main/agent_framework/_mcp.py b/python/packages/main/agent_framework/_mcp.py index 27a4c3d5af..180aecfc95 100644 --- a/python/packages/main/agent_framework/_mcp.py +++ b/python/packages/main/agent_framework/_mcp.py @@ -12,7 +12,6 @@ from typing import TYPE_CHECKING, Any from mcp import types from mcp.client.session import ClientSession -from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.client.streamable_http import streamablehttp_client from mcp.client.websocket import websocket_client @@ -49,7 +48,6 @@ LOG_LEVEL_MAPPING: dict[types.LoggingLevel, int] = { } __all__ = [ - "MCPSseTools", "MCPStdioTool", "MCPStreamableHTTPTool", "MCPWebsocketTool", @@ -224,7 +222,7 @@ def _normalize_mcp_name(name: str) -> str: class MCPTool: - """Base class with the MCP logic.""" + """Main MCP class, to initialize use one of the subclasses.""" def __init__( self, @@ -567,82 +565,6 @@ class MCPStdioTool(MCPTool): return stdio_client(server=StdioServerParameters(**args)) -class MCPSseTools(MCPTool): - """MCP sse server configuration.""" - - def __init__( - self, - name: str, - url: str, - *, - load_tools: bool = True, - load_prompts: bool = True, - request_timeout: int | None = None, - session: ClientSession | None = None, - description: str | None = None, - additional_properties: dict[str, Any] | None = None, - headers: dict[str, Any] | None = None, - timeout: float | None = None, - sse_read_timeout: float | None = None, - chat_client: "ChatClientProtocol | None" = None, - **kwargs: Any, - ) -> None: - """Initialize the MCP sse plugin. - - The arguments are used to create a sse client. - see mcp.client.sse.sse_client for more details. - - Any extra arguments passed to the constructor will be passed to the - sse client constructor. - - Args: - name: The name of the plugin. - url: The URL of the MCP server. - load_tools: Whether to load tools from the MCP server. - load_prompts: Whether to load prompts from the MCP server. - request_timeout: The default timeout used for all requests. - session: The session to use for the MCP connection. - description: The description of the plugin. - additional_properties: Additional properties. - headers: The headers to send with the request. - timeout: The timeout for the request. - sse_read_timeout: The timeout for reading from the SSE stream. - chat_client: The chat client to use for sampling. - kwargs: Any extra arguments to pass to the sse client. - - """ - super().__init__( - name=name, - description=description, - additional_properties=additional_properties, - session=session, - chat_client=chat_client, - load_tools=load_tools, - load_prompts=load_prompts, - request_timeout=request_timeout, - ) - self.url = url - self.headers = headers or {} - self.timeout = timeout - self.sse_read_timeout = sse_read_timeout - self._client_kwargs = kwargs - - def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - """Get an MCP SSE client.""" - args: dict[str, Any] = { - "url": self.url, - } - if self.headers: - args["headers"] = self.headers - if self.timeout is not None: - args["timeout"] = self.timeout - if self.sse_read_timeout is not None: - args["sse_read_timeout"] = self.sse_read_timeout - if self._client_kwargs: - args.update(self._client_kwargs) - return sse_client(**args) - - class MCPStreamableHTTPTool(MCPTool): """MCP streamable http server configuration.""" diff --git a/python/packages/main/agent_framework/_tools.py b/python/packages/main/agent_framework/_tools.py index 5bd9566cb5..4b12835d43 100644 --- a/python/packages/main/agent_framework/_tools.py +++ b/python/packages/main/agent_framework/_tools.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. - import inspect -from collections.abc import Awaitable, Callable +import sys +from collections.abc import Awaitable, Callable, Collection from functools import wraps from time import perf_counter from typing import ( @@ -9,6 +9,7 @@ from typing import ( Annotated, Any, Generic, + Literal, Protocol, TypeVar, get_args, @@ -17,15 +18,21 @@ from typing import ( ) from opentelemetry import metrics, trace -from pydantic import BaseModel, Field, PrivateAttr, create_model +from pydantic import AnyUrl, BaseModel, Field, PrivateAttr, ValidationError, create_model, field_validator from ._logging import get_logger from ._pydantic import AFBaseModel +from .exceptions import ToolException from .telemetry import GenAIAttributes, start_as_current_span if TYPE_CHECKING: from ._types import Contents +if sys.version_info >= (3, 12): + from typing import TypedDict # pragma: no cover +else: + from typing_extensions import TypedDict # pragma: no cover + tracer: trace.Tracer = trace.get_tracer("agent_framework") meter: metrics.Meter = metrics.get_meter_provider().get_meter("agent_framework") logger = get_logger() @@ -34,6 +41,8 @@ __all__ = [ "AIFunction", "HostedCodeInterpreterTool", "HostedFileSearchTool", + "HostedMCPSpecificApproval", + "HostedMCPTool", "HostedWebSearchTool", "ToolProtocol", "ai_function", @@ -197,13 +206,88 @@ class HostedWebSearchTool(BaseTool): args: dict[str, Any] = { "name": "web_search", } + super().__init__(**args, **kwargs) + + +class HostedMCPSpecificApproval(TypedDict, total=False): + """Represents the `specific` mode for a hosted tool. + + When using this mode, the user must specify which tools always or never require approval. + This is represented as a dictionary with two optional keys: + - `always_require_approval`: A sequence of tool names that always require approval. + - `never_require_approval`: A sequence of tool names that never require approval. + + """ + + always_require_approval: Collection[str] | None + never_require_approval: Collection[str] | None + + +class HostedMCPTool(BaseTool): + """Represents a MCP tool that is managed and executed by the service.""" + + url: AnyUrl + approval_mode: Literal["always_require", "never_require"] | HostedMCPSpecificApproval | None = None + allowed_tools: set[str] | None = None + headers: dict[str, str] | None = None + + def __init__( + self, + *, + name: str, + description: str | None = None, + url: AnyUrl | str, + approval_mode: Literal["always_require", "never_require"] | HostedMCPSpecificApproval | None = None, + allowed_tools: Collection[str] | None = None, + headers: dict[str, str] | None = None, + additional_properties: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Create a hosted MCP tool. + + Args: + name: The name of the tool. + description: A description of the tool. + url: The URL of the tool. + approval_mode: The approval mode for the tool. This can be: + - "always_require": The tool always requires approval before use. + - "never_require": The tool never requires approval before use. + - A dict with keys `always_require_approval` or `never_require_approval`, + followed by a sequence of strings with the names of the relevant tools. + allowed_tools: A list of tools that are allowed to use this tool. + headers: Headers to include in requests to the tool. + additional_properties: Additional properties to include in the tool definition. + **kwargs: Additional keyword arguments to pass to the base class. + """ + args: dict[str, Any] = { + "name": name, + "url": url, + } + if allowed_tools is not None: + args["allowed_tools"] = allowed_tools + if approval_mode is not None: + args["approval_mode"] = approval_mode + if headers is not None: + args["headers"] = headers if description is not None: args["description"] = description if additional_properties is not None: args["additional_properties"] = additional_properties - if "name" in kwargs: - raise ValueError("The 'name' argument is reserved for the HostedFileSearchTool and cannot be set.") - super().__init__(**args, **kwargs) + try: + super().__init__(**args, **kwargs) + except ValidationError as err: + raise ToolException(f"Error initializing HostedMCPTool: {err}", inner_exception=err) from err + + @field_validator("approval_mode") + def validate_approval_mode(cls, approval_mode: str | dict[str, Any] | None) -> str | dict[str, Any] | None: + """Validate the approval_mode field to ensure it is one of the accepted values.""" + if approval_mode is None or not isinstance(approval_mode, dict): + return approval_mode + # Validate that the dict has sets + for key, value in approval_mode.items(): + if not isinstance(value, set): + approval_mode[key] = set(value) # Convert to set if it's a list or other collection + return approval_mode class HostedFileSearchTool(BaseTool): diff --git a/python/packages/main/agent_framework/_types.py b/python/packages/main/agent_framework/_types.py index c80d01dc60..60108d982a 100644 --- a/python/packages/main/agent_framework/_types.py +++ b/python/packages/main/agent_framework/_types.py @@ -29,7 +29,7 @@ from pydantic import ( from ._logging import get_logger from ._pydantic import AFBaseModel from ._tools import ToolProtocol, ai_function -from .exceptions import AgentFrameworkException +from .exceptions import AdditionItemMismatch if sys.version_info >= (3, 11): from typing import Self # pragma: no cover @@ -55,6 +55,7 @@ KNOWN_MEDIA_TYPES = [ "application/pdf", "application/xml", "audio/mpeg", + "audio/mp3", "audio/ogg", "audio/wav", "image/apng", @@ -93,6 +94,8 @@ __all__ = [ "DataContent", "ErrorContent", "FinishReason", + "FunctionApprovalRequestContent", + "FunctionApprovalResponseContent", "FunctionCallContent", "FunctionResultContent", "GeneratedEmbeddings", @@ -224,7 +227,11 @@ def _process_update( is_new_message = False if ( not response.messages - or (update.message_id and response.messages[-1].message_id != update.message_id) + or ( + update.message_id + and response.messages[-1].message_id + and response.messages[-1].message_id != update.message_id + ) or (update.role and response.messages[-1].role != update.role) ): is_new_message = True @@ -249,7 +256,7 @@ def _process_update( ): try: message.contents[-1] += content - except AgentFrameworkException: + except AdditionItemMismatch: message.contents.append(content) elif isinstance(content, UsageContent): if response.usage_details is None: @@ -718,7 +725,7 @@ class DataContent(BaseContent): raise ValueError(f"Unknown media type: {media_type}") return uri - def has_top_level_media_type(self, top_level_media_type: str) -> bool: + def has_top_level_media_type(self, top_level_media_type: Literal["application", "audio", "image", "text"]) -> bool: return _has_top_level_media_type(self.media_type, top_level_media_type) @@ -776,11 +783,13 @@ class UriContent(BaseContent): **kwargs, ) - def has_top_level_media_type(self, top_level_media_type: str) -> bool: + def has_top_level_media_type(self, top_level_media_type: Literal["application", "audio", "image", "text"]) -> bool: return _has_top_level_media_type(self.media_type, top_level_media_type) -def _has_top_level_media_type(media_type: str | None, top_level_media_type: str) -> bool: +def _has_top_level_media_type( + media_type: str | None, top_level_media_type: Literal["application", "audio", "image", "text"] +) -> bool: if media_type is None: return False @@ -924,7 +933,7 @@ class FunctionCallContent(BaseContent): if not isinstance(other, FunctionCallContent): raise TypeError("Incompatible type") if other.call_id and self.call_id != other.call_id: - raise AgentFrameworkException("Incompatible function call contents") + raise AdditionItemMismatch if not self.arguments: arguments = other.arguments elif not other.arguments: @@ -1093,6 +1102,110 @@ class HostedVectorStoreContent(BaseContent): ) +class BaseUserInputRequest(BaseContent): + """Base class for all user requests.""" + + type: Literal["user_input_request"] = "user_input_request" # type: ignore[assignment] + id: Annotated[str, Field(..., min_length=1)] + + +class BaseUserInputResponse(BaseContent): + """Base class for all user responses.""" + + type: Literal["user_input_response"] = "user_input_response" # type: ignore[assignment] + id: Annotated[str, Field(..., min_length=1)] + + +class FunctionApprovalResponseContent(BaseUserInputResponse): + """Represents a response for user approval of a function call.""" + + type: Literal["function_approval_response"] = "function_approval_response" # type: ignore[assignment] + approved: bool + function_call: FunctionCallContent + + def __init__( + self, + approved: bool, + *, + id: str, + function_call: FunctionCallContent, + annotations: list[Annotations] | None = None, + additional_properties: dict[str, Any] | None = None, + raw_representation: Any | None = None, + **kwargs: Any, + ) -> None: + """Initializes a FunctionApprovalResponseContent instance. + + Args: + approved: Whether the function call was approved. + id: The unique identifier for the request. + function_call: The function call content to be approved. + annotations: Optional list of annotations for the request. + additional_properties: Optional additional properties for the request. + raw_representation: Optional raw representation of the request. + **kwargs: Additional keyword arguments. + """ + super().__init__( + approved=approved, # type: ignore[reportCallIssue] + id=id, # type: ignore[reportCallIssue] + function_call=function_call, # type: ignore[reportCallIssue] + annotations=annotations, + additional_properties=additional_properties, + raw_representation=raw_representation, + **kwargs, + ) + + +class FunctionApprovalRequestContent(BaseUserInputRequest): + """Represents a request for user approval of a function call.""" + + type: Literal["function_approval_request"] = "function_approval_request" # type: ignore[assignment] + function_call: FunctionCallContent + + def __init__( + self, + *, + id: str, + function_call: FunctionCallContent, + annotations: list[Annotations] | None = None, + additional_properties: dict[str, Any] | None = None, + raw_representation: Any | None = None, + **kwargs: Any, + ) -> None: + """Initializes a FunctionApprovalRequestContent instance. + + Args: + id: The unique identifier for the request. + function_call: The function call content to be approved. + annotations: Optional list of annotations for the request. + additional_properties: Optional additional properties for the request. + raw_representation: Optional raw representation of the request. + **kwargs: Additional keyword arguments. + """ + super().__init__( + id=id, # type: ignore[reportCallIssue] + function_call=function_call, # type: ignore[reportCallIssue] + annotations=annotations, + additional_properties=additional_properties, + raw_representation=raw_representation, + **kwargs, + ) + + def create_response(self, approved: bool) -> "FunctionApprovalResponseContent": + """Create a response for the function approval request.""" + return FunctionApprovalResponseContent( + approved, + id=self.id, + function_call=self.function_call, + additional_properties=self.additional_properties, + ) + + +UserInputRequestContents = Annotated[ + FunctionApprovalRequestContent, + Field(discriminator="type"), +] + Contents = Annotated[ TextContent | DataContent @@ -1103,7 +1216,9 @@ Contents = Annotated[ | ErrorContent | UsageContent | HostedFileContent - | HostedVectorStoreContent, + | HostedVectorStoreContent + | FunctionApprovalRequestContent + | FunctionApprovalResponseContent, Field(discriminator="type"), ] @@ -1957,6 +2072,13 @@ class AgentRunResponse(AFBaseModel): """Get the concatenated text of all messages.""" return "".join(msg.text for msg in self.messages) if self.messages else "" + @property + def user_input_requests(self) -> list[UserInputRequestContents]: + """Get all BaseUserInputRequest messages from the response.""" + return [ + content for msg in self.messages for content in msg.contents if isinstance(content, BaseUserInputRequest) + ] + @classmethod def from_agent_run_response_updates( cls: type[TAgentRunResponse], updates: Sequence["AgentRunResponseUpdate"] @@ -2007,6 +2129,11 @@ class AgentRunResponseUpdate(AFBaseModel): else "" ) + @property + def user_input_requests(self) -> list[UserInputRequestContents]: + """Get all BaseUserInputRequest messages from the response.""" + return [content for content in self.contents if isinstance(content, BaseUserInputRequest)] + def __str__(self) -> str: return self.text @@ -2082,9 +2209,3 @@ class TextToSpeechOptions(AFBaseModel): for key in merged_exclude: settings.pop(key, None) return settings - - -# endregion - - -# endregion diff --git a/python/packages/main/agent_framework/exceptions.py b/python/packages/main/agent_framework/exceptions.py index 66a2c14033..5e6a28b1e2 100644 --- a/python/packages/main/agent_framework/exceptions.py +++ b/python/packages/main/agent_framework/exceptions.py @@ -12,13 +12,15 @@ class AgentFrameworkException(Exception): Automatically logs the message as debug. """ - def __init__(self, message: str, inner_exception: Exception | None = None, *args: Any, **kwargs: Any): + def __init__(self, message: str, inner_exception: Exception | None = None, *args: Any): """Create an AgentFrameworkException. This emits a debug log, with the inner_exception if provided. """ logger.debug(message, exc_info=inner_exception) - super().__init__(message, *args, **kwargs) # type: ignore + if inner_exception: + super().__init__(message, inner_exception, *args) # type: ignore + super().__init__(message, *args) # type: ignore class AgentException(AgentFrameworkException): @@ -94,3 +96,14 @@ class ToolExecutionException(ToolException): """An error occurred while executing a tool.""" pass + + +class AdditionItemMismatch(AgentFrameworkException): + """An error occurred while adding two types.""" + + def __init__(self) -> None: + """Create an AdditionItemMismatch. + + Unlike the AgentFrameworkException, this does not log the message automatically, + """ + pass diff --git a/python/packages/main/agent_framework/openai/_responses_client.py b/python/packages/main/agent_framework/openai/_responses_client.py index 9abfc5772b..f684dcc31e 100644 --- a/python/packages/main/agent_framework/openai/_responses_client.py +++ b/python/packages/main/agent_framework/openai/_responses_client.py @@ -13,18 +13,12 @@ from openai.types.responses.parsed_response import ( ParsedResponse, ) from openai.types.responses.response import Response as OpenAIResponse -from openai.types.responses.response_completed_event import ResponseCompletedEvent -from openai.types.responses.response_content_part_added_event import ResponseContentPartAddedEvent -from openai.types.responses.response_function_call_arguments_delta_event import ResponseFunctionCallArgumentsDeltaEvent -from openai.types.responses.response_output_item_added_event import ResponseOutputItemAddedEvent -from openai.types.responses.response_output_refusal import ResponseOutputRefusal -from openai.types.responses.response_output_text import ResponseOutputText from openai.types.responses.response_stream_event import ResponseStreamEvent as OpenAIResponseStreamEvent -from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent from openai.types.responses.response_usage import ResponseUsage from openai.types.responses.tool_param import ( CodeInterpreter, CodeInterpreterContainerCodeInterpreterToolAuto, + Mcp, ToolParam, ) from openai.types.responses.web_search_tool_param import UserLocation as WebSearchUserLocation @@ -33,7 +27,14 @@ from pydantic import BaseModel, SecretStr, ValidationError from .._clients import BaseChatClient, use_tool_calling from .._logging import get_logger -from .._tools import AIFunction, HostedCodeInterpreterTool, HostedFileSearchTool, HostedWebSearchTool, ToolProtocol +from .._tools import ( + AIFunction, + HostedCodeInterpreterTool, + HostedFileSearchTool, + HostedMCPTool, + HostedWebSearchTool, + ToolProtocol, +) from .._types import ( ChatMessage, ChatOptions, @@ -42,6 +43,8 @@ from .._types import ( CitationAnnotation, Contents, DataContent, + FunctionApprovalRequestContent, + FunctionApprovalResponseContent, FunctionCallContent, FunctionResultContent, HostedFileContent, @@ -364,15 +367,41 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient): # region Prep methods - def _chat_to_response_tool_spec( + def _tools_to_response_tools( self, tools: list[ToolProtocol | MutableMapping[str, Any]] ) -> list[ToolParam | dict[str, Any]]: response_tools: list[ToolParam | dict[str, Any]] = [] for tool in tools: if isinstance(tool, ToolProtocol): match tool: + case HostedMCPTool(): + mcp: Mcp = { + "type": "mcp", + "server_label": tool.name.replace(" ", "_"), + "server_url": str(tool.url), + "server_description": tool.description, + "headers": tool.headers, + } + if tool.allowed_tools: + mcp["allowed_tools"] = list(tool.allowed_tools) + if tool.approval_mode: + match tool.approval_mode: + case str(): + mcp["require_approval"] = ( + "always" if tool.approval_mode == "always_require" else "never" + ) + case _: + if always_require_approvals := tool.approval_mode.get("always_require_approval"): + mcp["require_approval"] = { + "always": {"tool_names": list(always_require_approvals)} + } + if never_require_approvals := tool.approval_mode.get("never_require_approval"): + mcp["require_approval"] = { + "never": {"tool_names": list(never_require_approvals)} + } + response_tools.append(mcp) case HostedCodeInterpreterTool(): - tool_args: dict[str, Any] = {"type": "auto"} + tool_args: CodeInterpreterContainerCodeInterpreterToolAuto = {"type": "auto"} if tool.inputs: tool_args["file_ids"] = [] for tool_input in tool.inputs: @@ -383,7 +412,7 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient): response_tools.append( CodeInterpreter( type="code_interpreter", - container=CodeInterpreterContainerCodeInterpreterToolAuto(**tool_args), # type: ignore[typeddict-item] + container=tool_args, ) ) case AIFunction(): @@ -455,7 +484,7 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient): if chat_options.tools is None: options_dict.pop("parallel_tool_calls", None) else: - options_dict["tools"] = self._chat_to_response_tool_spec(chat_options.tools) + options_dict["tools"] = self._tools_to_response_tools(chat_options.tools) # other settings if "store" not in options_dict: options_dict["store"] = False @@ -496,6 +525,137 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient): # Flatten the list of lists into a single list return list(chain.from_iterable(list_of_list)) + def _openai_chat_message_parser( + self, + message: ChatMessage, + call_id_to_id: dict[str, str], + ) -> list[dict[str, Any]]: + """Parse a chat message into the openai format.""" + all_messages: list[dict[str, Any]] = [] + args: dict[str, Any] = { + "role": message.role.value if isinstance(message.role, Role) else message.role, + } + if message.additional_properties: + args["metadata"] = message.additional_properties + for content in message.contents: + match content: + case FunctionResultContent(): + new_args: dict[str, Any] = {} + new_args.update(self._openai_content_parser(message.role, content, call_id_to_id)) + all_messages.append(new_args) + case FunctionCallContent(): + function_call = self._openai_content_parser(message.role, content, call_id_to_id) + all_messages.append(function_call) # type: ignore + case FunctionApprovalResponseContent() | FunctionApprovalRequestContent(): + all_messages.append(self._openai_content_parser(message.role, content, call_id_to_id)) # type: ignore + case _: + if "content" not in args: + args["content"] = [] + args["content"].append(self._openai_content_parser(message.role, content, call_id_to_id)) # type: ignore + if "content" in args or "tool_calls" in args: + all_messages.append(args) + return all_messages + + def _openai_content_parser( + self, + role: Role, + content: Contents, + call_id_to_id: dict[str, str], + ) -> dict[str, Any]: + """Parse contents into the openai format.""" + match content: + case TextContent(): + return { + "type": "output_text" if role == Role.ASSISTANT else "input_text", + "text": content.text, + } + case TextReasoningContent(): + ret: dict[str, Any] = { + "type": "reasoning", + "summary": { + "type": "summary_text", + "text": content.text, + }, + } + if content.additional_properties is not None: + if status := content.additional_properties.get("status"): + ret["status"] = status + if reasoning_text := content.additional_properties.get("reasoning_text"): + ret["content"] = {"type": "reasoning_text", "text": reasoning_text} + if encrypted_content := content.additional_properties.get("encrypted_content"): + ret["encrypted_content"] = encrypted_content + return ret + case DataContent() | UriContent(): + if content.has_top_level_media_type("image"): + return { + "type": "input_image", + "image_url": content.uri, + "detail": content.additional_properties.get("detail", "auto") + if content.additional_properties + else "auto", + "file_id": content.additional_properties.get("file_id", None) + if content.additional_properties + else None, + } + if content.has_top_level_media_type("audio"): + if content.media_type and "wav" in content.media_type: + format = "wav" + elif content.media_type and "mp3" in content.media_type: + format = "mp3" + else: + logger.warning("Unsupported audio media type: %s", content.media_type) + return {} + return { + "type": "input_audio", + "input_audio": { + "data": content.uri, + "format": format, + }, + } + return {} + case FunctionCallContent(): + return { + "call_id": content.call_id, + "id": call_id_to_id[content.call_id], + "type": "function_call", + "name": content.name, + "arguments": content.arguments, + } + case FunctionResultContent(): + # call_id for the result needs to be the same as the call_id for the function call + args: dict[str, Any] = { + "call_id": content.call_id, + "id": call_id_to_id.get(content.call_id), + "type": "function_call_output", + } + if content.result: + args["output"] = prepare_function_call_results(content.result) + return args + case FunctionApprovalRequestContent(): + return { + "type": "mcp_approval_request", + "id": content.id, + "arguments": content.function_call.arguments, + "name": content.function_call.name, + "server_label": content.function_call.additional_properties.get("server_label") + if content.function_call.additional_properties + else None, + } + case FunctionApprovalResponseContent(): + return { + "type": "mcp_approval_response", + "approval_request_id": content.id, + "approve": content.approved, + } + case HostedFileContent(): + return { + "type": "input_file", + "file_id": content.file_id, + } + case _: # should catch UsageDetails and ErrorContent and HostedVectorStoreContent + logger.debug("Unsupported content type passed (type: %s)", type(content)) + return {} + # region Response creation methods def _create_response_content( @@ -533,7 +693,8 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient): match message_content.type: case "output_text": text_content = TextContent( - text=message_content.text, raw_representation=message_content + text=message_content.text, + raw_representation=message_content, # type: ignore[reportUnknownArgumentType] ) metadata.update(self._get_metadata_from_response(message_content)) if message_content.annotations: @@ -639,6 +800,19 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient): raw_representation=item, ) ) + case "mcp_approval_request": # ResponseOutputMcpApprovalRequest + contents.append( + FunctionApprovalRequestContent( + id=item.id, + function_call=FunctionCallContent( + call_id=item.id, + name=item.name, + arguments=item.arguments, + additional_properties={"server_label": item.server_label}, + raw_representation=item, + ), + ) + ) case "image_generation_call": # ResponseOutputImageGenerationCall if item.result: contents.append( @@ -649,7 +823,7 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient): ) # TODO(peterychang): Add support for other content types case _: - logger.debug("Unparsed content of type: %s: %s", item.type, item) + logger.debug("Unparsed output of type: %s: %s", item.type, item) response_message = ChatMessage(role="assistant", contents=contents) args: dict[str, Any] = { "response_id": response.id, @@ -677,35 +851,151 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient): ) -> ChatResponseUpdate: """Create a streaming chat message content object from a choice.""" metadata: dict[str, Any] = {} - items: list[Contents] = [] + contents: list[Contents] = [] conversation_id: str | None = None model = self.ai_model_id # TODO(peterychang): Add support for other content types - match event: - case ResponseContentPartAddedEvent(): - match event.part: - case ResponseOutputText(): - items.append(TextContent(text=event.part.text, raw_representation=event)) - metadata.update(self._get_metadata_from_response(event.part)) - case ResponseOutputRefusal(): - items.append(TextContent(text=event.part.refusal, raw_representation=event)) - case ResponseTextDeltaEvent(): - items.append(TextContent(text=event.delta, raw_representation=event)) + match event.type: + # types: + # ResponseAudioDeltaEvent, + # ResponseAudioDoneEvent, + # ResponseAudioTranscriptDeltaEvent, + # ResponseAudioTranscriptDoneEvent, + # ResponseCodeInterpreterCallCodeDeltaEvent, + # ResponseCodeInterpreterCallCodeDoneEvent, + # ResponseCodeInterpreterCallCompletedEvent, + # ResponseCodeInterpreterCallInProgressEvent, + # ResponseCodeInterpreterCallInterpretingEvent, + # ResponseCompletedEvent, + # ResponseContentPartAddedEvent, + # ResponseContentPartDoneEvent, + # ResponseCreatedEvent, + # ResponseErrorEvent, + # ResponseFileSearchCallCompletedEvent, + # ResponseFileSearchCallInProgressEvent, + # ResponseFileSearchCallSearchingEvent, + # ResponseFunctionCallArgumentsDeltaEvent, + # ResponseFunctionCallArgumentsDoneEvent, + # ResponseInProgressEvent, + # ResponseFailedEvent, + # ResponseIncompleteEvent, + # ResponseOutputItemAddedEvent, + # ResponseOutputItemDoneEvent, + # ResponseReasoningSummaryPartAddedEvent, + # ResponseReasoningSummaryPartDoneEvent, + # ResponseReasoningSummaryTextDeltaEvent, + # ResponseReasoningSummaryTextDoneEvent, + # ResponseReasoningTextDeltaEvent, + # ResponseReasoningTextDoneEvent, + # ResponseRefusalDeltaEvent, + # ResponseRefusalDoneEvent, + # ResponseTextDeltaEvent, + # ResponseTextDoneEvent, + # ResponseWebSearchCallCompletedEvent, + # ResponseWebSearchCallInProgressEvent, + # ResponseWebSearchCallSearchingEvent, + # ResponseImageGenCallCompletedEvent, + # ResponseImageGenCallGeneratingEvent, + # ResponseImageGenCallInProgressEvent, + # ResponseImageGenCallPartialImageEvent, + # ResponseMcpCallArgumentsDeltaEvent, + # ResponseMcpCallArgumentsDoneEvent, + # ResponseMcpCallCompletedEvent, + # ResponseMcpCallFailedEvent, + # ResponseMcpCallInProgressEvent, + # ResponseMcpListToolsCompletedEvent, + # ResponseMcpListToolsFailedEvent, + # ResponseMcpListToolsInProgressEvent, + # ResponseOutputTextAnnotationAddedEvent, + # ResponseQueuedEvent, + # ResponseCustomToolCallInputDeltaEvent, + # ResponseCustomToolCallInputDoneEvent, + case "response.content_part.added": + event_part = event.part + match event_part.type: + case "output_text": + contents.append(TextContent(text=event_part.text, raw_representation=event)) + metadata.update(self._get_metadata_from_response(event_part)) + case "refusal": + contents.append(TextContent(text=event_part.refusal, raw_representation=event)) + case "response.output_text.delta": + contents.append(TextContent(text=event.delta, raw_representation=event)) metadata.update(self._get_metadata_from_response(event)) - case ResponseCompletedEvent(): + case "response.completed": conversation_id = event.response.id if chat_options.store is True else None model = event.response.model if event.response.usage: usage = self._usage_details_from_openai(event.response.usage) if usage: - items.append(UsageContent(details=usage, raw_representation=event)) - case ResponseOutputItemAddedEvent(): - if event.item.type == "function_call": - function_call_ids[event.output_index] = (event.item.call_id, event.item.name) - case ResponseFunctionCallArgumentsDeltaEvent(): + contents.append(UsageContent(details=usage, raw_representation=event)) + case "response.output_item.added": + event_item = event.item + match event_item.type: + # types: + # ResponseOutputMessage, + # ResponseFileSearchToolCall, + # ResponseFunctionToolCall, + # ResponseFunctionWebSearch, + # ResponseComputerToolCall, + # ResponseReasoningItem, + # ImageGenerationCall, + # ResponseCodeInterpreterToolCall, + # LocalShellCall, + # McpCall, + # McpListTools, + # McpApprovalRequest, + # ResponseCustomToolCall, + case "function_call": + function_call_ids[event.output_index] = (event_item.call_id, event_item.name) + case "mcp_approval_request": + contents.append( + FunctionApprovalRequestContent( + id=event_item.id, + function_call=FunctionCallContent( + call_id=event_item.id, + name=event_item.name, + arguments=event_item.arguments, + additional_properties={"server_label": event_item.server_label}, + raw_representation=event_item, + ), + ) + ) + case "code_interpreter_call": # ResponseOutputCodeInterpreterCall + if event_item.outputs: + for code_output in event_item.outputs: + if code_output.type == "logs": + contents.append(TextContent(text=code_output.logs, raw_representation=event_item)) + if code_output.type == "image": + contents.append( + UriContent( + uri=code_output.url, + raw_representation=event_item, + # no more specific media type then this can be inferred + media_type="image", + ) + ) + elif event_item.code: + # fallback if no output was returned is the code: + contents.append(TextContent(text=event_item.code, raw_representation=event_item)) + case "reasoning": # ResponseOutputReasoning + if event_item.content: + for index, reasoning_content in enumerate(event_item.content): + additional_properties = None + if event_item.summary and index < len(event_item.summary): + additional_properties = {"summary": event_item.summary[index]} + contents.append( + TextReasoningContent( + text=reasoning_content.text, + raw_representation=reasoning_content, + additional_properties=additional_properties, + ) + ) + case _: + logger.debug("Unparsed event of type: %s: %s", event.type, event) + case "response.function_call_arguments.delta": call_id, name = function_call_ids.get(event.output_index, (None, None)) if call_id and name: - items.append( + contents.append( FunctionCallContent( call_id=call_id, name=name, @@ -715,10 +1005,10 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient): ) ) case _: - logger.debug("Unparsed event: %s", event) + logger.debug("Unparsed event of type: %s: %s", event.type, event) return ChatResponseUpdate( - contents=items, + contents=contents, conversation_id=conversation_id, role=Role.ASSISTANT, ai_model_id=model, @@ -738,69 +1028,6 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient): details["openai.reasoning_tokens"] = usage.output_tokens_details.reasoning_tokens return details - def _openai_chat_message_parser( - self, - message: ChatMessage, - call_id_to_id: dict[str, str], - ) -> list[dict[str, Any]]: - """Parse a chat message into the openai format.""" - all_messages: list[dict[str, Any]] = [] - args: dict[str, Any] = { - "role": message.role.value if isinstance(message.role, Role) else message.role, - } - if message.additional_properties: - args["metadata"] = message.additional_properties - for content in message.contents: - match content: - case FunctionResultContent(): - new_args: dict[str, Any] = {} - new_args.update(self._openai_content_parser(message.role, content, call_id_to_id)) - all_messages.append(new_args) - case FunctionCallContent(): - function_call = self._openai_content_parser(message.role, content, call_id_to_id) - all_messages.append(function_call) # type: ignore - case _: - if "content" not in args: - args["content"] = [] - args["content"].append(self._openai_content_parser(message.role, content, call_id_to_id)) # type: ignore - if "content" in args or "tool_calls" in args: - all_messages.append(args) - return all_messages - - def _openai_content_parser( - self, - role: Role, - content: Contents, - call_id_to_id: dict[str, str], - ) -> dict[str, Any]: - """Parse contents into the openai format.""" - match content: - case FunctionCallContent(): - return { - "call_id": content.call_id, - "id": call_id_to_id[content.call_id], - "type": "function_call", - "name": content.name, - "arguments": content.arguments, - } - case FunctionResultContent(): - # call_id for the result needs to be the same as the call_id for the function call - args: dict[str, Any] = { - "call_id": content.call_id, - "type": "function_call_output", - } - if content.result: - args["output"] = prepare_function_call_results(content.result) - return args - case TextContent(): - return { - "type": "output_text" if role == Role.ASSISTANT else "input_text", - "text": content.text, - } - # TODO(peterychang): We'll probably need to specialize the other content types as well - case _: - return content.model_dump(exclude_none=True) - def _get_metadata_from_response(self, output: Any) -> dict[str, Any]: """Get metadata from a chat choice.""" if logprobs := getattr(output, "logprobs", None): diff --git a/python/packages/main/tests/main/test_logging.py b/python/packages/main/tests/main/test_logging.py index 2f7e40f17a..6565834596 100644 --- a/python/packages/main/tests/main/test_logging.py +++ b/python/packages/main/tests/main/test_logging.py @@ -22,7 +22,7 @@ def test_get_logger_custom_name(): def test_get_logger_invalid_name(): """Test that an exception is raised for an invalid logger name.""" - with pytest.raises(AgentFrameworkException, match="Logger name must start with 'agent_framework'."): + with pytest.raises(AgentFrameworkException): get_logger("invalid_name") diff --git a/python/packages/main/tests/main/test_tools.py b/python/packages/main/tests/main/test_tools.py index a3a39b9613..537b8cc94e 100644 --- a/python/packages/main/tests/main/test_tools.py +++ b/python/packages/main/tests/main/test_tools.py @@ -1,14 +1,24 @@ # Copyright (c) Microsoft. All rights reserved. +from typing import Any from unittest.mock import Mock, patch import pytest from pydantic import BaseModel -from agent_framework import AIFunction, HostedCodeInterpreterTool, ToolProtocol, ai_function +from agent_framework import ( + AIFunction, + HostedCodeInterpreterTool, + HostedMCPTool, + ToolProtocol, + ai_function, +) from agent_framework._tools import _parse_inputs +from agent_framework.exceptions import ToolException from agent_framework.telemetry import GenAIAttributes +# region AIFunction and ai_function decorator tests + def test_ai_function_decorator(): """Test the ai_function decorator.""" @@ -291,7 +301,7 @@ async def test_ai_function_invoke_invalid_pydantic_args(): await invalid_args_test.invoke(arguments=wrong_args) -# Tests for HostedCodeInterpreterTool and _parse_inputs +# region HostedCodeInterpreterTool and _parse_inputs def test_hosted_code_interpreter_tool_default(): @@ -507,3 +517,104 @@ def test_hosted_code_interpreter_tool_with_unknown_input(): """Test HostedCodeInterpreterTool with single unknown input.""" with pytest.raises(ValueError, match="Unsupported input type"): HostedCodeInterpreterTool(inputs={"hosted_file": "file-single"}) + + +# region HostedMCPTool tests + + +def test_hosted_mcp_tool_with_other_fields(): + """Test creating a HostedMCPTool with a specific approval dict, headers and additional properties.""" + tool = HostedMCPTool( + name="mcp-tool", + url="https://mcp.example", + description="A test MCP tool", + headers={"x": "y"}, + additional_properties={"p": 1}, + ) + + assert tool.name == "mcp-tool" + # pydantic AnyUrl preserves as string-like + assert str(tool.url).startswith("https://") + assert tool.headers == {"x": "y"} + assert tool.additional_properties == {"p": 1} + assert tool.description == "A test MCP tool" + + +@pytest.mark.parametrize( + "approval_mode", + [ + "always_require", + "never_require", + { + "always_require_approval": {"toolA"}, + "never_require_approval": {"toolB"}, + }, + { + "always_require_approval": ["toolA"], + "never_require_approval": ("toolB",), + }, + ], + ids=["always_require", "never_require", "specific", "specific_with_parsing"], +) +def test_hosted_mcp_tool_with_approval_mode(approval_mode: str | dict[str, Any]): + """Test creating a HostedMCPTool with a specific approval dict, headers and additional properties.""" + tool = HostedMCPTool(name="mcp-tool", url="https://mcp.example", approval_mode=approval_mode) + + assert tool.name == "mcp-tool" + # pydantic AnyUrl preserves as string-like + assert str(tool.url).startswith("https://") + if not isinstance(approval_mode, dict): + assert tool.approval_mode == approval_mode + else: + # approval_mode parsed to sets + assert isinstance(tool.approval_mode["always_require_approval"], set) + assert isinstance(tool.approval_mode["never_require_approval"], set) + assert "toolA" in tool.approval_mode["always_require_approval"] + assert "toolB" in tool.approval_mode["never_require_approval"] + + +def test_hosted_mcp_tool_invalid_approval_mode_raises(): + """Invalid approval_mode string should raise ServiceInitializationError.""" + with pytest.raises(ToolException): + HostedMCPTool(name="bad", url="https://x", approval_mode="invalid_mode") + + +@pytest.mark.parametrize( + "tools", + [ + {"toolA", "toolB"}, + ("toolA", "toolB"), + ["toolA", "toolB"], + ["toolA", "toolB", "toolA"], + ], + ids=[ + "set", + "tuple", + "list", + "list_with_duplicates", + ], +) +def test_hosted_mcp_tool_with_allowed_tools(tools: list[str] | tuple[str, ...] | set[str]): + """Test creating a HostedMCPTool with a list of allowed tools.""" + tool = HostedMCPTool( + name="mcp-tool", + url="https://mcp.example", + allowed_tools=tools, + ) + + assert tool.name == "mcp-tool" + # pydantic AnyUrl preserves as string-like + assert str(tool.url).startswith("https://") + # approval_mode parsed to set + assert isinstance(tool.allowed_tools, set) + assert tool.allowed_tools == {"toolA", "toolB"} + + +def test_hosted_mcp_tool_with_dict_of_allowed_tools(): + """Test creating a HostedMCPTool with a dict of allowed tools.""" + with pytest.raises(ToolException): + HostedMCPTool( + name="mcp-tool", + url="https://mcp.example", + allowed_tools={"toolA": "Tool A", "toolC": "Tool C"}, + ) diff --git a/python/packages/main/tests/main/test_types.py b/python/packages/main/tests/main/test_types.py index 40fd8e9f0e..094d4ea494 100644 --- a/python/packages/main/tests/main/test_types.py +++ b/python/packages/main/tests/main/test_types.py @@ -21,6 +21,8 @@ from agent_framework import ( DataContent, ErrorContent, FinishReason, + FunctionApprovalRequestContent, + FunctionApprovalResponseContent, FunctionCallContent, FunctionResultContent, GeneratedEmbeddings, @@ -38,6 +40,7 @@ from agent_framework import ( UsageDetails, ai_function, ) +from agent_framework.exceptions import AdditionItemMismatch @fixture @@ -296,9 +299,8 @@ def test_function_call_content_add_merging_and_errors(): # incompatible call ids a = FunctionCallContent(call_id="1", name="f", arguments="abc") b = FunctionCallContent(call_id="2", name="f", arguments="def") - from agent_framework.exceptions import AgentFrameworkException - with raises(AgentFrameworkException): + with raises(AdditionItemMismatch): _ = a + b @@ -379,6 +381,42 @@ def test_usage_details_add_with_none_and_type_errors(): u += 42 # type: ignore[arg-type] +# region UserInputRequest and Response + + +def test_function_approval_request_and_response_creation(): + """Test creating a FunctionApprovalRequestContent and producing a response.""" + fc = FunctionCallContent(call_id="call-1", name="do_something", arguments={"a": 1}) + req = FunctionApprovalRequestContent(id="req-1", function_call=fc) + + assert req.type == "function_approval_request" + assert req.function_call == fc + assert req.id == "req-1" + assert isinstance(req, BaseContent) + + resp = req.create_response(True) + + assert isinstance(resp, FunctionApprovalResponseContent) + assert resp.approved is True + assert resp.function_call == fc + assert resp.id == "req-1" + + +def test_function_approval_serialization_roundtrip(): + fc = FunctionCallContent(call_id="c2", name="f", arguments='{"x":1}') + req = FunctionApprovalRequestContent(id="id-2", function_call=fc, additional_properties={"meta": 1}) + + dumped = req.model_dump() + loaded = FunctionApprovalRequestContent.model_validate(dumped) + assert loaded == req + + class TestModel(BaseModel): + content: Contents + + test_item = TestModel.model_validate({"content": dumped}) + assert isinstance(test_item.content, FunctionApprovalRequestContent) + + # region BaseContent Serialization diff --git a/python/packages/main/tests/openai/test_openai_assistants_client.py b/python/packages/main/tests/openai/test_openai_assistants_client.py index c78a45d1fb..76ce1ce87e 100644 --- a/python/packages/main/tests/openai/test_openai_assistants_client.py +++ b/python/packages/main/tests/openai/test_openai_assistants_client.py @@ -1251,42 +1251,3 @@ async def test_openai_assistants_client_agent_level_tool_persistence(): assert second_response.text is not None # Should use the agent-level weather tool again assert any(term in second_response.text.lower() for term in ["miami", "sunny", "72"]) - - -@skip_if_openai_integration_tests_disabled -async def test_openai_assistants_client_run_level_tool_isolation(): - """Test that run-level tools are isolated to specific runs and don't persist with OpenAI Assistants Client.""" - # Counter to track how many times the weather tool is called - call_count = 0 - - @ai_function - async def get_weather_with_counter(location: Annotated[str, "The location as a city name"]) -> str: - """Get the current weather in a given location.""" - nonlocal call_count - call_count += 1 - return f"The weather in {location} is sunny and 72°F." - - async with ChatAgent( - chat_client=OpenAIAssistantsClient(), - instructions="You are a helpful assistant.", - ) as agent: - # First run - use run-level tool - first_response = await agent.run( - "What's the weather like in Chicago?", - tools=[get_weather_with_counter], # Run-level tool - ) - - assert isinstance(first_response, AgentRunResponse) - assert first_response.text is not None - # Should use the run-level weather tool (call count should be 1) - assert call_count == 1 - assert any(term in first_response.text.lower() for term in ["chicago", "sunny", "72"]) - - # Second run - run-level tool should NOT persist (key isolation test) - second_response = await agent.run("What's the weather like in Miami?") - - assert isinstance(second_response, AgentRunResponse) - assert second_response.text is not None - # Should NOT use the weather tool since it was only run-level in previous call - # Call count should still be 1 (no additional calls) - assert call_count == 1 diff --git a/python/packages/main/tests/openai/test_openai_chat_client.py b/python/packages/main/tests/openai/test_openai_chat_client.py index 35e75095a1..f57af48207 100644 --- a/python/packages/main/tests/openai/test_openai_chat_client.py +++ b/python/packages/main/tests/openai/test_openai_chat_client.py @@ -339,7 +339,7 @@ async def test_openai_chat_client_web_search() -> None: tools=[HostedWebSearchTool(additional_properties=additional_properties)], tool_choice="auto", ) - assert "Seattle" in response.text + assert response.text is not None @skip_if_openai_integration_tests_disabled @@ -392,7 +392,7 @@ async def test_openai_chat_client_web_search_streaming() -> None: for content in chunk.contents: if isinstance(content, TextContent) and content.text: full_message += content.text - assert "Seattle" in full_message + assert full_message is not None @skip_if_openai_integration_tests_disabled diff --git a/python/packages/main/tests/openai/test_openai_responses_client.py b/python/packages/main/tests/openai/test_openai_responses_client.py index f9e7c18ab6..6b89329f4e 100644 --- a/python/packages/main/tests/openai/test_openai_responses_client.py +++ b/python/packages/main/tests/openai/test_openai_responses_client.py @@ -18,11 +18,14 @@ from agent_framework import ( ChatMessage, ChatResponse, ChatResponseUpdate, + FunctionApprovalRequestContent, + FunctionApprovalResponseContent, FunctionCallContent, FunctionResultContent, HostedCodeInterpreterTool, HostedFileContent, HostedFileSearchTool, + HostedMCPTool, HostedVectorStoreContent, HostedWebSearchTool, Role, @@ -49,7 +52,7 @@ class OutputStruct(BaseModel): """A structured output for testing purposes.""" location: str - weather: str + weather: str | None = None async def create_vector_store(client: OpenAIResponsesClient) -> tuple[str, HostedVectorStoreContent]: @@ -644,6 +647,156 @@ def test_response_content_creation_with_function_call() -> None: assert function_call.arguments == '{"location": "Seattle"}' +def test_tools_to_response_tools_with_hosted_mcp() -> None: + """Test that HostedMCPTool is converted to the correct response tool dict.""" + client = OpenAIResponsesClient(ai_model_id="test-model", api_key="test-key") + + tool = HostedMCPTool( + name="My MCP", + url="https://mcp.example", + description="An MCP server", + approval_mode={"always_require_approval": ["tool_a", "tool_b"]}, + allowed_tools={"tool_a", "tool_b"}, + headers={"X-Test": "yes"}, + additional_properties={"custom": "value"}, + ) + + resp_tools = client._tools_to_response_tools([tool]) + assert isinstance(resp_tools, list) + assert len(resp_tools) == 1 + mcp = resp_tools[0] + assert isinstance(mcp, dict) + assert mcp["type"] == "mcp" + assert mcp["server_label"] == "My_MCP" + # server_url may be normalized to include a trailing slash by the client + assert str(mcp["server_url"]).rstrip("/") == "https://mcp.example" + assert mcp["server_description"] == "An MCP server" + assert mcp["headers"]["X-Test"] == "yes" + assert set(mcp["allowed_tools"]) == {"tool_a", "tool_b"} + # approval mapping created from approval_mode dict + assert "require_approval" in mcp + + +def test_create_response_content_with_mcp_approval_request() -> None: + """Test that a non-streaming mcp_approval_request is parsed into FunctionApprovalRequestContent.""" + client = OpenAIResponsesClient(ai_model_id="test-model", api_key="test-key") + + mock_response = MagicMock() + mock_response.output_parsed = None + mock_response.metadata = {} + mock_response.usage = None + mock_response.id = "resp-id" + mock_response.model = "test-model" + mock_response.created_at = 1000000000 + + mock_item = MagicMock() + mock_item.type = "mcp_approval_request" + mock_item.id = "approval-1" + mock_item.name = "do_sensitive_action" + mock_item.arguments = {"arg": 1} + mock_item.server_label = "My_MCP" + + mock_response.output = [mock_item] + + response = client._create_response_content(mock_response, chat_options=ChatOptions()) # type: ignore + + assert isinstance(response.messages[0].contents[0], FunctionApprovalRequestContent) + req = response.messages[0].contents[0] + assert req.id == "approval-1" + assert req.function_call.name == "do_sensitive_action" + assert req.function_call.arguments == {"arg": 1} + assert req.function_call.additional_properties["server_label"] == "My_MCP" + + +def test_create_streaming_response_content_with_mcp_approval_request() -> None: + """Test that a streaming mcp_approval_request event is parsed into FunctionApprovalRequestContent.""" + client = OpenAIResponsesClient(ai_model_id="test-model", api_key="test-key") + chat_options = ChatOptions() + function_call_ids: dict[int, tuple[str, str]] = {} + + mock_event = MagicMock() + mock_event.type = "response.output_item.added" + mock_item = MagicMock() + mock_item.type = "mcp_approval_request" + mock_item.id = "approval-stream-1" + mock_item.name = "do_stream_action" + mock_item.arguments = {"x": 2} + mock_item.server_label = "My_MCP" + mock_event.item = mock_item + + update = client._create_streaming_response_content(mock_event, chat_options, function_call_ids) + assert any(isinstance(c, FunctionApprovalRequestContent) for c in update.contents) + fa = next(c for c in update.contents if isinstance(c, FunctionApprovalRequestContent)) + assert fa.id == "approval-stream-1" + assert fa.function_call.name == "do_stream_action" + + +def test_end_to_end_mcp_approval_flow() -> None: + """End-to-end mocked test: + model issues an mcp_approval_request, user approves, client sends mcp_approval_response. + """ + client = OpenAIResponsesClient(ai_model_id="test-model", api_key="test-key") + + # First mocked response: model issues an mcp_approval_request + mock_response1 = MagicMock() + mock_response1.output_parsed = None + mock_response1.metadata = {} + mock_response1.usage = None + mock_response1.id = "resp-1" + mock_response1.model = "test-model" + mock_response1.created_at = 1000000000 + + mock_item = MagicMock() + mock_item.type = "mcp_approval_request" + mock_item.id = "approval-1" + mock_item.name = "do_sensitive_action" + mock_item.arguments = {"arg": "value"} + mock_item.server_label = "My_MCP" + mock_response1.output = [mock_item] + + # Second mocked response: simple assistant acknowledgement after approval + mock_response2 = MagicMock() + mock_response2.output_parsed = None + mock_response2.metadata = {} + mock_response2.usage = None + mock_response2.id = "resp-2" + mock_response2.model = "test-model" + mock_response2.created_at = 1000000001 + mock_text_item = MagicMock() + mock_text_item.type = "message" + mock_text_content = MagicMock() + mock_text_content.type = "output_text" + mock_text_content.text = "Approved." + mock_text_item.content = [mock_text_content] + mock_response2.output = [mock_text_item] + + # Patch the create call to return the two mocked responses in sequence + with patch.object(client.client.responses, "create", side_effect=[mock_response1, mock_response2]) as mock_create: + # First call: get the approval request + response = asyncio.run(client.get_response(messages=[ChatMessage(role="user", text="Trigger approval")])) + assert isinstance(response.messages[0].contents[0], FunctionApprovalRequestContent) + req = response.messages[0].contents[0] + assert req.id == "approval-1" + + # Build a user approval and send it (include required function_call) + approval = FunctionApprovalResponseContent(approved=True, id=req.id, function_call=req.function_call) + approval_message = ChatMessage(role="user", contents=[approval]) + _ = asyncio.run(client.get_response(messages=[approval_message])) + + # Ensure two calls were made and the second includes the mcp_approval_response + assert mock_create.call_count == 2 + _, kwargs = mock_create.call_args_list[1] + sent_input = kwargs.get("input") + assert isinstance(sent_input, list) + found = False + for item in sent_input: + if isinstance(item, dict) and item.get("type") == "mcp_approval_response": + assert item["approval_request_id"] == "approval-1" + assert item["approve"] is True + found = True + assert found + + def test_usage_details_basic() -> None: """Test _usage_details_from_openai without cached or reasoning tokens.""" client = OpenAIResponsesClient(ai_model_id="test-model", api_key="test-key") @@ -775,9 +928,10 @@ async def test_openai_responses_client_response() -> None: assert response is not None assert isinstance(response, ChatResponse) - output = OutputStruct.model_validate_json(response.text) + output = response.value + assert output is not None, "Response value is None" assert "seattle" in output.location.lower() - assert "sunny" in output.weather.lower() + assert output.weather is not None @skip_if_openai_integration_tests_disabled @@ -839,17 +993,11 @@ async def test_openai_responses_client_streaming() -> None: messages.append(ChatMessage(role="user", text="who are Emily and David?")) # Test that the client can be used to get a response - response = openai_responses_client.get_streaming_response(messages=messages) + response = await ChatResponse.from_chat_response_generator( + openai_responses_client.get_streaming_response(messages=messages) + ) - full_message: str = "" - async for chunk in response: - assert chunk is not None - assert isinstance(chunk, ChatResponseUpdate) - for content in chunk.contents: - if isinstance(content, TextContent) and content.text: - full_message += content.text - - assert "scientists" in full_message + assert "scientists" in response.text messages.clear() messages.append(ChatMessage(role="user", text="The weather in Seattle is sunny")) @@ -859,17 +1007,16 @@ async def test_openai_responses_client_streaming() -> None: messages=messages, response_format=OutputStruct, ) - full_message = "" + chunks = [] async for chunk in response: assert chunk is not None assert isinstance(chunk, ChatResponseUpdate) - for content in chunk.contents: - if isinstance(content, TextContent) and content.text: - full_message += content.text - - output = OutputStruct.model_validate_json(full_message) + chunks.append(chunk) + full_message = ChatResponse.from_chat_response_updates(chunks, output_format_type=OutputStruct) + output = full_message.value + assert output is not None, "Response value is None" assert "seattle" in output.location.lower() - assert "sunny" in output.weather.lower() + assert output.weather is not None @skip_if_openai_integration_tests_disabled @@ -906,15 +1053,15 @@ async def test_openai_responses_client_streaming_tools() -> None: tool_choice="auto", response_format=OutputStruct, ) - full_message = "" + chunks = [] async for chunk in response: assert chunk is not None assert isinstance(chunk, ChatResponseUpdate) - for content in chunk.contents: - if isinstance(content, TextContent) and content.text: - full_message += content.text + chunks.append(chunk) - output = OutputStruct.model_validate_json(full_message) + full_message = ChatResponse.from_chat_response_updates(chunks, output_format_type=OutputStruct) + output = full_message.value + assert output is not None, "Response value is None" assert "seattle" in output.location.lower() assert "sunny" in output.weather.lower() @@ -955,7 +1102,7 @@ async def test_openai_responses_client_web_search() -> None: tools=[HostedWebSearchTool(additional_properties=additional_properties)], tool_choice="auto", ) - assert "Seattle" in response.text + assert response.text is not None @skip_if_openai_integration_tests_disabled @@ -1008,7 +1155,7 @@ async def test_openai_responses_client_web_search_streaming() -> None: for content in chunk.contents: if isinstance(content, TextContent) and content.text: full_message += content.text - assert "Seattle" in full_message + assert full_message is not None @skip_if_openai_integration_tests_disabled diff --git a/python/samples/getting_started/agents/openai_responses_client/openai_responses_client_with_hosted_mcp.py b/python/samples/getting_started/agents/openai_responses_client/openai_responses_client_with_hosted_mcp.py new file mode 100644 index 0000000000..ee2140a72f --- /dev/null +++ b/python/samples/getting_started/agents/openai_responses_client/openai_responses_client_with_hosted_mcp.py @@ -0,0 +1,224 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from typing import TYPE_CHECKING, Any + +from agent_framework import ChatAgent, HostedMCPTool +from agent_framework.openai import OpenAIResponsesClient + +if TYPE_CHECKING: + from agent_framework import AgentProtocol, AgentThread + + +async def handle_approvals_without_thread(query: str, agent: "AgentProtocol"): + """When we don't have a thread, we need to ensure we return with the input, approval request and approval.""" + from agent_framework import ChatMessage + + result = await agent.run(query) + while len(result.user_input_requests) > 0: + new_inputs: list[Any] = [query] + for user_input_needed in result.user_input_requests: + print( + f"User Input Request for function from {agent.name}: {user_input_needed.function_call.name}" + f" with arguments: {user_input_needed.function_call.arguments}" + ) + new_inputs.append(ChatMessage(role="assistant", contents=[user_input_needed])) + user_approval = input("Approve function call? (y/n): ") + new_inputs.append( + ChatMessage(role="user", contents=[user_input_needed.create_response(user_approval.lower() == "y")]) + ) + + result = await agent.run(new_inputs) + return result + + +async def handle_approvals_with_thread(query: str, agent: "AgentProtocol", thread: "AgentThread"): + """Here we let the thread deal with the previous responses, and we just rerun with the approval.""" + from agent_framework import ChatMessage + + result = await agent.run(query, thread=thread, store=True) + while len(result.user_input_requests) > 0: + new_input: list[Any] = [] + for user_input_needed in result.user_input_requests: + print( + f"User Input Request for function from {agent.name}: {user_input_needed.function_call.name}" + f" with arguments: {user_input_needed.function_call.arguments}" + ) + user_approval = input("Approve function call? (y/n): ") + new_input.append( + ChatMessage( + role="user", + contents=[user_input_needed.create_response(user_approval.lower() == "y")], + ) + ) + result = await agent.run(new_input, thread=thread, store=True) + return result + + +async def handle_approvals_with_thread_streaming(query: str, agent: "AgentProtocol", thread: "AgentThread"): + """Here we let the thread deal with the previous responses, and we just rerun with the approval.""" + from agent_framework import ChatMessage + + new_input: list[ChatMessage] = [] + new_input_added = True + while new_input_added: + new_input_added = False + new_input.append(ChatMessage(role="user", text=query)) + async for update in agent.run_stream(new_input, thread=thread, store=True): + if update.user_input_requests: + for user_input_needed in update.user_input_requests: + print( + f"User Input Request for function from {agent.name}: {user_input_needed.function_call.name}" + f" with arguments: {user_input_needed.function_call.arguments}" + ) + user_approval = input("Approve function call? (y/n): ") + new_input.append( + ChatMessage( + role="user", contents=[user_input_needed.create_response(user_approval.lower() == "y")] + ) + ) + new_input_added = True + else: + yield update + + +async def run_hosted_mcp_without_thread_and_specific_approval() -> None: + """Example showing Mcp Tools with approvals without using a thread.""" + print("=== Mcp with approvals and without thread ===") + + # Tools are provided when creating the agent + # The agent can use these tools for any query during its lifetime + async with ChatAgent( + chat_client=OpenAIResponsesClient(), + name="DocsAgent", + instructions="You are a helpful assistant that can help with microsoft documentation questions.", + tools=HostedMCPTool( + name="Microsoft Learn MCP", + url="https://learn.microsoft.com/api/mcp", + # we don't require approval for microsoft_docs_search tool calls + # but we do for any other tool + approval_mode={"never_require_approval": ["microsoft_docs_search"]}, + ), + ) as agent: + # First query + query1 = "How to create an Azure storage account using az cli?" + print(f"User: {query1}") + result1 = await handle_approvals_without_thread(query1, agent) + print(f"{agent.name}: {result1}\n") + print("\n=======================================\n") + # Second query + query2 = "What is Microsoft Semantic Kernel?" + print(f"User: {query2}") + result2 = await handle_approvals_without_thread(query2, agent) + print(f"{agent.name}: {result2}\n") + + +async def run_hosted_mcp_without_approval() -> None: + """Example showing Mcp Tools without approvals.""" + print("=== Mcp without approvals ===") + + # Tools are provided when creating the agent + # The agent can use these tools for any query during its lifetime + async with ChatAgent( + chat_client=OpenAIResponsesClient(), + name="DocsAgent", + instructions="You are a helpful assistant that can help with microsoft documentation questions.", + tools=HostedMCPTool( + name="Microsoft Learn MCP", + url="https://learn.microsoft.com/api/mcp", + # we don't require approval for any function calls + # this means we will not see the approval messages, + # it is fully handled by the service and a final response is returned. + approval_mode="never_require", + ), + ) as agent: + # First query + query1 = "How to create an Azure storage account using az cli?" + print(f"User: {query1}") + result1 = await handle_approvals_without_thread(query1, agent) + print(f"{agent.name}: {result1}\n") + print("\n=======================================\n") + # Second query + query2 = "What is Microsoft Semantic Kernel?" + print(f"User: {query2}") + result2 = await handle_approvals_without_thread(query2, agent) + print(f"{agent.name}: {result2}\n") + + +async def run_hosted_mcp_with_thread() -> None: + """Example showing Mcp Tools with approvals using a thread.""" + print("=== Mcp with approvals and with thread ===") + + # Tools are provided when creating the agent + # The agent can use these tools for any query during its lifetime + async with ChatAgent( + chat_client=OpenAIResponsesClient(), + name="DocsAgent", + instructions="You are a helpful assistant that can help with microsoft documentation questions.", + tools=HostedMCPTool( + name="Microsoft Learn MCP", + url="https://learn.microsoft.com/api/mcp", + # we require approval for all function calls + approval_mode="always_require", + ), + ) as agent: + # First query + thread = agent.get_new_thread() + query1 = "How to create an Azure storage account using az cli?" + print(f"User: {query1}") + result1 = await handle_approvals_with_thread(query1, agent, thread) + print(f"{agent.name}: {result1}\n") + print("\n=======================================\n") + # Second query + query2 = "What is Microsoft Semantic Kernel?" + print(f"User: {query2}") + result2 = await handle_approvals_with_thread(query2, agent, thread) + print(f"{agent.name}: {result2}\n") + + +async def run_hosted_mcp_with_thread_streaming() -> None: + """Example showing Mcp Tools with approvals using a thread.""" + print("=== Mcp with approvals and with thread ===") + + # Tools are provided when creating the agent + # The agent can use these tools for any query during its lifetime + async with ChatAgent( + chat_client=OpenAIResponsesClient(), + name="DocsAgent", + instructions="You are a helpful assistant that can help with microsoft documentation questions.", + tools=HostedMCPTool( + name="Microsoft Learn MCP", + url="https://learn.microsoft.com/api/mcp", + # we require approval for all function calls + approval_mode="always_require", + ), + ) as agent: + # First query + thread = agent.get_new_thread() + query1 = "How to create an Azure storage account using az cli?" + print(f"User: {query1}") + print(f"{agent.name}: ", end="") + async for update in handle_approvals_with_thread_streaming(query1, agent, thread): + print(update, end="") + print("\n") + print("\n=======================================\n") + # Second query + query2 = "What is Microsoft Semantic Kernel?" + print(f"User: {query2}") + print(f"{agent.name}: ", end="") + async for update in handle_approvals_with_thread_streaming(query2, agent, thread): + print(update, end="") + print("\n") + + +async def main() -> None: + print("=== OpenAI Responses Client Agent with Hosted Mcp Tools Examples ===\n") + + await run_hosted_mcp_without_approval() + await run_hosted_mcp_without_thread_and_specific_approval() + await run_hosted_mcp_with_thread() + await run_hosted_mcp_with_thread_streaming() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/tests/samples/getting_started/test_agents.py b/python/tests/samples/getting_started/test_agents.py index c4a94f10ff..32ae175d4a 100644 --- a/python/tests/samples/getting_started/test_agents.py +++ b/python/tests/samples/getting_started/test_agents.py @@ -66,7 +66,13 @@ from samples.getting_started.agents.foundry.foundry_with_explicit_settings impor main as foundry_with_explicit_settings, ) from samples.getting_started.agents.foundry.foundry_with_function_tools import ( - main as foundry_with_function_tools, + mixed_tools_example as foundry_with_function_tools_mixed, +) +from samples.getting_started.agents.foundry.foundry_with_function_tools import ( + tools_on_agent_level as foundry_with_function_tools_agent, +) +from samples.getting_started.agents.foundry.foundry_with_function_tools import ( + tools_on_run_level as foundry_with_function_tools_run, ) from samples.getting_started.agents.foundry.foundry_with_local_mcp import ( main as foundry_with_local_mcp, @@ -323,7 +329,25 @@ agent_samples = [ ], ), param( - foundry_with_function_tools, + foundry_with_function_tools_agent, + [], # Non-interactive sample + id="foundry_with_function_tools", + marks=[ + pytest.mark.foundry, + pytest.mark.skipif(os.getenv(RUN_SAMPLES_TESTS, None) is None, reason="Not running sample tests."), + ], + ), + param( + foundry_with_function_tools_run, + [], # Non-interactive sample + id="foundry_with_function_tools", + marks=[ + pytest.mark.foundry, + pytest.mark.skipif(os.getenv(RUN_SAMPLES_TESTS, None) is None, reason="Not running sample tests."), + ], + ), + param( + foundry_with_function_tools_mixed, [], # Non-interactive sample id="foundry_with_function_tools", marks=[