mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Create/Get Agent API for OpenAI Assistants (#3208)
* Added provider implementation * Added example with response format * Small improvements
This commit is contained in:
committed by
GitHub
Unverified
parent
dd3e2b6e53
commit
b5ca0c8eda
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
|
||||
from ._assistant_provider import * # noqa: F403
|
||||
from ._assistants_client import * # noqa: F403
|
||||
from ._chat_client import * # noqa: F403
|
||||
from ._exceptions import * # noqa: F403
|
||||
|
||||
@@ -0,0 +1,563 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable, MutableMapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypedDict, cast
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.beta.assistant import Assistant
|
||||
from pydantic import BaseModel, SecretStr, ValidationError
|
||||
|
||||
from .._agents import ChatAgent
|
||||
from .._memory import ContextProvider
|
||||
from .._middleware import Middleware
|
||||
from .._tools import AIFunction, ToolProtocol
|
||||
from .._types import normalize_tools
|
||||
from ..exceptions import ServiceInitializationError
|
||||
from ._assistants_client import OpenAIAssistantsClient
|
||||
from ._shared import OpenAISettings, from_assistant_tools, to_assistant_tools
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._assistants_client import OpenAIAssistantsOptions
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import Self, TypeVar # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self, TypeVar # pragma: no cover
|
||||
|
||||
|
||||
__all__ = ["OpenAIAssistantProvider"]
|
||||
|
||||
# Type variable for options - allows typed ChatAgent[TOptions] returns
|
||||
# Default matches OpenAIAssistantsClient's default options type
|
||||
TOptions_co = TypeVar(
|
||||
"TOptions_co",
|
||||
bound=TypedDict, # type: ignore[valid-type]
|
||||
default="OpenAIAssistantsOptions",
|
||||
covariant=True,
|
||||
)
|
||||
|
||||
_ToolsType = (
|
||||
ToolProtocol
|
||||
| Callable[..., Any]
|
||||
| MutableMapping[str, Any]
|
||||
| Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
|
||||
)
|
||||
|
||||
|
||||
class OpenAIAssistantProvider(Generic[TOptions_co]):
|
||||
"""Provider for creating ChatAgent instances from OpenAI Assistants API.
|
||||
|
||||
This provider allows you to create, retrieve, and wrap OpenAI Assistants
|
||||
as ChatAgent instances for use in the agent framework.
|
||||
|
||||
Examples:
|
||||
Basic usage with automatic client creation:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework.openai import OpenAIAssistantProvider
|
||||
|
||||
# Uses OPENAI_API_KEY environment variable
|
||||
provider = OpenAIAssistantProvider()
|
||||
|
||||
# Create a new assistant
|
||||
agent = await provider.create_agent(
|
||||
name="MyAssistant",
|
||||
model="gpt-4",
|
||||
instructions="You are a helpful assistant.",
|
||||
tools=[my_function],
|
||||
)
|
||||
|
||||
result = await agent.run("Hello!")
|
||||
|
||||
Using an existing client:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from agent_framework.openai import OpenAIAssistantProvider
|
||||
|
||||
client = AsyncOpenAI()
|
||||
provider = OpenAIAssistantProvider(client)
|
||||
|
||||
# Get an existing assistant by ID
|
||||
agent = await provider.get_agent(
|
||||
assistant_id="asst_123",
|
||||
tools=[my_function], # Provide implementations for function tools
|
||||
)
|
||||
|
||||
Wrapping an SDK Assistant object:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Fetch assistant directly via SDK
|
||||
assistant = await client.beta.assistants.retrieve("asst_123")
|
||||
|
||||
# Wrap without additional HTTP call
|
||||
agent = provider.as_agent(assistant, tools=[my_function])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: AsyncOpenAI | None = None,
|
||||
*,
|
||||
api_key: str | SecretStr | Callable[[], str | Awaitable[str]] | None = None,
|
||||
org_id: str | None = None,
|
||||
base_url: str | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the OpenAI Assistant Provider.
|
||||
|
||||
Args:
|
||||
client: An existing AsyncOpenAI client to use. If not provided,
|
||||
a new client will be created using the other parameters.
|
||||
|
||||
Keyword Args:
|
||||
api_key: OpenAI API key. Can also be set via OPENAI_API_KEY env var.
|
||||
org_id: OpenAI organization ID. Can also be set via OPENAI_ORG_ID env var.
|
||||
base_url: Base URL for the OpenAI API. Can also be set via OPENAI_BASE_URL env var.
|
||||
env_file_path: Path to .env file for configuration.
|
||||
env_file_encoding: Encoding of the .env file.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If no client is provided and API key is missing.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
# Using environment variables
|
||||
provider = OpenAIAssistantProvider()
|
||||
|
||||
# Using explicit API key
|
||||
provider = OpenAIAssistantProvider(api_key="sk-...")
|
||||
|
||||
# Using existing client
|
||||
client = AsyncOpenAI()
|
||||
provider = OpenAIAssistantProvider(client)
|
||||
"""
|
||||
self._client: AsyncOpenAI | None = client
|
||||
self._should_close_client: bool = client is None
|
||||
|
||||
if client is None:
|
||||
# Load settings and create client
|
||||
try:
|
||||
settings = OpenAISettings(
|
||||
api_key=api_key, # type: ignore[reportArgumentType]
|
||||
org_id=org_id,
|
||||
base_url=base_url,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
except ValidationError as ex:
|
||||
raise ServiceInitializationError("Failed to create OpenAI settings.", ex) from ex
|
||||
|
||||
if not settings.api_key:
|
||||
raise ServiceInitializationError(
|
||||
"OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable."
|
||||
)
|
||||
|
||||
# Get API key value
|
||||
api_key_value: str | Callable[[], str | Awaitable[str]] | None
|
||||
if isinstance(settings.api_key, SecretStr):
|
||||
api_key_value = settings.api_key.get_secret_value()
|
||||
else:
|
||||
api_key_value = settings.api_key
|
||||
|
||||
# Create client
|
||||
client_args: dict[str, Any] = {"api_key": api_key_value}
|
||||
if settings.org_id:
|
||||
client_args["organization"] = settings.org_id
|
||||
if settings.base_url:
|
||||
client_args["base_url"] = settings.base_url
|
||||
|
||||
self._client = AsyncOpenAI(**client_args)
|
||||
|
||||
async def __aenter__(self) -> "Self":
|
||||
"""Async context manager entry."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None:
|
||||
"""Async context manager exit."""
|
||||
await self.close()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the provider and clean up resources.
|
||||
|
||||
If the provider created its own client, it will be closed.
|
||||
If an external client was provided, it will not be closed.
|
||||
"""
|
||||
if self._should_close_client and self._client is not None:
|
||||
await self._client.close()
|
||||
|
||||
async def create_agent(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
model: str,
|
||||
instructions: str | None = None,
|
||||
description: str | None = None,
|
||||
tools: _ToolsType | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
default_options: TOptions_co | None = None,
|
||||
middleware: Sequence[Middleware] | None = None,
|
||||
context_provider: ContextProvider | None = None,
|
||||
) -> "ChatAgent[TOptions_co]":
|
||||
"""Create a new assistant on OpenAI and return a ChatAgent.
|
||||
|
||||
This method creates a new assistant on the OpenAI service and wraps it
|
||||
in a ChatAgent instance. The assistant will persist on OpenAI until deleted.
|
||||
|
||||
Keyword Args:
|
||||
name: The name of the assistant (required).
|
||||
model: The model ID to use, e.g., "gpt-4", "gpt-4o" (required).
|
||||
instructions: System instructions for the assistant.
|
||||
description: A description of the assistant.
|
||||
tools: Tools available to the assistant. Can include:
|
||||
- AIFunction instances or callables decorated with @ai_function
|
||||
- HostedCodeInterpreterTool for code execution
|
||||
- HostedFileSearchTool for vector store search
|
||||
- Raw tool dictionaries
|
||||
metadata: Metadata to attach to the assistant (max 16 key-value pairs).
|
||||
default_options: A TypedDict containing default chat options for the agent.
|
||||
These options are applied to every run unless overridden.
|
||||
Include ``response_format`` here for structured output responses.
|
||||
middleware: Middleware for the ChatAgent.
|
||||
context_provider: Context provider for the ChatAgent.
|
||||
|
||||
Returns:
|
||||
A ChatAgent instance wrapping the created assistant.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If assistant creation fails.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
provider = OpenAIAssistantProvider()
|
||||
|
||||
# Create with function tools
|
||||
agent = await provider.create_agent(
|
||||
name="WeatherBot",
|
||||
model="gpt-4",
|
||||
instructions="You are a helpful weather assistant.",
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
# Create with structured output
|
||||
agent = await provider.create_agent(
|
||||
name="StructuredBot",
|
||||
model="gpt-4",
|
||||
default_options={"response_format": MyPydanticModel},
|
||||
)
|
||||
"""
|
||||
# Normalize tools
|
||||
normalized_tools = normalize_tools(tools)
|
||||
api_tools = to_assistant_tools(normalized_tools) if normalized_tools else []
|
||||
|
||||
# Extract response_format from default_options if present
|
||||
opts = dict(default_options) if default_options else {}
|
||||
response_format = opts.get("response_format")
|
||||
|
||||
# Build assistant creation parameters
|
||||
create_params: dict[str, Any] = {
|
||||
"model": model,
|
||||
"name": name,
|
||||
}
|
||||
|
||||
if instructions is not None:
|
||||
create_params["instructions"] = instructions
|
||||
if description is not None:
|
||||
create_params["description"] = description
|
||||
if api_tools:
|
||||
create_params["tools"] = api_tools
|
||||
if metadata is not None:
|
||||
create_params["metadata"] = metadata
|
||||
|
||||
# Handle response format for OpenAI API
|
||||
if response_format is not None and isinstance(response_format, type) and issubclass(response_format, BaseModel):
|
||||
create_params["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": response_format.__name__,
|
||||
"schema": response_format.model_json_schema(),
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
|
||||
# Create the assistant
|
||||
if not self._client:
|
||||
raise ServiceInitializationError("OpenAI client is not initialized.")
|
||||
|
||||
assistant = await self._client.beta.assistants.create(**create_params)
|
||||
|
||||
# Create ChatAgent - pass default_options which contains response_format
|
||||
return self._create_chat_agent_from_assistant(
|
||||
assistant=assistant,
|
||||
tools=normalized_tools,
|
||||
instructions=instructions,
|
||||
middleware=middleware,
|
||||
context_provider=context_provider,
|
||||
default_options=default_options,
|
||||
)
|
||||
|
||||
async def get_agent(
|
||||
self,
|
||||
assistant_id: str,
|
||||
*,
|
||||
tools: _ToolsType | None = None,
|
||||
instructions: str | None = None,
|
||||
default_options: TOptions_co | None = None,
|
||||
middleware: Sequence[Middleware] | None = None,
|
||||
context_provider: ContextProvider | None = None,
|
||||
) -> "ChatAgent[TOptions_co]":
|
||||
"""Retrieve an existing assistant by ID and return a ChatAgent.
|
||||
|
||||
This method fetches an existing assistant from OpenAI by its ID
|
||||
and wraps it in a ChatAgent instance.
|
||||
|
||||
Args:
|
||||
assistant_id: The ID of the assistant to retrieve (e.g., "asst_123").
|
||||
|
||||
Keyword Args:
|
||||
tools: Function tools to make available. IMPORTANT: If the assistant
|
||||
was created with function tools, you MUST provide matching
|
||||
implementations here. Hosted tools (code_interpreter, file_search)
|
||||
are automatically included.
|
||||
instructions: Override the assistant's instructions (optional).
|
||||
default_options: A TypedDict containing default chat options for the agent.
|
||||
These options are applied to every run unless overridden.
|
||||
middleware: Middleware for the ChatAgent.
|
||||
context_provider: Context provider for the ChatAgent.
|
||||
|
||||
Returns:
|
||||
A ChatAgent instance wrapping the retrieved assistant.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If the assistant cannot be retrieved.
|
||||
ValueError: If required function tools are missing.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
provider = OpenAIAssistantProvider()
|
||||
|
||||
# Get assistant without function tools
|
||||
agent = await provider.get_agent(assistant_id="asst_123")
|
||||
|
||||
# Get assistant with function tools
|
||||
agent = await provider.get_agent(
|
||||
assistant_id="asst_456",
|
||||
tools=[get_weather, search_database], # Implementations required!
|
||||
)
|
||||
"""
|
||||
# Fetch the assistant
|
||||
if not self._client:
|
||||
raise ServiceInitializationError("OpenAI client is not initialized.")
|
||||
|
||||
assistant = await self._client.beta.assistants.retrieve(assistant_id)
|
||||
|
||||
# Use as_agent to wrap it
|
||||
return self.as_agent(
|
||||
assistant=assistant,
|
||||
tools=tools,
|
||||
instructions=instructions,
|
||||
default_options=default_options,
|
||||
middleware=middleware,
|
||||
context_provider=context_provider,
|
||||
)
|
||||
|
||||
def as_agent(
|
||||
self,
|
||||
assistant: Assistant,
|
||||
*,
|
||||
tools: _ToolsType | None = None,
|
||||
instructions: str | None = None,
|
||||
default_options: TOptions_co | None = None,
|
||||
middleware: Sequence[Middleware] | None = None,
|
||||
context_provider: ContextProvider | None = None,
|
||||
) -> "ChatAgent[TOptions_co]":
|
||||
"""Wrap an existing SDK Assistant object as a ChatAgent.
|
||||
|
||||
This method does NOT make any HTTP calls. It simply wraps an already-
|
||||
fetched Assistant object in a ChatAgent.
|
||||
|
||||
Args:
|
||||
assistant: The OpenAI Assistant SDK object to wrap.
|
||||
|
||||
Keyword Args:
|
||||
tools: Function tools to make available. If the assistant has
|
||||
function tools defined, you MUST provide matching implementations.
|
||||
Hosted tools (code_interpreter, file_search) are automatically included.
|
||||
instructions: Override the assistant's instructions (optional).
|
||||
default_options: A TypedDict containing default chat options for the agent.
|
||||
These options are applied to every run unless overridden.
|
||||
middleware: Middleware for the ChatAgent.
|
||||
context_provider: Context provider for the ChatAgent.
|
||||
|
||||
Returns:
|
||||
A ChatAgent instance wrapping the assistant.
|
||||
|
||||
Raises:
|
||||
ValueError: If required function tools are missing.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
client = AsyncOpenAI()
|
||||
provider = OpenAIAssistantProvider(client)
|
||||
|
||||
# Fetch assistant via SDK
|
||||
assistant = await client.beta.assistants.retrieve("asst_123")
|
||||
|
||||
# Wrap without additional HTTP call
|
||||
agent = provider.as_agent(
|
||||
assistant,
|
||||
tools=[my_function],
|
||||
instructions="Custom instructions override",
|
||||
)
|
||||
"""
|
||||
# Validate that required function tools are provided
|
||||
self._validate_function_tools(assistant.tools or [], tools)
|
||||
|
||||
# Merge hosted tools with user-provided function tools
|
||||
merged_tools = self._merge_tools(assistant.tools or [], tools)
|
||||
|
||||
# Create ChatAgent
|
||||
return self._create_chat_agent_from_assistant(
|
||||
assistant=assistant,
|
||||
tools=merged_tools,
|
||||
instructions=instructions,
|
||||
default_options=default_options,
|
||||
middleware=middleware,
|
||||
context_provider=context_provider,
|
||||
)
|
||||
|
||||
def _validate_function_tools(
|
||||
self,
|
||||
assistant_tools: list[Any],
|
||||
provided_tools: _ToolsType | None,
|
||||
) -> None:
|
||||
"""Validate that required function tools are provided.
|
||||
|
||||
Args:
|
||||
assistant_tools: Tools defined on the assistant.
|
||||
provided_tools: Tools provided by the user.
|
||||
|
||||
Raises:
|
||||
ValueError: If a required function tool is missing.
|
||||
"""
|
||||
# Get function tool names from assistant
|
||||
required_functions: set[str] = set()
|
||||
for tool in assistant_tools:
|
||||
if (
|
||||
hasattr(tool, "type")
|
||||
and tool.type == "function"
|
||||
and hasattr(tool, "function")
|
||||
and hasattr(tool.function, "name")
|
||||
):
|
||||
required_functions.add(tool.function.name)
|
||||
|
||||
if not required_functions:
|
||||
return # No function tools required
|
||||
|
||||
# Get provided function names using normalize_tools
|
||||
provided_functions: set[str] = set()
|
||||
if provided_tools is not None:
|
||||
normalized = normalize_tools(provided_tools)
|
||||
for tool in normalized:
|
||||
if isinstance(tool, AIFunction):
|
||||
provided_functions.add(tool.name)
|
||||
elif isinstance(tool, MutableMapping) and "function" in tool:
|
||||
func_spec = tool.get("function", {})
|
||||
if isinstance(func_spec, dict):
|
||||
func_dict = cast(dict[str, Any], func_spec)
|
||||
if "name" in func_dict:
|
||||
provided_functions.add(str(func_dict["name"]))
|
||||
|
||||
# Check for missing functions
|
||||
missing = required_functions - provided_functions
|
||||
if missing:
|
||||
missing_list = ", ".join(sorted(missing))
|
||||
raise ValueError(
|
||||
f"Assistant requires function tool(s) '{missing_list}' but no implementation was provided. "
|
||||
f"Please pass the function implementation(s) in the 'tools' parameter."
|
||||
)
|
||||
|
||||
def _merge_tools(
|
||||
self,
|
||||
assistant_tools: list[Any],
|
||||
user_tools: _ToolsType | None,
|
||||
) -> list[ToolProtocol | MutableMapping[str, Any]]:
|
||||
"""Merge hosted tools from assistant with user-provided function tools.
|
||||
|
||||
Args:
|
||||
assistant_tools: Tools defined on the assistant.
|
||||
user_tools: Tools provided by the user.
|
||||
|
||||
Returns:
|
||||
A list of all tools (hosted tools + user function implementations).
|
||||
"""
|
||||
merged: list[ToolProtocol | MutableMapping[str, Any]] = []
|
||||
|
||||
# Add hosted tools from assistant using shared conversion
|
||||
hosted_tools = from_assistant_tools(assistant_tools)
|
||||
merged.extend(hosted_tools)
|
||||
|
||||
# Add user-provided tools (normalized)
|
||||
if user_tools is not None:
|
||||
normalized_user_tools = normalize_tools(user_tools)
|
||||
merged.extend(normalized_user_tools)
|
||||
|
||||
return merged
|
||||
|
||||
def _create_chat_agent_from_assistant(
|
||||
self,
|
||||
assistant: Assistant,
|
||||
tools: list[ToolProtocol | MutableMapping[str, Any]] | None,
|
||||
instructions: str | None,
|
||||
middleware: Sequence[Middleware] | None,
|
||||
context_provider: ContextProvider | None,
|
||||
default_options: TOptions_co | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "ChatAgent[TOptions_co]":
|
||||
"""Create a ChatAgent from an Assistant.
|
||||
|
||||
Args:
|
||||
assistant: The OpenAI Assistant object.
|
||||
tools: Tools for the agent.
|
||||
instructions: Instructions override.
|
||||
middleware: Middleware for the agent.
|
||||
context_provider: Context provider for the agent.
|
||||
default_options: Default chat options for the agent (may include response_format).
|
||||
**kwargs: Additional arguments passed to ChatAgent.
|
||||
|
||||
Returns:
|
||||
A configured ChatAgent instance.
|
||||
"""
|
||||
# Create the chat client with the assistant
|
||||
chat_client = OpenAIAssistantsClient(
|
||||
model_id=assistant.model,
|
||||
assistant_id=assistant.id,
|
||||
assistant_name=assistant.name,
|
||||
assistant_description=assistant.description,
|
||||
async_client=self._client,
|
||||
)
|
||||
|
||||
# Use instructions from assistant if not overridden
|
||||
final_instructions = instructions if instructions is not None else assistant.instructions
|
||||
|
||||
# Create and return ChatAgent
|
||||
return ChatAgent(
|
||||
chat_client=chat_client,
|
||||
id=assistant.id,
|
||||
name=assistant.name,
|
||||
description=assistant.description,
|
||||
instructions=final_instructions,
|
||||
tools=tools if tools else None,
|
||||
middleware=middleware,
|
||||
context_provider=context_provider,
|
||||
default_options=default_options, # type: ignore[arg-type]
|
||||
**kwargs,
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence
|
||||
from copy import copy
|
||||
from typing import Any, ClassVar, Union
|
||||
|
||||
@@ -24,6 +24,7 @@ from .._logging import get_logger
|
||||
from .._pydantic import AFBaseSettings
|
||||
from .._serialization import SerializationMixin
|
||||
from .._telemetry import APP_INFO, USER_AGENT_KEY, prepend_agent_framework_to_user_agent
|
||||
from .._tools import AIFunction, HostedCodeInterpreterTool, HostedFileSearchTool, ToolProtocol
|
||||
from ..exceptions import ServiceInitializationError
|
||||
|
||||
logger: logging.Logger = get_logger("agent_framework.openai")
|
||||
@@ -275,3 +276,74 @@ class OpenAIConfigMixin(OpenAIBase):
|
||||
# Ensure additional_properties and middleware are passed through kwargs to BaseChatClient
|
||||
# These are consumed by BaseChatClient.__init__ via kwargs
|
||||
super().__init__(**args, **kwargs)
|
||||
|
||||
|
||||
def to_assistant_tools(
|
||||
tools: Sequence[ToolProtocol | MutableMapping[str, Any]] | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert Agent Framework tools to OpenAI Assistants API format.
|
||||
|
||||
Args:
|
||||
tools: Normalized tools (from ChatOptions.tools).
|
||||
|
||||
Returns:
|
||||
List of tool definitions for OpenAI Assistants API.
|
||||
"""
|
||||
if not tools:
|
||||
return []
|
||||
|
||||
tool_definitions: list[dict[str, Any]] = []
|
||||
|
||||
for tool in tools:
|
||||
if isinstance(tool, AIFunction):
|
||||
tool_definitions.append(tool.to_json_schema_spec())
|
||||
elif isinstance(tool, HostedCodeInterpreterTool):
|
||||
tool_definitions.append({"type": "code_interpreter"})
|
||||
elif isinstance(tool, HostedFileSearchTool):
|
||||
params: dict[str, Any] = {"type": "file_search"}
|
||||
if tool.max_results is not None:
|
||||
params["file_search"] = {"max_num_results": tool.max_results}
|
||||
tool_definitions.append(params)
|
||||
elif isinstance(tool, MutableMapping):
|
||||
# Pass through raw dict definitions
|
||||
tool_definitions.append(dict(tool))
|
||||
|
||||
return tool_definitions
|
||||
|
||||
|
||||
def from_assistant_tools(
|
||||
assistant_tools: list[Any] | None,
|
||||
) -> list[ToolProtocol]:
|
||||
"""Convert OpenAI Assistant tools to Agent Framework format.
|
||||
|
||||
This converts hosted tools (code_interpreter, file_search) from an OpenAI
|
||||
Assistant definition back to Agent Framework tool instances.
|
||||
|
||||
Note: Function tools are skipped - user must provide implementations separately.
|
||||
|
||||
Args:
|
||||
assistant_tools: Tools from OpenAI Assistant object (assistant.tools).
|
||||
|
||||
Returns:
|
||||
List of Agent Framework tool instances for hosted tools.
|
||||
"""
|
||||
if not assistant_tools:
|
||||
return []
|
||||
|
||||
tools: list[ToolProtocol] = []
|
||||
|
||||
for tool in assistant_tools:
|
||||
if hasattr(tool, "type"):
|
||||
tool_type = tool.type
|
||||
elif isinstance(tool, dict):
|
||||
tool_type = tool.get("type")
|
||||
else:
|
||||
tool_type = None
|
||||
|
||||
if tool_type == "code_interpreter":
|
||||
tools.append(HostedCodeInterpreterTool())
|
||||
elif tool_type == "file_search":
|
||||
tools.append(HostedFileSearchTool())
|
||||
# Skip function tools - user must provide implementations
|
||||
|
||||
return tools
|
||||
|
||||
@@ -0,0 +1,814 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import os
|
||||
from typing import Annotated, Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from openai.types.beta.assistant import Assistant
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from agent_framework import ChatAgent, HostedCodeInterpreterTool, HostedFileSearchTool, ai_function, normalize_tools
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from agent_framework.openai import OpenAIAssistantProvider
|
||||
from agent_framework.openai._shared import from_assistant_tools, to_assistant_tools
|
||||
|
||||
# region Test Helpers
|
||||
|
||||
|
||||
def create_mock_assistant(
|
||||
assistant_id: str = "asst_test123",
|
||||
name: str = "TestAssistant",
|
||||
model: str = "gpt-4",
|
||||
instructions: str | None = "You are a helpful assistant.",
|
||||
description: str | None = None,
|
||||
tools: list[Any] | None = None,
|
||||
) -> Assistant:
|
||||
"""Create a mock Assistant object."""
|
||||
mock = MagicMock(spec=Assistant)
|
||||
mock.id = assistant_id
|
||||
mock.name = name
|
||||
mock.model = model
|
||||
mock.instructions = instructions
|
||||
mock.description = description
|
||||
mock.tools = tools or []
|
||||
return mock
|
||||
|
||||
|
||||
def create_function_tool(name: str, description: str = "A test function") -> MagicMock:
|
||||
"""Create a mock FunctionTool."""
|
||||
mock = MagicMock()
|
||||
mock.type = "function"
|
||||
mock.function = MagicMock()
|
||||
mock.function.name = name
|
||||
mock.function.description = description
|
||||
return mock
|
||||
|
||||
|
||||
def create_code_interpreter_tool() -> MagicMock:
|
||||
"""Create a mock CodeInterpreterTool."""
|
||||
mock = MagicMock()
|
||||
mock.type = "code_interpreter"
|
||||
return mock
|
||||
|
||||
|
||||
def create_file_search_tool() -> MagicMock:
|
||||
"""Create a mock FileSearchTool."""
|
||||
mock = MagicMock()
|
||||
mock.type = "file_search"
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_openai() -> MagicMock:
|
||||
"""Mock AsyncOpenAI client."""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock beta.assistants
|
||||
mock_client.beta.assistants.create = AsyncMock(
|
||||
return_value=create_mock_assistant(assistant_id="asst_created123", name="CreatedAssistant")
|
||||
)
|
||||
mock_client.beta.assistants.retrieve = AsyncMock(
|
||||
return_value=create_mock_assistant(assistant_id="asst_retrieved123", name="RetrievedAssistant")
|
||||
)
|
||||
mock_client.beta.assistants.delete = AsyncMock()
|
||||
|
||||
# Mock close method
|
||||
mock_client.close = AsyncMock()
|
||||
|
||||
return mock_client
|
||||
|
||||
|
||||
# Test function for tool validation
|
||||
def get_weather(location: Annotated[str, Field(description="The location")]) -> str:
|
||||
"""Get the weather for a location."""
|
||||
return f"Weather in {location}: sunny"
|
||||
|
||||
|
||||
def search_database(query: Annotated[str, Field(description="Search query")]) -> str:
|
||||
"""Search the database."""
|
||||
return f"Results for: {query}"
|
||||
|
||||
|
||||
# Pydantic model for structured output tests
|
||||
class WeatherResponse(BaseModel):
|
||||
location: str
|
||||
temperature: float
|
||||
conditions: str
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Initialization Tests
|
||||
|
||||
|
||||
class TestOpenAIAssistantProviderInit:
|
||||
"""Tests for provider initialization."""
|
||||
|
||||
def test_init_with_client(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test initialization with existing AsyncOpenAI client."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
assert provider._client is mock_async_openai # type: ignore[reportPrivateUsage]
|
||||
assert provider._should_close_client is False # type: ignore[reportPrivateUsage]
|
||||
|
||||
def test_init_without_client_creates_one(self, openai_unit_test_env: dict[str, str]) -> None:
|
||||
"""Test initialization creates client from settings."""
|
||||
provider = OpenAIAssistantProvider()
|
||||
|
||||
assert provider._client is not None # type: ignore[reportPrivateUsage]
|
||||
assert provider._should_close_client is True # type: ignore[reportPrivateUsage]
|
||||
|
||||
def test_init_with_api_key(self) -> None:
|
||||
"""Test initialization with explicit API key."""
|
||||
provider = OpenAIAssistantProvider(api_key="sk-test-key")
|
||||
|
||||
assert provider._client is not None # type: ignore[reportPrivateUsage]
|
||||
assert provider._should_close_client is True # type: ignore[reportPrivateUsage]
|
||||
|
||||
def test_init_fails_without_api_key(self) -> None:
|
||||
"""Test initialization fails without API key when settings return None."""
|
||||
from unittest.mock import patch
|
||||
|
||||
# Mock OpenAISettings to return None for api_key
|
||||
with patch("agent_framework.openai._assistant_provider.OpenAISettings") as mock_settings:
|
||||
mock_settings.return_value.api_key = None
|
||||
|
||||
with pytest.raises(ServiceInitializationError) as exc_info:
|
||||
OpenAIAssistantProvider()
|
||||
|
||||
assert "API key is required" in str(exc_info.value)
|
||||
|
||||
def test_init_with_org_id_and_base_url(self) -> None:
|
||||
"""Test initialization with organization ID and base URL."""
|
||||
provider = OpenAIAssistantProvider(
|
||||
api_key="sk-test-key",
|
||||
org_id="org-123",
|
||||
base_url="https://custom.openai.com",
|
||||
)
|
||||
|
||||
assert provider._client is not None # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
class TestOpenAIAssistantProviderContextManager:
|
||||
"""Tests for async context manager."""
|
||||
|
||||
async def test_context_manager_enter_exit(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test async context manager entry and exit."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
async with provider as p:
|
||||
assert p is provider
|
||||
|
||||
async def test_context_manager_closes_owned_client(self, openai_unit_test_env: dict[str, str]) -> None:
|
||||
"""Test that owned client is closed on exit."""
|
||||
provider = OpenAIAssistantProvider()
|
||||
client = provider._client # type: ignore[reportPrivateUsage]
|
||||
assert client is not None
|
||||
client.close = AsyncMock()
|
||||
|
||||
async with provider:
|
||||
pass
|
||||
|
||||
client.close.assert_called_once()
|
||||
|
||||
async def test_context_manager_does_not_close_external_client(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test that external client is not closed on exit."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
async with provider:
|
||||
pass
|
||||
|
||||
mock_async_openai.close.assert_not_called()
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region create_agent Tests
|
||||
|
||||
|
||||
class TestOpenAIAssistantProviderCreateAgent:
|
||||
"""Tests for create_agent method."""
|
||||
|
||||
async def test_create_agent_basic(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test basic assistant creation."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
agent = await provider.create_agent(
|
||||
name="TestAgent",
|
||||
model="gpt-4",
|
||||
instructions="You are helpful.",
|
||||
)
|
||||
|
||||
assert isinstance(agent, ChatAgent)
|
||||
assert agent.name == "CreatedAssistant"
|
||||
mock_async_openai.beta.assistants.create.assert_called_once()
|
||||
|
||||
# Verify create was called with correct parameters
|
||||
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
|
||||
assert call_kwargs["name"] == "TestAgent"
|
||||
assert call_kwargs["model"] == "gpt-4"
|
||||
assert call_kwargs["instructions"] == "You are helpful."
|
||||
|
||||
async def test_create_agent_with_description(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test assistant creation with description."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
await provider.create_agent(
|
||||
name="TestAgent",
|
||||
model="gpt-4",
|
||||
description="A test agent description",
|
||||
)
|
||||
|
||||
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
|
||||
assert call_kwargs["description"] == "A test agent description"
|
||||
|
||||
async def test_create_agent_with_function_tools(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test assistant creation with function tools."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
agent = await provider.create_agent(
|
||||
name="WeatherAgent",
|
||||
model="gpt-4",
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
assert isinstance(agent, ChatAgent)
|
||||
|
||||
# Verify tools were passed to create
|
||||
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
|
||||
assert "tools" in call_kwargs
|
||||
assert len(call_kwargs["tools"]) == 1
|
||||
assert call_kwargs["tools"][0]["type"] == "function"
|
||||
assert call_kwargs["tools"][0]["function"]["name"] == "get_weather"
|
||||
|
||||
async def test_create_agent_with_ai_function(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test assistant creation with AIFunction."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
@ai_function
|
||||
def my_function(x: int) -> int:
|
||||
"""Double a number."""
|
||||
return x * 2
|
||||
|
||||
await provider.create_agent(
|
||||
name="TestAgent",
|
||||
model="gpt-4",
|
||||
tools=[my_function],
|
||||
)
|
||||
|
||||
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
|
||||
assert call_kwargs["tools"][0]["function"]["name"] == "my_function"
|
||||
|
||||
async def test_create_agent_with_code_interpreter(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test assistant creation with code interpreter."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
await provider.create_agent(
|
||||
name="CodeAgent",
|
||||
model="gpt-4",
|
||||
tools=[HostedCodeInterpreterTool()],
|
||||
)
|
||||
|
||||
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
|
||||
assert {"type": "code_interpreter"} in call_kwargs["tools"]
|
||||
|
||||
async def test_create_agent_with_file_search(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test assistant creation with file search."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
await provider.create_agent(
|
||||
name="SearchAgent",
|
||||
model="gpt-4",
|
||||
tools=[HostedFileSearchTool()],
|
||||
)
|
||||
|
||||
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
|
||||
assert any(t["type"] == "file_search" for t in call_kwargs["tools"])
|
||||
|
||||
async def test_create_agent_with_file_search_max_results(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test assistant creation with file search and max_results."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
await provider.create_agent(
|
||||
name="SearchAgent",
|
||||
model="gpt-4",
|
||||
tools=[HostedFileSearchTool(max_results=10)],
|
||||
)
|
||||
|
||||
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
|
||||
file_search_tool = next(t for t in call_kwargs["tools"] if t["type"] == "file_search")
|
||||
assert file_search_tool.get("file_search", {}).get("max_num_results") == 10
|
||||
|
||||
async def test_create_agent_with_mixed_tools(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test assistant creation with multiple tool types."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
await provider.create_agent(
|
||||
name="MultiToolAgent",
|
||||
model="gpt-4",
|
||||
tools=[get_weather, HostedCodeInterpreterTool(), HostedFileSearchTool()],
|
||||
)
|
||||
|
||||
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
|
||||
assert len(call_kwargs["tools"]) == 3
|
||||
|
||||
async def test_create_agent_with_metadata(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test assistant creation with metadata."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
await provider.create_agent(
|
||||
name="TestAgent",
|
||||
model="gpt-4",
|
||||
metadata={"env": "test", "version": "1.0"},
|
||||
)
|
||||
|
||||
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
|
||||
assert call_kwargs["metadata"] == {"env": "test", "version": "1.0"}
|
||||
|
||||
async def test_create_agent_with_response_format_pydantic(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test assistant creation with Pydantic response format via default_options."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
await provider.create_agent(
|
||||
name="StructuredAgent",
|
||||
model="gpt-4",
|
||||
default_options={"response_format": WeatherResponse},
|
||||
)
|
||||
|
||||
call_kwargs = mock_async_openai.beta.assistants.create.call_args.kwargs
|
||||
assert call_kwargs["response_format"]["type"] == "json_schema"
|
||||
assert call_kwargs["response_format"]["json_schema"]["name"] == "WeatherResponse"
|
||||
|
||||
async def test_create_agent_returns_chat_agent(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test that create_agent returns a ChatAgent instance."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
agent = await provider.create_agent(
|
||||
name="TestAgent",
|
||||
model="gpt-4",
|
||||
)
|
||||
|
||||
assert isinstance(agent, ChatAgent)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region get_agent Tests
|
||||
|
||||
|
||||
class TestOpenAIAssistantProviderGetAgent:
|
||||
"""Tests for get_agent method."""
|
||||
|
||||
async def test_get_agent_basic(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test retrieving an existing assistant."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
agent = await provider.get_agent(assistant_id="asst_123")
|
||||
|
||||
assert isinstance(agent, ChatAgent)
|
||||
mock_async_openai.beta.assistants.retrieve.assert_called_once_with("asst_123")
|
||||
|
||||
async def test_get_agent_with_instructions_override(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test retrieving assistant with instruction override."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
agent = await provider.get_agent(
|
||||
assistant_id="asst_123",
|
||||
instructions="Custom instructions",
|
||||
)
|
||||
|
||||
# Agent should be created successfully with the custom instructions
|
||||
assert isinstance(agent, ChatAgent)
|
||||
assert agent.id == "asst_retrieved123"
|
||||
|
||||
async def test_get_agent_with_function_tools(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test retrieving assistant with function tools provided."""
|
||||
# Setup assistant with function tool
|
||||
assistant = create_mock_assistant(tools=[create_function_tool("get_weather")])
|
||||
mock_async_openai.beta.assistants.retrieve = AsyncMock(return_value=assistant)
|
||||
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
agent = await provider.get_agent(
|
||||
assistant_id="asst_123",
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
assert isinstance(agent, ChatAgent)
|
||||
|
||||
async def test_get_agent_validates_missing_function_tools(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test that missing function tools raise ValueError."""
|
||||
# Setup assistant with function tool
|
||||
assistant = create_mock_assistant(tools=[create_function_tool("get_weather")])
|
||||
mock_async_openai.beta.assistants.retrieve = AsyncMock(return_value=assistant)
|
||||
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await provider.get_agent(assistant_id="asst_123")
|
||||
|
||||
assert "get_weather" in str(exc_info.value)
|
||||
assert "no implementation was provided" in str(exc_info.value)
|
||||
|
||||
async def test_get_agent_validates_multiple_missing_function_tools(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test validation with multiple missing function tools."""
|
||||
assistant = create_mock_assistant(
|
||||
tools=[create_function_tool("get_weather"), create_function_tool("search_database")]
|
||||
)
|
||||
mock_async_openai.beta.assistants.retrieve = AsyncMock(return_value=assistant)
|
||||
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await provider.get_agent(assistant_id="asst_123")
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "get_weather" in error_msg or "search_database" in error_msg
|
||||
|
||||
async def test_get_agent_merges_hosted_tools(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test that hosted tools are automatically included."""
|
||||
assistant = create_mock_assistant(tools=[create_code_interpreter_tool(), create_file_search_tool()])
|
||||
mock_async_openai.beta.assistants.retrieve = AsyncMock(return_value=assistant)
|
||||
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
|
||||
agent = await provider.get_agent(assistant_id="asst_123")
|
||||
|
||||
# Hosted tools should be merged automatically
|
||||
assert isinstance(agent, ChatAgent)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region as_agent Tests
|
||||
|
||||
|
||||
class TestOpenAIAssistantProviderAsAgent:
|
||||
"""Tests for as_agent method."""
|
||||
|
||||
def test_as_agent_no_http_call(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test that as_agent doesn't make HTTP calls."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
assistant = create_mock_assistant()
|
||||
|
||||
agent = provider.as_agent(assistant)
|
||||
|
||||
assert isinstance(agent, ChatAgent)
|
||||
# Verify no HTTP calls were made
|
||||
mock_async_openai.beta.assistants.create.assert_not_called()
|
||||
mock_async_openai.beta.assistants.retrieve.assert_not_called()
|
||||
|
||||
def test_as_agent_wraps_assistant(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test wrapping an SDK Assistant object."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
assistant = create_mock_assistant(
|
||||
assistant_id="asst_wrap123",
|
||||
name="WrappedAssistant",
|
||||
instructions="Original instructions",
|
||||
)
|
||||
|
||||
agent = provider.as_agent(assistant)
|
||||
|
||||
assert agent.id == "asst_wrap123"
|
||||
assert agent.name == "WrappedAssistant"
|
||||
# Instructions are passed to ChatOptions, not exposed as attribute
|
||||
assert isinstance(agent, ChatAgent)
|
||||
|
||||
def test_as_agent_with_instructions_override(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test as_agent with instruction override."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
assistant = create_mock_assistant(instructions="Original")
|
||||
|
||||
agent = provider.as_agent(assistant, instructions="Override")
|
||||
|
||||
# Agent should be created successfully with override instructions
|
||||
assert isinstance(agent, ChatAgent)
|
||||
|
||||
def test_as_agent_validates_function_tools(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test that missing function tools raise ValueError."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
assistant = create_mock_assistant(tools=[create_function_tool("get_weather")])
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
provider.as_agent(assistant)
|
||||
|
||||
assert "get_weather" in str(exc_info.value)
|
||||
|
||||
def test_as_agent_with_function_tools_provided(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test as_agent with function tools provided."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
assistant = create_mock_assistant(tools=[create_function_tool("get_weather")])
|
||||
|
||||
agent = provider.as_agent(assistant, tools=[get_weather])
|
||||
|
||||
assert isinstance(agent, ChatAgent)
|
||||
|
||||
def test_as_agent_merges_hosted_tools(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test that hosted tools are merged automatically."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
assistant = create_mock_assistant(tools=[create_code_interpreter_tool()])
|
||||
|
||||
agent = provider.as_agent(assistant)
|
||||
|
||||
assert isinstance(agent, ChatAgent)
|
||||
|
||||
def test_as_agent_hosted_tools_not_required(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test that hosted tools don't require user implementations."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
assistant = create_mock_assistant(tools=[create_code_interpreter_tool(), create_file_search_tool()])
|
||||
|
||||
# Should not raise - hosted tools don't need implementations
|
||||
agent = provider.as_agent(assistant)
|
||||
|
||||
assert isinstance(agent, ChatAgent)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Tool Conversion Tests
|
||||
|
||||
|
||||
class TestToolConversion:
|
||||
"""Tests for tool conversion utilities (shared functions)."""
|
||||
|
||||
def test_to_assistant_tools_ai_function(self) -> None:
|
||||
"""Test AIFunction conversion to API format."""
|
||||
|
||||
@ai_function
|
||||
def test_func(x: int) -> int:
|
||||
"""Test function."""
|
||||
return x
|
||||
|
||||
# Normalize tools first, then convert
|
||||
normalized = normalize_tools([test_func])
|
||||
api_tools = to_assistant_tools(normalized)
|
||||
|
||||
assert len(api_tools) == 1
|
||||
assert api_tools[0]["type"] == "function"
|
||||
assert api_tools[0]["function"]["name"] == "test_func"
|
||||
|
||||
def test_to_assistant_tools_callable(self) -> None:
|
||||
"""Test raw callable conversion via normalize_tools."""
|
||||
# normalize_tools converts callables to AIFunction
|
||||
normalized = normalize_tools([get_weather])
|
||||
api_tools = to_assistant_tools(normalized)
|
||||
|
||||
assert len(api_tools) == 1
|
||||
assert api_tools[0]["type"] == "function"
|
||||
assert api_tools[0]["function"]["name"] == "get_weather"
|
||||
|
||||
def test_to_assistant_tools_code_interpreter(self) -> None:
|
||||
"""Test HostedCodeInterpreterTool conversion."""
|
||||
api_tools = to_assistant_tools([HostedCodeInterpreterTool()])
|
||||
|
||||
assert len(api_tools) == 1
|
||||
assert api_tools[0] == {"type": "code_interpreter"}
|
||||
|
||||
def test_to_assistant_tools_file_search(self) -> None:
|
||||
"""Test HostedFileSearchTool conversion."""
|
||||
api_tools = to_assistant_tools([HostedFileSearchTool()])
|
||||
|
||||
assert len(api_tools) == 1
|
||||
assert api_tools[0]["type"] == "file_search"
|
||||
|
||||
def test_to_assistant_tools_file_search_with_max_results(self) -> None:
|
||||
"""Test HostedFileSearchTool with max_results conversion."""
|
||||
api_tools = to_assistant_tools([HostedFileSearchTool(max_results=5)])
|
||||
|
||||
assert api_tools[0]["file_search"]["max_num_results"] == 5
|
||||
|
||||
def test_to_assistant_tools_dict(self) -> None:
|
||||
"""Test raw dict tool passthrough."""
|
||||
raw_tool = {"type": "function", "function": {"name": "custom", "description": "Custom tool"}}
|
||||
|
||||
api_tools = to_assistant_tools([raw_tool])
|
||||
|
||||
assert len(api_tools) == 1
|
||||
assert api_tools[0] == raw_tool
|
||||
|
||||
def test_to_assistant_tools_empty(self) -> None:
|
||||
"""Test conversion with no tools."""
|
||||
api_tools = to_assistant_tools(None)
|
||||
|
||||
assert api_tools == []
|
||||
|
||||
def test_from_assistant_tools_code_interpreter(self) -> None:
|
||||
"""Test converting code_interpreter tool from OpenAI format."""
|
||||
assistant_tools = [create_code_interpreter_tool()]
|
||||
|
||||
tools = from_assistant_tools(assistant_tools)
|
||||
|
||||
assert len(tools) == 1
|
||||
assert isinstance(tools[0], HostedCodeInterpreterTool)
|
||||
|
||||
def test_from_assistant_tools_file_search(self) -> None:
|
||||
"""Test converting file_search tool from OpenAI format."""
|
||||
assistant_tools = [create_file_search_tool()]
|
||||
|
||||
tools = from_assistant_tools(assistant_tools)
|
||||
|
||||
assert len(tools) == 1
|
||||
assert isinstance(tools[0], HostedFileSearchTool)
|
||||
|
||||
def test_from_assistant_tools_function_skipped(self) -> None:
|
||||
"""Test that function tools are skipped (no implementations)."""
|
||||
assistant_tools = [create_function_tool("test_func")]
|
||||
|
||||
tools = from_assistant_tools(assistant_tools)
|
||||
|
||||
assert len(tools) == 0 # Function tools are skipped
|
||||
|
||||
def test_from_assistant_tools_empty(self) -> None:
|
||||
"""Test conversion with no tools."""
|
||||
tools = from_assistant_tools(None)
|
||||
|
||||
assert tools == []
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Tool Validation Tests
|
||||
|
||||
|
||||
class TestToolValidation:
|
||||
"""Tests for tool validation."""
|
||||
|
||||
def test_validate_missing_function_tool_raises(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test that missing function tools raise ValueError."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
assistant_tools = [create_function_tool("my_function")]
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
provider._validate_function_tools(assistant_tools, None) # type: ignore[reportPrivateUsage]
|
||||
|
||||
assert "my_function" in str(exc_info.value)
|
||||
|
||||
def test_validate_all_tools_provided_passes(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test that validation passes when all tools provided."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
assistant_tools = [create_function_tool("get_weather")]
|
||||
|
||||
# Should not raise
|
||||
provider._validate_function_tools(assistant_tools, [get_weather]) # type: ignore[reportPrivateUsage]
|
||||
|
||||
def test_validate_hosted_tools_not_required(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test that hosted tools don't require implementations."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
assistant_tools = [create_code_interpreter_tool(), create_file_search_tool()]
|
||||
|
||||
# Should not raise
|
||||
provider._validate_function_tools(assistant_tools, None) # type: ignore[reportPrivateUsage]
|
||||
|
||||
def test_validate_with_ai_function(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test validation with AIFunction."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
assistant_tools = [create_function_tool("get_weather")]
|
||||
|
||||
wrapped = ai_function(get_weather)
|
||||
|
||||
# Should not raise
|
||||
provider._validate_function_tools(assistant_tools, [wrapped]) # type: ignore[reportPrivateUsage]
|
||||
|
||||
def test_validate_partial_tools_raises(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test that partial tool provision raises error."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
assistant_tools = [
|
||||
create_function_tool("get_weather"),
|
||||
create_function_tool("search_database"),
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
provider._validate_function_tools(assistant_tools, [get_weather]) # type: ignore[reportPrivateUsage]
|
||||
|
||||
assert "search_database" in str(exc_info.value)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Tool Merging Tests
|
||||
|
||||
|
||||
class TestToolMerging:
|
||||
"""Tests for tool merging."""
|
||||
|
||||
def test_merge_code_interpreter(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test merging code interpreter tool."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
assistant_tools = [create_code_interpreter_tool()]
|
||||
|
||||
merged = provider._merge_tools(assistant_tools, None) # type: ignore[reportPrivateUsage]
|
||||
|
||||
assert len(merged) == 1
|
||||
assert isinstance(merged[0], HostedCodeInterpreterTool)
|
||||
|
||||
def test_merge_file_search(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test merging file search tool."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
assistant_tools = [create_file_search_tool()]
|
||||
|
||||
merged = provider._merge_tools(assistant_tools, None) # type: ignore[reportPrivateUsage]
|
||||
|
||||
assert len(merged) == 1
|
||||
assert isinstance(merged[0], HostedFileSearchTool)
|
||||
|
||||
def test_merge_with_user_tools(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test merging hosted and user tools."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
assistant_tools = [create_code_interpreter_tool()]
|
||||
|
||||
merged = provider._merge_tools(assistant_tools, [get_weather]) # type: ignore[reportPrivateUsage]
|
||||
|
||||
assert len(merged) == 2
|
||||
assert isinstance(merged[0], HostedCodeInterpreterTool)
|
||||
|
||||
def test_merge_multiple_hosted_tools(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test merging multiple hosted tools."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
assistant_tools = [create_code_interpreter_tool(), create_file_search_tool()]
|
||||
|
||||
merged = provider._merge_tools(assistant_tools, None) # type: ignore[reportPrivateUsage]
|
||||
|
||||
assert len(merged) == 2
|
||||
|
||||
def test_merge_single_user_tool(self, mock_async_openai: MagicMock) -> None:
|
||||
"""Test merging with single user tool (not list)."""
|
||||
provider = OpenAIAssistantProvider(mock_async_openai)
|
||||
assistant_tools: list[Any] = []
|
||||
|
||||
merged = provider._merge_tools(assistant_tools, get_weather) # type: ignore[reportPrivateUsage]
|
||||
|
||||
assert len(merged) == 1
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Integration Tests
|
||||
|
||||
|
||||
skip_if_openai_integration_tests_disabled = pytest.mark.skipif(
|
||||
os.getenv("RUN_INTEGRATION_TESTS", "false").lower() != "true"
|
||||
or os.getenv("OPENAI_API_KEY", "") in ("", "test-dummy-key"),
|
||||
reason="No real OPENAI_API_KEY provided; skipping integration tests."
|
||||
if os.getenv("RUN_INTEGRATION_TESTS", "false").lower() == "true"
|
||||
else "Integration tests are disabled.",
|
||||
)
|
||||
|
||||
|
||||
@skip_if_openai_integration_tests_disabled
|
||||
class TestOpenAIAssistantProviderIntegration:
|
||||
"""Integration tests requiring real OpenAI API."""
|
||||
|
||||
async def test_create_and_run_agent(self) -> None:
|
||||
"""End-to-end test of creating and running an agent."""
|
||||
provider = OpenAIAssistantProvider()
|
||||
|
||||
agent = await provider.create_agent(
|
||||
name="IntegrationTestAgent",
|
||||
model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"),
|
||||
instructions="You are a helpful assistant. Respond briefly.",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await agent.run("Say 'hello' and nothing else.")
|
||||
result_text = str(result)
|
||||
assert "hello" in result_text.lower()
|
||||
finally:
|
||||
# Clean up the assistant
|
||||
await provider._client.beta.assistants.delete(agent.id) # type: ignore[reportPrivateUsage, union-attr]
|
||||
|
||||
async def test_create_agent_with_function_tools_integration(self) -> None:
|
||||
"""Integration test with function tools."""
|
||||
provider = OpenAIAssistantProvider()
|
||||
|
||||
def get_current_time() -> str:
|
||||
"""Get the current time."""
|
||||
from datetime import datetime
|
||||
|
||||
return datetime.now().strftime("%H:%M")
|
||||
|
||||
agent = await provider.create_agent(
|
||||
name="TimeAgent",
|
||||
model=os.environ.get("OPENAI_CHAT_MODEL_ID", "gpt-4"),
|
||||
instructions="You are a helpful assistant.",
|
||||
tools=[get_current_time],
|
||||
)
|
||||
|
||||
try:
|
||||
result = await agent.run("What time is it? Use the get_current_time function.")
|
||||
result_text = str(result)
|
||||
# The response should contain time information
|
||||
assert ":" in result_text or "time" in result_text.lower()
|
||||
finally:
|
||||
await provider._client.beta.assistants.delete(agent.id) # type: ignore[reportPrivateUsage, union-attr]
|
||||
|
||||
|
||||
# endregion
|
||||
Reference in New Issue
Block a user