Python: Create/Get Agent API for OpenAI Assistants (#3208)

* Added provider implementation

* Added example with response format

* Small improvements
This commit is contained in:
Dmytro Struk
2026-01-15 14:52:32 -08:00
committed by GitHub
Unverified
parent dd3e2b6e53
commit b5ca0c8eda
14 changed files with 1940 additions and 126 deletions
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
from ._assistant_provider import * # noqa: F403
from ._assistants_client import * # noqa: F403
from ._chat_client import * # noqa: F403
from ._exceptions import * # noqa: F403
@@ -0,0 +1,563 @@
# Copyright (c) Microsoft. All rights reserved.
import sys
from collections.abc import Awaitable, Callable, MutableMapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, TypedDict, cast
from openai import AsyncOpenAI
from openai.types.beta.assistant import Assistant
from pydantic import BaseModel, SecretStr, ValidationError
from .._agents import ChatAgent
from .._memory import ContextProvider
from .._middleware import Middleware
from .._tools import AIFunction, ToolProtocol
from .._types import normalize_tools
from ..exceptions import ServiceInitializationError
from ._assistants_client import OpenAIAssistantsClient
from ._shared import OpenAISettings, from_assistant_tools, to_assistant_tools
if TYPE_CHECKING:
from ._assistants_client import OpenAIAssistantsOptions
if sys.version_info >= (3, 13):
from typing import Self, TypeVar # pragma: no cover
else:
from typing_extensions import Self, TypeVar # pragma: no cover
__all__ = ["OpenAIAssistantProvider"]
# Type variable for options - allows typed ChatAgent[TOptions] returns
# Default matches OpenAIAssistantsClient's default options type
TOptions_co = TypeVar(
"TOptions_co",
bound=TypedDict, # type: ignore[valid-type]
default="OpenAIAssistantsOptions",
covariant=True,
)
_ToolsType = (
ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
| Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
)
class OpenAIAssistantProvider(Generic[TOptions_co]):
"""Provider for creating ChatAgent instances from OpenAI Assistants API.
This provider allows you to create, retrieve, and wrap OpenAI Assistants
as ChatAgent instances for use in the agent framework.
Examples:
Basic usage with automatic client creation:
.. code-block:: python
from agent_framework.openai import OpenAIAssistantProvider
# Uses OPENAI_API_KEY environment variable
provider = OpenAIAssistantProvider()
# Create a new assistant
agent = await provider.create_agent(
name="MyAssistant",
model="gpt-4",
instructions="You are a helpful assistant.",
tools=[my_function],
)
result = await agent.run("Hello!")
Using an existing client:
.. code-block:: python
from openai import AsyncOpenAI
from agent_framework.openai import OpenAIAssistantProvider
client = AsyncOpenAI()
provider = OpenAIAssistantProvider(client)
# Get an existing assistant by ID
agent = await provider.get_agent(
assistant_id="asst_123",
tools=[my_function], # Provide implementations for function tools
)
Wrapping an SDK Assistant object:
.. code-block:: python
# Fetch assistant directly via SDK
assistant = await client.beta.assistants.retrieve("asst_123")
# Wrap without additional HTTP call
agent = provider.as_agent(assistant, tools=[my_function])
"""
def __init__(
self,
client: AsyncOpenAI | None = None,
*,
api_key: str | SecretStr | Callable[[], str | Awaitable[str]] | None = None,
org_id: str | None = None,
base_url: str | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
"""Initialize the OpenAI Assistant Provider.
Args:
client: An existing AsyncOpenAI client to use. If not provided,
a new client will be created using the other parameters.
Keyword Args:
api_key: OpenAI API key. Can also be set via OPENAI_API_KEY env var.
org_id: OpenAI organization ID. Can also be set via OPENAI_ORG_ID env var.
base_url: Base URL for the OpenAI API. Can also be set via OPENAI_BASE_URL env var.
env_file_path: Path to .env file for configuration.
env_file_encoding: Encoding of the .env file.
Raises:
ServiceInitializationError: If no client is provided and API key is missing.
Examples:
.. code-block:: python
# Using environment variables
provider = OpenAIAssistantProvider()
# Using explicit API key
provider = OpenAIAssistantProvider(api_key="sk-...")
# Using existing client
client = AsyncOpenAI()
provider = OpenAIAssistantProvider(client)
"""
self._client: AsyncOpenAI | None = client
self._should_close_client: bool = client is None
if client is None:
# Load settings and create client
try:
settings = OpenAISettings(
api_key=api_key, # type: ignore[reportArgumentType]
org_id=org_id,
base_url=base_url,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create OpenAI settings.", ex) from ex
if not settings.api_key:
raise ServiceInitializationError(
"OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable."
)
# Get API key value
api_key_value: str | Callable[[], str | Awaitable[str]] | None
if isinstance(settings.api_key, SecretStr):
api_key_value = settings.api_key.get_secret_value()
else:
api_key_value = settings.api_key
# Create client
client_args: dict[str, Any] = {"api_key": api_key_value}
if settings.org_id:
client_args["organization"] = settings.org_id
if settings.base_url:
client_args["base_url"] = settings.base_url
self._client = AsyncOpenAI(**client_args)
async def __aenter__(self) -> "Self":
"""Async context manager entry."""
return self
async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None:
"""Async context manager exit."""
await self.close()
async def close(self) -> None:
"""Close the provider and clean up resources.
If the provider created its own client, it will be closed.
If an external client was provided, it will not be closed.
"""
if self._should_close_client and self._client is not None:
await self._client.close()
async def create_agent(
self,
*,
name: str,
model: str,
instructions: str | None = None,
description: str | None = None,
tools: _ToolsType | None = None,
metadata: dict[str, str] | None = None,
default_options: TOptions_co | None = None,
middleware: Sequence[Middleware] | None = None,
context_provider: ContextProvider | None = None,
) -> "ChatAgent[TOptions_co]":
"""Create a new assistant on OpenAI and return a ChatAgent.
This method creates a new assistant on the OpenAI service and wraps it
in a ChatAgent instance. The assistant will persist on OpenAI until deleted.
Keyword Args:
name: The name of the assistant (required).
model: The model ID to use, e.g., "gpt-4", "gpt-4o" (required).
instructions: System instructions for the assistant.
description: A description of the assistant.
tools: Tools available to the assistant. Can include:
- AIFunction instances or callables decorated with @ai_function
- HostedCodeInterpreterTool for code execution
- HostedFileSearchTool for vector store search
- Raw tool dictionaries
metadata: Metadata to attach to the assistant (max 16 key-value pairs).
default_options: A TypedDict containing default chat options for the agent.
These options are applied to every run unless overridden.
Include ``response_format`` here for structured output responses.
middleware: Middleware for the ChatAgent.
context_provider: Context provider for the ChatAgent.
Returns:
A ChatAgent instance wrapping the created assistant.
Raises:
ServiceInitializationError: If assistant creation fails.
Examples:
.. code-block:: python
provider = OpenAIAssistantProvider()
# Create with function tools
agent = await provider.create_agent(
name="WeatherBot",
model="gpt-4",
instructions="You are a helpful weather assistant.",
tools=[get_weather],
)
# Create with structured output
agent = await provider.create_agent(
name="StructuredBot",
model="gpt-4",
default_options={"response_format": MyPydanticModel},
)
"""
# Normalize tools
normalized_tools = normalize_tools(tools)
api_tools = to_assistant_tools(normalized_tools) if normalized_tools else []
# Extract response_format from default_options if present
opts = dict(default_options) if default_options else {}
response_format = opts.get("response_format")
# Build assistant creation parameters
create_params: dict[str, Any] = {
"model": model,
"name": name,
}
if instructions is not None:
create_params["instructions"] = instructions
if description is not None:
create_params["description"] = description
if api_tools:
create_params["tools"] = api_tools
if metadata is not None:
create_params["metadata"] = metadata
# Handle response format for OpenAI API
if response_format is not None and isinstance(response_format, type) and issubclass(response_format, BaseModel):
create_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": response_format.__name__,
"schema": response_format.model_json_schema(),
"strict": True,
},
}
# Create the assistant
if not self._client:
raise ServiceInitializationError("OpenAI client is not initialized.")
assistant = await self._client.beta.assistants.create(**create_params)
# Create ChatAgent - pass default_options which contains response_format
return self._create_chat_agent_from_assistant(
assistant=assistant,
tools=normalized_tools,
instructions=instructions,
middleware=middleware,
context_provider=context_provider,
default_options=default_options,
)
async def get_agent(
self,
assistant_id: str,
*,
tools: _ToolsType | None = None,
instructions: str | None = None,
default_options: TOptions_co | None = None,
middleware: Sequence[Middleware] | None = None,
context_provider: ContextProvider | None = None,
) -> "ChatAgent[TOptions_co]":
"""Retrieve an existing assistant by ID and return a ChatAgent.
This method fetches an existing assistant from OpenAI by its ID
and wraps it in a ChatAgent instance.
Args:
assistant_id: The ID of the assistant to retrieve (e.g., "asst_123").
Keyword Args:
tools: Function tools to make available. IMPORTANT: If the assistant
was created with function tools, you MUST provide matching
implementations here. Hosted tools (code_interpreter, file_search)
are automatically included.
instructions: Override the assistant's instructions (optional).
default_options: A TypedDict containing default chat options for the agent.
These options are applied to every run unless overridden.
middleware: Middleware for the ChatAgent.
context_provider: Context provider for the ChatAgent.
Returns:
A ChatAgent instance wrapping the retrieved assistant.
Raises:
ServiceInitializationError: If the assistant cannot be retrieved.
ValueError: If required function tools are missing.
Examples:
.. code-block:: python
provider = OpenAIAssistantProvider()
# Get assistant without function tools
agent = await provider.get_agent(assistant_id="asst_123")
# Get assistant with function tools
agent = await provider.get_agent(
assistant_id="asst_456",
tools=[get_weather, search_database], # Implementations required!
)
"""
# Fetch the assistant
if not self._client:
raise ServiceInitializationError("OpenAI client is not initialized.")
assistant = await self._client.beta.assistants.retrieve(assistant_id)
# Use as_agent to wrap it
return self.as_agent(
assistant=assistant,
tools=tools,
instructions=instructions,
default_options=default_options,
middleware=middleware,
context_provider=context_provider,
)
def as_agent(
self,
assistant: Assistant,
*,
tools: _ToolsType | None = None,
instructions: str | None = None,
default_options: TOptions_co | None = None,
middleware: Sequence[Middleware] | None = None,
context_provider: ContextProvider | None = None,
) -> "ChatAgent[TOptions_co]":
"""Wrap an existing SDK Assistant object as a ChatAgent.
This method does NOT make any HTTP calls. It simply wraps an already-
fetched Assistant object in a ChatAgent.
Args:
assistant: The OpenAI Assistant SDK object to wrap.
Keyword Args:
tools: Function tools to make available. If the assistant has
function tools defined, you MUST provide matching implementations.
Hosted tools (code_interpreter, file_search) are automatically included.
instructions: Override the assistant's instructions (optional).
default_options: A TypedDict containing default chat options for the agent.
These options are applied to every run unless overridden.
middleware: Middleware for the ChatAgent.
context_provider: Context provider for the ChatAgent.
Returns:
A ChatAgent instance wrapping the assistant.
Raises:
ValueError: If required function tools are missing.
Examples:
.. code-block:: python
client = AsyncOpenAI()
provider = OpenAIAssistantProvider(client)
# Fetch assistant via SDK
assistant = await client.beta.assistants.retrieve("asst_123")
# Wrap without additional HTTP call
agent = provider.as_agent(
assistant,
tools=[my_function],
instructions="Custom instructions override",
)
"""
# Validate that required function tools are provided
self._validate_function_tools(assistant.tools or [], tools)
# Merge hosted tools with user-provided function tools
merged_tools = self._merge_tools(assistant.tools or [], tools)
# Create ChatAgent
return self._create_chat_agent_from_assistant(
assistant=assistant,
tools=merged_tools,
instructions=instructions,
default_options=default_options,
middleware=middleware,
context_provider=context_provider,
)
def _validate_function_tools(
self,
assistant_tools: list[Any],
provided_tools: _ToolsType | None,
) -> None:
"""Validate that required function tools are provided.
Args:
assistant_tools: Tools defined on the assistant.
provided_tools: Tools provided by the user.
Raises:
ValueError: If a required function tool is missing.
"""
# Get function tool names from assistant
required_functions: set[str] = set()
for tool in assistant_tools:
if (
hasattr(tool, "type")
and tool.type == "function"
and hasattr(tool, "function")
and hasattr(tool.function, "name")
):
required_functions.add(tool.function.name)
if not required_functions:
return # No function tools required
# Get provided function names using normalize_tools
provided_functions: set[str] = set()
if provided_tools is not None:
normalized = normalize_tools(provided_tools)
for tool in normalized:
if isinstance(tool, AIFunction):
provided_functions.add(tool.name)
elif isinstance(tool, MutableMapping) and "function" in tool:
func_spec = tool.get("function", {})
if isinstance(func_spec, dict):
func_dict = cast(dict[str, Any], func_spec)
if "name" in func_dict:
provided_functions.add(str(func_dict["name"]))
# Check for missing functions
missing = required_functions - provided_functions
if missing:
missing_list = ", ".join(sorted(missing))
raise ValueError(
f"Assistant requires function tool(s) '{missing_list}' but no implementation was provided. "
f"Please pass the function implementation(s) in the 'tools' parameter."
)
def _merge_tools(
self,
assistant_tools: list[Any],
user_tools: _ToolsType | None,
) -> list[ToolProtocol | MutableMapping[str, Any]]:
"""Merge hosted tools from assistant with user-provided function tools.
Args:
assistant_tools: Tools defined on the assistant.
user_tools: Tools provided by the user.
Returns:
A list of all tools (hosted tools + user function implementations).
"""
merged: list[ToolProtocol | MutableMapping[str, Any]] = []
# Add hosted tools from assistant using shared conversion
hosted_tools = from_assistant_tools(assistant_tools)
merged.extend(hosted_tools)
# Add user-provided tools (normalized)
if user_tools is not None:
normalized_user_tools = normalize_tools(user_tools)
merged.extend(normalized_user_tools)
return merged
def _create_chat_agent_from_assistant(
self,
assistant: Assistant,
tools: list[ToolProtocol | MutableMapping[str, Any]] | None,
instructions: str | None,
middleware: Sequence[Middleware] | None,
context_provider: ContextProvider | None,
default_options: TOptions_co | None = None,
**kwargs: Any,
) -> "ChatAgent[TOptions_co]":
"""Create a ChatAgent from an Assistant.
Args:
assistant: The OpenAI Assistant object.
tools: Tools for the agent.
instructions: Instructions override.
middleware: Middleware for the agent.
context_provider: Context provider for the agent.
default_options: Default chat options for the agent (may include response_format).
**kwargs: Additional arguments passed to ChatAgent.
Returns:
A configured ChatAgent instance.
"""
# Create the chat client with the assistant
chat_client = OpenAIAssistantsClient(
model_id=assistant.model,
assistant_id=assistant.id,
assistant_name=assistant.name,
assistant_description=assistant.description,
async_client=self._client,
)
# Use instructions from assistant if not overridden
final_instructions = instructions if instructions is not None else assistant.instructions
# Create and return ChatAgent
return ChatAgent(
chat_client=chat_client,
id=assistant.id,
name=assistant.name,
description=assistant.description,
instructions=final_instructions,
tools=tools if tools else None,
middleware=middleware,
context_provider=context_provider,
default_options=default_options, # type: ignore[arg-type]
**kwargs,
)
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
import logging
from collections.abc import Awaitable, Callable, Mapping
from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence
from copy import copy
from typing import Any, ClassVar, Union
@@ -24,6 +24,7 @@ from .._logging import get_logger
from .._pydantic import AFBaseSettings
from .._serialization import SerializationMixin
from .._telemetry import APP_INFO, USER_AGENT_KEY, prepend_agent_framework_to_user_agent
from .._tools import AIFunction, HostedCodeInterpreterTool, HostedFileSearchTool, ToolProtocol
from ..exceptions import ServiceInitializationError
logger: logging.Logger = get_logger("agent_framework.openai")
@@ -275,3 +276,74 @@ class OpenAIConfigMixin(OpenAIBase):
# Ensure additional_properties and middleware are passed through kwargs to BaseChatClient
# These are consumed by BaseChatClient.__init__ via kwargs
super().__init__(**args, **kwargs)
def to_assistant_tools(
tools: Sequence[ToolProtocol | MutableMapping[str, Any]] | None,
) -> list[dict[str, Any]]:
"""Convert Agent Framework tools to OpenAI Assistants API format.
Args:
tools: Normalized tools (from ChatOptions.tools).
Returns:
List of tool definitions for OpenAI Assistants API.
"""
if not tools:
return []
tool_definitions: list[dict[str, Any]] = []
for tool in tools:
if isinstance(tool, AIFunction):
tool_definitions.append(tool.to_json_schema_spec())
elif isinstance(tool, HostedCodeInterpreterTool):
tool_definitions.append({"type": "code_interpreter"})
elif isinstance(tool, HostedFileSearchTool):
params: dict[str, Any] = {"type": "file_search"}
if tool.max_results is not None:
params["file_search"] = {"max_num_results": tool.max_results}
tool_definitions.append(params)
elif isinstance(tool, MutableMapping):
# Pass through raw dict definitions
tool_definitions.append(dict(tool))
return tool_definitions
def from_assistant_tools(
assistant_tools: list[Any] | None,
) -> list[ToolProtocol]:
"""Convert OpenAI Assistant tools to Agent Framework format.
This converts hosted tools (code_interpreter, file_search) from an OpenAI
Assistant definition back to Agent Framework tool instances.
Note: Function tools are skipped - user must provide implementations separately.
Args:
assistant_tools: Tools from OpenAI Assistant object (assistant.tools).
Returns:
List of Agent Framework tool instances for hosted tools.
"""
if not assistant_tools:
return []
tools: list[ToolProtocol] = []
for tool in assistant_tools:
if hasattr(tool, "type"):
tool_type = tool.type
elif isinstance(tool, dict):
tool_type = tool.get("type")
else:
tool_type = None
if tool_type == "code_interpreter":
tools.append(HostedCodeInterpreterTool())
elif tool_type == "file_search":
tools.append(HostedFileSearchTool())
# Skip function tools - user must provide implementations
return tools
@@ -0,0 +1,814 @@
# Copyright (c) Microsoft. All rights reserved.
import os
from typing import Annotated, Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from openai.types.beta.assistant import Assistant
from pydantic import BaseModel, Field
from agent_framework import ChatAgent, HostedCodeInterpreterTool, HostedFileSearchTool, ai_function, normalize_tools
from agent_framework.exceptions import ServiceInitializationError
from agent_framework.openai import OpenAIAssistantProvider
from agent_framework.openai._shared import from_assistant_tools, to_assistant_tools
# region Test Helpers
def create_mock_assistant(
assistant_id: str = "asst_test123",
name: str = "TestAssistant",
model: str = "gpt-4",
instructions: str | None = "You are a helpful assistant.",
description: str | None = None,
tools: list[Any] | None = None,
) -> Assistant:
"""Create a mock Assistant object."""
mock = MagicMock(spec=Assistant)
mock.id = assistant_id
mock.name = name
mock.model = model
mock.instructions = instructions
mock.description = description
mock.tools = tools or []
return mock
def create_function_tool(name: str, description: str = "A test function") -> MagicMock:
"""Create a mock FunctionTool."""
mock = MagicMock()
mock.type = "function"
mock.function = MagicMock()
mock.function.name = name
mock.function.description = description
return mock
def create_code_interpreter_tool() -> MagicMock:
"""Create a mock CodeInterpreterTool."""
mock = MagicMock()
mock.type = "code_interpreter"
return mock
def create_file_search_tool() -> MagicMock:
"""Create a mock FileSearchTool."""
mock = MagicMock()
mock.type = "file_search"
return mock
@pytest.fixture
def mock_async_openai() -> MagicMock:
"""Mock AsyncOpenAI client."""
mock_client = MagicMock()
# Mock beta.assistants
mock_client.beta.assistants.create = AsyncMock(
return_value=create_mock_assistant(assistant_id="asst_created123", name="CreatedAssistant")
)
mock_client.beta.assistants.retrieve = AsyncMock(
return_value=create_mock_assistant(assistant_id="asst_retrieved123", name="RetrievedAssistant")
)
mock_client.beta.assistants.delete = AsyncMock()
# Mock close method
mock_client.close = AsyncMock()
return mock_client
# Test function for tool validation
def get_weather(location: Annotated[str, Field(description="The location")]) -> str:
"""Get the weather for a location."""
return f"Weather in {location}: sunny"
def search_database(query: Annotated[str, Field(description="Search query")]) -> str:
"""Search the database."""
return f"Results for: {query}"
# Pydantic model for structured output tests
class WeatherResponse(BaseModel):
location: str
temperature: float
conditions: str
# endregion
# region Initialization Tests
class TestOpenAIAssistantProviderInit:
"""Tests for provider initialization."""
def test_init_with_client(self, mock_async_openai: MagicMock) -> None:
"""Test initialization with existing AsyncOpenAI client."""
provider = OpenAIAssistantProvider(mock_async_openai)
assert provider._client is mock_async_openai # type: ignore[reportPrivateUsage]
assert provider._should_close_client is False # type: ignore[reportPrivateUsage]
def test_init_without_client_creates_one(self, openai_unit_test_env: dict[str, str]) -> None:
"""Test initialization creates client from settings."""
provider = OpenAIAssistantProvider()
assert provider._client is not None # type: ignore[reportPrivateUsage]
assert provider._should_close_client is True # type: ignore[reportPrivateUsage]
def test_init_with_api_key(self) -> None:
"""Test initialization with explicit API key."""
provider = OpenAIAssistantProvider(api_key="sk-test-key")
assert provider._client is not None # type: ignore[reportPrivateUsage]
assert provider._should_close_client is True # type: ignore[reportPrivateUsage]
def test_init_fails_without_api_key(self) -> None:
"""Test initialization fails without API key when settings return None."""
from unittest.mock import patch
# Mock OpenAISettings to return None for api_key
with patch("agent_framework.openai._assistant_provider.OpenAISettings") as mock_settings:
mock_settings.return_value.api_key = None
with pytest.raises(ServiceInitializationError) as exc_info:
OpenAIAssistantProvider()
assert "API key is required" in str(exc_info.value)
def test_init_with_org_id_and_base_url(self) -> None:
"""Test initialization with organization ID and base URL."""
provider = OpenAIAssistantProvider(
api_key="sk-test-key",
org_id="org-123",
base_url="https://custom.openai.com",
)
assert provider._client is not None # type: ignore[reportPrivateUsage]
class TestOpenAIAssistantProviderContextManager:
"""Tests for async context manager."""
async def test_context_manager_enter_exit(self, mock_async_openai: MagicMock) -> None:
"""Test async context manager entry and exit."""
provider = OpenAIAssistantProvider(mock_async_openai)
async with provider as p:
assert p is provider
async def test_context_manager_closes_owned_client(self, openai_unit_test_env: dict[str, str]) -> None:
"""Test that owned client is closed on exit."""
provider = OpenAIAssistantProvider()
client = provider._client # type: ignore[reportPrivateUsage]
assert client is not None
client.close = AsyncMock()
async with provider:
pass
client.close.assert_called_once()
async def test_context_manager_does_not_close_external_client(self, mock_async_openai: MagicMock) -> None:
"""Test that external client is not closed on exit."""
provider = OpenAIAssistantProvider(mock_async_openai)
async with provider:
pass
mock_async_openai.close.assert_not_called()
# endregion
# region create_agent Tests
class TestOpenAIAssistantProviderCreateAgent:
"""Tests for create_agent method."""
async def test_create_agent_basic(self, mock_async_openai: MagicMock) -> None:
"""Test basic assistant creation."""
provider = OpenAIAssistantProvider(mock_async_openai)
agent = await provider.create_agent(
name="TestAgent",
model="gpt-4",
instructions="You are helpful.",
)
assert isinstance(agent, ChatAgent)
assert agent.name == "CreatedAssistant"
mock_async_openai.beta.assistants.create.assert_called_once()
# Verify create was called with correct parameters
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
assert call_kwargs["name"] == "TestAgent"
assert call_kwargs["model"] == "gpt-4"
assert call_kwargs["instructions"] == "You are helpful."
async def test_create_agent_with_description(self, mock_async_openai: MagicMock) -> None:
"""Test assistant creation with description."""
provider = OpenAIAssistantProvider(mock_async_openai)
await provider.create_agent(
name="TestAgent",
model="gpt-4",
description="A test agent description",
)
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
assert call_kwargs["description"] == "A test agent description"
async def test_create_agent_with_function_tools(self, mock_async_openai: MagicMock) -> None:
"""Test assistant creation with function tools."""
provider = OpenAIAssistantProvider(mock_async_openai)
agent = await provider.create_agent(
name="WeatherAgent",
model="gpt-4",
tools=[get_weather],
)
assert isinstance(agent, ChatAgent)
# Verify tools were passed to create
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
assert "tools" in call_kwargs
assert len(call_kwargs["tools"]) == 1
assert call_kwargs["tools"][0]["type"] == "function"
assert call_kwargs["tools"][0]["function"]["name"] == "get_weather"
async def test_create_agent_with_ai_function(self, mock_async_openai: MagicMock) -> None:
"""Test assistant creation with AIFunction."""
provider = OpenAIAssistantProvider(mock_async_openai)
@ai_function
def my_function(x: int) -> int:
"""Double a number."""
return x * 2
await provider.create_agent(
name="TestAgent",
model="gpt-4",
tools=[my_function],
)
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
assert call_kwargs["tools"][0]["function"]["name"] == "my_function"
async def test_create_agent_with_code_interpreter(self, mock_async_openai: MagicMock) -> None:
"""Test assistant creation with code interpreter."""
provider = OpenAIAssistantProvider(mock_async_openai)
await provider.create_agent(
name="CodeAgent",
model="gpt-4",
tools=[HostedCodeInterpreterTool()],
)
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
assert {"type": "code_interpreter"} in call_kwargs["tools"]
async def test_create_agent_with_file_search(self, mock_async_openai: MagicMock) -> None:
"""Test assistant creation with file search."""
provider = OpenAIAssistantProvider(mock_async_openai)
await provider.create_agent(
name="SearchAgent",
model="gpt-4",
tools=[HostedFileSearchTool()],
)
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
assert any(t["type"] == "file_search" for t in call_kwargs["tools"])
async def test_create_agent_with_file_search_max_results(self, mock_async_openai: MagicMock) -> None:
"""Test assistant creation with file search and max_results."""
provider = OpenAIAssistantProvider(mock_async_openai)
await provider.create_agent(
name="SearchAgent",
model="gpt-4",
tools=[HostedFileSearchTool(max_results=10)],
)
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
file_search_tool = next(t for t in call_kwargs["tools"] if t["type"] == "file_search")
assert file_search_tool.get("file_search", {}).get("max_num_results") == 10
async def test_create_agent_with_mixed_tools(self, mock_async_openai: MagicMock) -> None:
"""Test assistant creation with multiple tool types."""
provider = OpenAIAssistantProvider(mock_async_openai)
await provider.create_agent(
name="MultiToolAgent",
model="gpt-4",
tools=[get_weather, HostedCodeInterpreterTool(), HostedFileSearchTool()],
)
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
assert len(call_kwargs["tools"]) == 3
async def test_create_agent_with_metadata(self, mock_async_openai: MagicMock) -> None:
"""Test assistant creation with metadata."""
provider = OpenAIAssistantProvider(mock_async_openai)
await provider.create_agent(
name="TestAgent",
model="gpt-4",
metadata={"env": "test", "version": "1.0"},
)
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
assert call_kwargs["metadata"] == {"env": "test", "version": "1.0"}
async def test_create_agent_with_response_format_pydantic(self, mock_async_openai: MagicMock) -> None:
"""Test assistant creation with Pydantic response format via default_options."""
provider = OpenAIAssistantProvider(mock_async_openai)
await provider.create_agent(
name="StructuredAgent",
model="gpt-4",
default_options={"response_format": WeatherResponse},
)
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
assert call_kwargs["response_format"]["type"] == "json_schema"
assert call_kwargs["response_format"]["json_schema"]["name"] == "WeatherResponse"
async def test_create_agent_returns_chat_agent(self, mock_async_openai: MagicMock) -> None:
"""Test that create_agent returns a ChatAgent instance."""
provider = OpenAIAssistantProvider(mock_async_openai)
agent = await provider.create_agent(
name="TestAgent",
model="gpt-4",
)
assert isinstance(agent, ChatAgent)
# endregion
# region get_agent Tests
class TestOpenAIAssistantProviderGetAgent:
"""Tests for get_agent method."""
async def test_get_agent_basic(self, mock_async_openai: MagicMock) -> None:
"""Test retrieving an existing assistant."""
provider = OpenAIAssistantProvider(mock_async_openai)
agent = await provider.get_agent(assistant_id="asst_123")
assert isinstance(agent, ChatAgent)
mock_async_openai.beta.assistants.retrieve.assert_called_once_with("asst_123")
async def test_get_agent_with_instructions_override(self, mock_async_openai: MagicMock) -> None:
"""Test retrieving assistant with instruction override."""
provider = OpenAIAssistantProvider(mock_async_openai)
agent = await provider.get_agent(
assistant_id="asst_123",
instructions="Custom instructions",
)
# Agent should be created successfully with the custom instructions
assert isinstance(agent, ChatAgent)
assert agent.id == "asst_retrieved123"
async def test_get_agent_with_function_tools(self, mock_async_openai: MagicMock) -> None:
"""Test retrieving assistant with function tools provided."""
# Setup assistant with function tool
assistant = create_mock_assistant(tools=[create_function_tool("get_weather")])
mock_async_openai.beta.assistants.retrieve = AsyncMock(return_value=assistant)
provider = OpenAIAssistantProvider(mock_async_openai)
agent = await provider.get_agent(
assistant_id="asst_123",
tools=[get_weather],
)
assert isinstance(agent, ChatAgent)
async def test_get_agent_validates_missing_function_tools(self, mock_async_openai: MagicMock) -> None:
"""Test that missing function tools raise ValueError."""
# Setup assistant with function tool
assistant = create_mock_assistant(tools=[create_function_tool("get_weather")])
mock_async_openai.beta.assistants.retrieve = AsyncMock(return_value=assistant)
provider = OpenAIAssistantProvider(mock_async_openai)
with pytest.raises(ValueError) as exc_info:
await provider.get_agent(assistant_id="asst_123")
assert "get_weather" in str(exc_info.value)
assert "no implementation was provided" in str(exc_info.value)
async def test_get_agent_validates_multiple_missing_function_tools(self, mock_async_openai: MagicMock) -> None:
"""Test validation with multiple missing function tools."""
assistant = create_mock_assistant(
tools=[create_function_tool("get_weather"), create_function_tool("search_database")]
)
mock_async_openai.beta.assistants.retrieve = AsyncMock(return_value=assistant)
provider = OpenAIAssistantProvider(mock_async_openai)
with pytest.raises(ValueError) as exc_info:
await provider.get_agent(assistant_id="asst_123")
error_msg = str(exc_info.value)
assert "get_weather" in error_msg or "search_database" in error_msg
async def test_get_agent_merges_hosted_tools(self, mock_async_openai: MagicMock) -> None:
"""Test that hosted tools are automatically included."""
assistant = create_mock_assistant(tools=[create_code_interpreter_tool(), create_file_search_tool()])
mock_async_openai.beta.assistants.retrieve = AsyncMock(return_value=assistant)
provider = OpenAIAssistantProvider(mock_async_openai)
agent = await provider.get_agent(assistant_id="asst_123")
# Hosted tools should be merged automatically
assert isinstance(agent, ChatAgent)
# endregion
# region as_agent Tests
class TestOpenAIAssistantProviderAsAgent:
"""Tests for as_agent method."""
def test_as_agent_no_http_call(self, mock_async_openai: MagicMock) -> None:
"""Test that as_agent doesn't make HTTP calls."""
provider = OpenAIAssistantProvider(mock_async_openai)
assistant = create_mock_assistant()
agent = provider.as_agent(assistant)
assert isinstance(agent, ChatAgent)
# Verify no HTTP calls were made
mock_async_openai.beta.assistants.create.assert_not_called()
mock_async_openai.beta.assistants.retrieve.assert_not_called()
def test_as_agent_wraps_assistant(self, mock_async_openai: MagicMock) -> None:
"""Test wrapping an SDK Assistant object."""
provider = OpenAIAssistantProvider(mock_async_openai)
assistant = create_mock_assistant(
assistant_id="asst_wrap123",
name="WrappedAssistant",
instructions="Original instructions",
)
agent = provider.as_agent(assistant)
assert agent.id == "asst_wrap123"
assert agent.name == "WrappedAssistant"
# Instructions are passed to ChatOptions, not exposed as attribute
assert isinstance(agent, ChatAgent)
def test_as_agent_with_instructions_override(self, mock_async_openai: MagicMock) -> None:
"""Test as_agent with instruction override."""
provider = OpenAIAssistantProvider(mock_async_openai)
assistant = create_mock_assistant(instructions="Original")
agent = provider.as_agent(assistant, instructions="Override")
# Agent should be created successfully with override instructions
assert isinstance(agent, ChatAgent)
def test_as_agent_validates_function_tools(self, mock_async_openai: MagicMock) -> None:
"""Test that missing function tools raise ValueError."""
provider = OpenAIAssistantProvider(mock_async_openai)
assistant = create_mock_assistant(tools=[create_function_tool("get_weather")])
with pytest.raises(ValueError) as exc_info:
provider.as_agent(assistant)
assert "get_weather" in str(exc_info.value)
def test_as_agent_with_function_tools_provided(self, mock_async_openai: MagicMock) -> None:
"""Test as_agent with function tools provided."""
provider = OpenAIAssistantProvider(mock_async_openai)
assistant = create_mock_assistant(tools=[create_function_tool("get_weather")])
agent = provider.as_agent(assistant, tools=[get_weather])
assert isinstance(agent, ChatAgent)
def test_as_agent_merges_hosted_tools(self, mock_async_openai: MagicMock) -> None:
"""Test that hosted tools are merged automatically."""
provider = OpenAIAssistantProvider(mock_async_openai)
assistant = create_mock_assistant(tools=[create_code_interpreter_tool()])
agent = provider.as_agent(assistant)
assert isinstance(agent, ChatAgent)
def test_as_agent_hosted_tools_not_required(self, mock_async_openai: MagicMock) -> None:
"""Test that hosted tools don't require user implementations."""
provider = OpenAIAssistantProvider(mock_async_openai)
assistant = create_mock_assistant(tools=[create_code_interpreter_tool(), create_file_search_tool()])
# Should not raise - hosted tools don't need implementations
agent = provider.as_agent(assistant)
assert isinstance(agent, ChatAgent)
# endregion
# region Tool Conversion Tests
class TestToolConversion:
"""Tests for tool conversion utilities (shared functions)."""
def test_to_assistant_tools_ai_function(self) -> None:
"""Test AIFunction conversion to API format."""
@ai_function
def test_func(x: int) -> int:
"""Test function."""
return x
# Normalize tools first, then convert
normalized = normalize_tools([test_func])
api_tools = to_assistant_tools(normalized)
assert len(api_tools) == 1
assert api_tools[0]["type"] == "function"
assert api_tools[0]["function"]["name"] == "test_func"
def test_to_assistant_tools_callable(self) -> None:
"""Test raw callable conversion via normalize_tools."""
# normalize_tools converts callables to AIFunction
normalized = normalize_tools([get_weather])
api_tools = to_assistant_tools(normalized)
assert len(api_tools) == 1
assert api_tools[0]["type"] == "function"
assert api_tools[0]["function"]["name"] == "get_weather"
def test_to_assistant_tools_code_interpreter(self) -> None:
"""Test HostedCodeInterpreterTool conversion."""
api_tools = to_assistant_tools([HostedCodeInterpreterTool()])
assert len(api_tools) == 1
assert api_tools[0] == {"type": "code_interpreter"}
def test_to_assistant_tools_file_search(self) -> None:
"""Test HostedFileSearchTool conversion."""
api_tools = to_assistant_tools([HostedFileSearchTool()])
assert len(api_tools) == 1
assert api_tools[0]["type"] == "file_search"
def test_to_assistant_tools_file_search_with_max_results(self) -> None:
"""Test HostedFileSearchTool with max_results conversion."""
api_tools = to_assistant_tools([HostedFileSearchTool(max_results=5)])
assert api_tools[0]["file_search"]["max_num_results"] == 5
def test_to_assistant_tools_dict(self) -> None:
"""Test raw dict tool passthrough."""
raw_tool = {"type": "function", "function": {"name": "custom", "description": "Custom tool"}}
api_tools = to_assistant_tools([raw_tool])
assert len(api_tools) == 1
assert api_tools[0] == raw_tool
def test_to_assistant_tools_empty(self) -> None:
"""Test conversion with no tools."""
api_tools = to_assistant_tools(None)
assert api_tools == []
def test_from_assistant_tools_code_interpreter(self) -> None:
"""Test converting code_interpreter tool from OpenAI format."""
assistant_tools = [create_code_interpreter_tool()]
tools = from_assistant_tools(assistant_tools)
assert len(tools) == 1
assert isinstance(tools[0], HostedCodeInterpreterTool)
def test_from_assistant_tools_file_search(self) -> None:
"""Test converting file_search tool from OpenAI format."""
assistant_tools = [create_file_search_tool()]
tools = from_assistant_tools(assistant_tools)
assert len(tools) == 1
assert isinstance(tools[0], HostedFileSearchTool)
def test_from_assistant_tools_function_skipped(self) -> None:
"""Test that function tools are skipped (no implementations)."""
assistant_tools = [create_function_tool("test_func")]
tools = from_assistant_tools(assistant_tools)
assert len(tools) == 0 # Function tools are skipped
def test_from_assistant_tools_empty(self) -> None:
"""Test conversion with no tools."""
tools = from_assistant_tools(None)
assert tools == []
# endregion
# region Tool Validation Tests
class TestToolValidation:
"""Tests for tool validation."""
def test_validate_missing_function_tool_raises(self, mock_async_openai: MagicMock) -> None:
"""Test that missing function tools raise ValueError."""
provider = OpenAIAssistantProvider(mock_async_openai)
assistant_tools = [create_function_tool("my_function")]
with pytest.raises(ValueError) as exc_info:
provider._validate_function_tools(assistant_tools, None) # type: ignore[reportPrivateUsage]
assert "my_function" in str(exc_info.value)
def test_validate_all_tools_provided_passes(self, mock_async_openai: MagicMock) -> None:
"""Test that validation passes when all tools provided."""
provider = OpenAIAssistantProvider(mock_async_openai)
assistant_tools = [create_function_tool("get_weather")]
# Should not raise
provider._validate_function_tools(assistant_tools, [get_weather]) # type: ignore[reportPrivateUsage]
def test_validate_hosted_tools_not_required(self, mock_async_openai: MagicMock) -> None:
"""Test that hosted tools don't require implementations."""
provider = OpenAIAssistantProvider(mock_async_openai)
assistant_tools = [create_code_interpreter_tool(), create_file_search_tool()]
# Should not raise
provider._validate_function_tools(assistant_tools, None) # type: ignore[reportPrivateUsage]
def test_validate_with_ai_function(self, mock_async_openai: MagicMock) -> None:
"""Test validation with AIFunction."""
provider = OpenAIAssistantProvider(mock_async_openai)
assistant_tools = [create_function_tool("get_weather")]
wrapped = ai_function(get_weather)
# Should not raise
provider._validate_function_tools(assistant_tools, [wrapped]) # type: ignore[reportPrivateUsage]
def test_validate_partial_tools_raises(self, mock_async_openai: MagicMock) -> None:
"""Test that partial tool provision raises error."""
provider = OpenAIAssistantProvider(mock_async_openai)
assistant_tools = [
create_function_tool("get_weather"),
create_function_tool("search_database"),
]
with pytest.raises(ValueError) as exc_info:
provider._validate_function_tools(assistant_tools, [get_weather]) # type: ignore[reportPrivateUsage]
assert "search_database" in str(exc_info.value)
# endregion
# region Tool Merging Tests
class TestToolMerging:
"""Tests for tool merging."""
def test_merge_code_interpreter(self, mock_async_openai: MagicMock) -> None:
"""Test merging code interpreter tool."""
provider = OpenAIAssistantProvider(mock_async_openai)
assistant_tools = [create_code_interpreter_tool()]
merged = provider._merge_tools(assistant_tools, None) # type: ignore[reportPrivateUsage]
assert len(merged) == 1
assert isinstance(merged[0], HostedCodeInterpreterTool)
def test_merge_file_search(self, mock_async_openai: MagicMock) -> None:
"""Test merging file search tool."""
provider = OpenAIAssistantProvider(mock_async_openai)
assistant_tools = [create_file_search_tool()]
merged = provider._merge_tools(assistant_tools, None) # type: ignore[reportPrivateUsage]
assert len(merged) == 1
assert isinstance(merged[0], HostedFileSearchTool)
def test_merge_with_user_tools(self, mock_async_openai: MagicMock) -> None:
"""Test merging hosted and user tools."""
provider = OpenAIAssistantProvider(mock_async_openai)
assistant_tools = [create_code_interpreter_tool()]
merged = provider._merge_tools(assistant_tools, [get_weather]) # type: ignore[reportPrivateUsage]
assert len(merged) == 2
assert isinstance(merged[0], HostedCodeInterpreterTool)
def test_merge_multiple_hosted_tools(self, mock_async_openai: MagicMock) -> None:
"""Test merging multiple hosted tools."""
provider = OpenAIAssistantProvider(mock_async_openai)
assistant_tools = [create_code_interpreter_tool(), create_file_search_tool()]
merged = provider._merge_tools(assistant_tools, None) # type: ignore[reportPrivateUsage]
assert len(merged) == 2
def test_merge_single_user_tool(self, mock_async_openai: MagicMock) -> None:
"""Test merging with single user tool (not list)."""
provider = OpenAIAssistantProvider(mock_async_openai)
assistant_tools: list[Any] = []
merged = provider._merge_tools(assistant_tools, get_weather) # type: ignore[reportPrivateUsage]
assert len(merged) == 1
# endregion
# region Integration Tests
skip_if_openai_integration_tests_disabled = pytest.mark.skipif(
os.getenv("RUN_INTEGRATION_TESTS", "false").lower() != "true"
or os.getenv("OPENAI_API_KEY", "") in ("", "test-dummy-key"),
reason="No real OPENAI_API_KEY provided; skipping integration tests."
if os.getenv("RUN_INTEGRATION_TESTS", "false").lower() == "true"
else "Integration tests are disabled.",
)
@skip_if_openai_integration_tests_disabled
class TestOpenAIAssistantProviderIntegration:
"""Integration tests requiring real OpenAI API."""
async def test_create_and_run_agent(self) -> None:
"""End-to-end test of creating and running an agent."""
provider = OpenAIAssistantProvider()
agent = await provider.create_agent(
name="IntegrationTestAgent",
model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"),
instructions="You are a helpful assistant. Respond briefly.",
)
try:
result = await agent.run("Say 'hello' and nothing else.")
result_text = str(result)
assert "hello" in result_text.lower()
finally:
# Clean up the assistant
await provider._client.beta.assistants.delete(agent.id) # type: ignore[reportPrivateUsage, union-attr]
async def test_create_agent_with_function_tools_integration(self) -> None:
"""Integration test with function tools."""
provider = OpenAIAssistantProvider()
def get_current_time() -> str:
"""Get the current time."""
from datetime import datetime
return datetime.now().strftime("%H:%M")
agent = await provider.create_agent(
name="TimeAgent",
model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"),
instructions="You are a helpful assistant.",
tools=[get_current_time],
)
try:
result = await agent.run("What time is it? Use the get_current_time function.")
result_text = str(result)
# The response should contain time information
assert ":" in result_text or "time" in result_text.lower()
finally:
await provider._client.beta.assistants.delete(agent.id) # type: ignore[reportPrivateUsage, union-attr]
# endregion