mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
91c5414836
* added poe setup and docs * smaller bandit exclude * updated poe * updated naming * added something in samples * exclude docs from bandit * updated readme * removed ds_store * updated readme
133 lines
4.9 KiB
Python
133 lines
4.9 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
import functools
|
|
import inspect
|
|
from collections.abc import Awaitable, Callable
|
|
from typing import Any, Generic, Protocol, TypeVar, runtime_checkable
|
|
|
|
from pydantic import BaseModel, create_model
|
|
|
|
|
|
@runtime_checkable
|
|
class AITool(Protocol):
|
|
"""Represents a tool that can be specified to an AI service."""
|
|
|
|
name: str
|
|
"""The name of the tool."""
|
|
description: str | None = None
|
|
"""A description of the tool, suitable for use in describing the purpose to a model."""
|
|
additional_properties: dict[str, Any] | None = None
|
|
"""Additional properties associated with the tool."""
|
|
|
|
def __str__(self) -> str:
|
|
"""Return a string representation of the tool."""
|
|
...
|
|
|
|
|
|
ArgsT = TypeVar("ArgsT", bound=BaseModel)
|
|
ReturnT = TypeVar("ReturnT")
|
|
|
|
|
|
class AIFunction(Generic[ArgsT, ReturnT]):
|
|
"""A tool that represents a function that can be called by an AI service."""
|
|
|
|
def __init__(
|
|
self,
|
|
func: Callable[..., Awaitable[ReturnT] | ReturnT],
|
|
name: str,
|
|
description: str,
|
|
input_model: type[ArgsT],
|
|
**kwargs: Any,
|
|
):
|
|
"""Initialize a FunctionTool.
|
|
|
|
Args:
|
|
func: The function to wrap.
|
|
name: The name of the tool.
|
|
description: A description of the tool.
|
|
input_model: A Pydantic model that defines the input parameters for the function.
|
|
**kwargs: Additional properties to set on the tool.
|
|
stored in additional_properties.
|
|
"""
|
|
self.name = name
|
|
self.description = description
|
|
self.input_model = input_model
|
|
self.additional_properties: dict[str, Any] | None = kwargs
|
|
self._func = func
|
|
|
|
def model_json_schema(self) -> dict[str, Any]:
|
|
"""Return the JSON schema of the input model."""
|
|
return self.input_model.model_json_schema()
|
|
|
|
def __call__(self, *args: Any, **kwargs: Any) -> ReturnT | Awaitable[ReturnT]:
|
|
"""Call the wrapped function with the provided arguments."""
|
|
return self._func(*args, **kwargs)
|
|
|
|
async def invoke(
|
|
self,
|
|
*,
|
|
arguments: ArgsT | None = None,
|
|
**kwargs: Any,
|
|
) -> ReturnT:
|
|
"""Run the AI function with the provided arguments as a Pydantic model.
|
|
|
|
Args:
|
|
arguments: A Pydantic model instance containing the arguments for the function.
|
|
kwargs: keyword arguments to pass to the function, will not be used if `args` is provided.
|
|
"""
|
|
if arguments is not None:
|
|
if not isinstance(arguments, self.input_model):
|
|
raise TypeError(f"Expected {self.input_model.__name__}, got {type(arguments).__name__}")
|
|
kwargs = arguments.model_dump(exclude_none=True)
|
|
res = self.__call__(**kwargs)
|
|
if inspect.isawaitable(res):
|
|
return await res
|
|
return res
|
|
|
|
|
|
def ai_function(
|
|
func: Callable[..., ReturnT | Awaitable[ReturnT]] | None = None,
|
|
*,
|
|
name: str | None = None,
|
|
description: str | None = None,
|
|
additional_properties: dict[str, Any] | None = None,
|
|
) -> AIFunction[Any, ReturnT] | Callable[[Callable[..., ReturnT | Awaitable[ReturnT]]], AIFunction[Any, ReturnT]]:
|
|
"""Decorate a function to turn it into a AIFunction that can be passed to models.
|
|
|
|
Args:
|
|
func: The function to wrap. If None, returns a decorator.
|
|
name: The name of the tool. Defaults to the function's name.
|
|
description: A description of the tool. Defaults to the function's docstring.
|
|
additional_properties: Additional properties to set on the tool.
|
|
|
|
"""
|
|
|
|
def wrapper(f: Callable[..., ReturnT | Awaitable[ReturnT]]) -> AIFunction[Any, ReturnT]:
|
|
tool_name: str = name or getattr(f, "__name__", "unknown_function") # type: ignore[assignment]
|
|
tool_desc: str = description or (f.__doc__ or "")
|
|
sig = inspect.signature(f)
|
|
fields = {
|
|
pname: (
|
|
param.annotation if param.annotation is not inspect.Parameter.empty else str,
|
|
param.default if param.default is not inspect.Parameter.empty else ...,
|
|
)
|
|
for pname, param in sig.parameters.items()
|
|
if pname not in {"self", "cls"}
|
|
}
|
|
input_model: Any = create_model(f"{tool_name}_input", **fields) # type: ignore[call-overload]
|
|
if not issubclass(input_model, BaseModel):
|
|
raise TypeError(f"Input model for {tool_name} must be a subclass of BaseModel, got {input_model}")
|
|
|
|
return functools.update_wrapper( # type: ignore[return-value]
|
|
AIFunction[Any, ReturnT](
|
|
func=f,
|
|
name=tool_name,
|
|
description=tool_desc,
|
|
input_model=input_model,
|
|
**(additional_properties if additional_properties is not None else {}),
|
|
),
|
|
f,
|
|
)
|
|
|
|
return wrapper(func) if func else wrapper
|