mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: moved prepare tools into class (#215)
* moved prepare tools into class * moved test * changed tool handling * fix test * second fix
This commit is contained in:
committed by
GitHub
Unverified
parent
84e5ee97b9
commit
27f7af2160
@@ -8,7 +8,7 @@ from typing import Any, ClassVar
|
||||
from agent_framework import (
|
||||
AFBaseSettings,
|
||||
AIContents,
|
||||
AITool,
|
||||
AIFunction,
|
||||
ChatClientBase,
|
||||
ChatMessage,
|
||||
ChatOptions,
|
||||
@@ -25,7 +25,7 @@ from agent_framework import (
|
||||
UsageDetails,
|
||||
use_tool_calling,
|
||||
)
|
||||
from agent_framework._clients import tool_to_json_schema_spec
|
||||
from agent_framework._clients import ai_function_to_json_schema_spec
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from azure.ai.agents.models import (
|
||||
AgentsNamedToolChoice,
|
||||
@@ -455,13 +455,14 @@ class FoundryChatClient(ChatClientBase):
|
||||
run_options["parallel_tool_calls"] = chat_options.allow_multiple_tool_calls
|
||||
|
||||
if chat_options.tools is not None:
|
||||
# TODO (eavanvalkenburg): replace with _prepare_tools_and_tool_choice overload
|
||||
tool_definitions: list[MutableMapping[str, Any]] = []
|
||||
|
||||
for tool in chat_options.tools:
|
||||
if isinstance(tool, AITool):
|
||||
tool_definitions.append(tool_to_json_schema_spec(tool))
|
||||
if isinstance(tool, AIFunction):
|
||||
tool_definitions.append(ai_function_to_json_schema_spec(tool))
|
||||
else:
|
||||
tool_definitions.append(tool)
|
||||
tool_definitions.append(tool) # type: ignore
|
||||
|
||||
if len(tool_definitions) > 0:
|
||||
run_options["tools"] = tool_definitions
|
||||
|
||||
@@ -75,35 +75,18 @@ async def _auto_invoke_function(
|
||||
)
|
||||
|
||||
|
||||
def tool_to_json_schema_spec(tool: AITool) -> dict[str, Any]:
|
||||
"""Convert a AITool to the JSON Schema function specification format."""
|
||||
def ai_function_to_json_schema_spec(function: AIFunction[BaseModel, Any]) -> dict[str, Any]:
|
||||
"""Convert a AIFunction to the JSON Schema function specification format."""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters(),
|
||||
"name": function.name,
|
||||
"description": function.description,
|
||||
"parameters": function.parameters(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _prepare_tools_and_tool_choice(chat_options: ChatOptions) -> None:
|
||||
"""Prepare the tools and tool choice for the chat options."""
|
||||
chat_tool_mode: ChatToolMode | None = chat_options.tool_choice # type: ignore
|
||||
if chat_tool_mode is None or chat_tool_mode == ChatToolMode.NONE:
|
||||
chat_options.tools = None
|
||||
chat_options.tool_choice = ChatToolMode.NONE.mode
|
||||
return
|
||||
chat_options.tools = [
|
||||
(tool_to_json_schema_spec(t) if isinstance(t, AITool) else t)
|
||||
for t in chat_options._ai_tools or [] # type: ignore[reportPrivateUsage]
|
||||
]
|
||||
if not chat_options.tools:
|
||||
chat_options.tool_choice = ChatToolMode.NONE.mode
|
||||
else:
|
||||
chat_options.tool_choice = chat_tool_mode.mode
|
||||
|
||||
|
||||
def _tool_call_non_streaming(func: TInnerGetResponse) -> TInnerGetResponse:
|
||||
"""Decorate the internal _inner_get_response method to enable tool calls."""
|
||||
|
||||
@@ -163,7 +146,7 @@ def _tool_call_non_streaming(func: TInnerGetResponse) -> TInnerGetResponse:
|
||||
|
||||
# Failsafe: give up on tools, ask model for plain answer
|
||||
chat_options.tool_choice = "none"
|
||||
_prepare_tools_and_tool_choice(chat_options=chat_options)
|
||||
self._prepare_tools_and_tool_choice(chat_options=chat_options) # type: ignore[reportPrivateUsage]
|
||||
response = await func(self, messages=messages, chat_options=chat_options)
|
||||
if fcc_messages:
|
||||
for msg in reversed(fcc_messages):
|
||||
@@ -231,7 +214,7 @@ def _tool_call_streaming(func: TInnerGetStreamingResponse) -> TInnerGetStreaming
|
||||
|
||||
# Failsafe: give up on tools, ask model for plain answer
|
||||
chat_options.tool_choice = "none"
|
||||
_prepare_tools_and_tool_choice(chat_options=chat_options)
|
||||
self._prepare_tools_and_tool_choice(chat_options=chat_options) # type: ignore[reportPrivateUsage]
|
||||
async for update in func(self, messages=messages, chat_options=chat_options, **kwargs):
|
||||
yield update
|
||||
|
||||
@@ -542,7 +525,7 @@ class ChatClientBase(AFBaseModel, ABC):
|
||||
additional_properties=additional_properties or {},
|
||||
)
|
||||
prepped_messages = self._prepare_messages(messages)
|
||||
_prepare_tools_and_tool_choice(chat_options=chat_options)
|
||||
self._prepare_tools_and_tool_choice(chat_options=chat_options)
|
||||
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **kwargs)
|
||||
|
||||
async def get_streaming_response(
|
||||
@@ -623,12 +606,32 @@ class ChatClientBase(AFBaseModel, ABC):
|
||||
**kwargs,
|
||||
)
|
||||
prepped_messages = self._prepare_messages(messages)
|
||||
_prepare_tools_and_tool_choice(chat_options=chat_options)
|
||||
self._prepare_tools_and_tool_choice(chat_options=chat_options)
|
||||
async for update in self._inner_get_streaming_response(
|
||||
messages=prepped_messages, chat_options=chat_options, **kwargs
|
||||
):
|
||||
yield update
|
||||
|
||||
def _prepare_tools_and_tool_choice(self, chat_options: ChatOptions) -> None:
|
||||
"""Prepare the tools and tool choice for the chat options.
|
||||
|
||||
This function should be overridden by subclasses to customize tool handling.
|
||||
Because it currently parses only AIFunctions.
|
||||
"""
|
||||
chat_tool_mode: ChatToolMode | None = chat_options.tool_choice # type: ignore
|
||||
if chat_tool_mode is None or chat_tool_mode == ChatToolMode.NONE:
|
||||
chat_options.tools = None
|
||||
chat_options.tool_choice = ChatToolMode.NONE.mode
|
||||
return
|
||||
chat_options.tools = [
|
||||
(ai_function_to_json_schema_spec(t) if isinstance(t, AIFunction) else t) # type: ignore[reportUnknownArgumentType]
|
||||
for t in chat_options._ai_tools or [] # type: ignore[reportPrivateUsage]
|
||||
]
|
||||
if not chat_options.tools:
|
||||
chat_options.tool_choice = ChatToolMode.NONE.mode
|
||||
else:
|
||||
chat_options.tool_choice = chat_tool_mode.mode
|
||||
|
||||
|
||||
# region: Embedding Client
|
||||
|
||||
|
||||
@@ -305,3 +305,48 @@ async def test_base_client_with_streaming_function_calling_disabled(chat_client_
|
||||
updates.append(update)
|
||||
assert len(updates) == 1
|
||||
assert exec_counter == 0
|
||||
|
||||
|
||||
def test_chat_options_parsing_tools(chat_client_base, ai_function_tool) -> None:
|
||||
"""Test that chat options can parse tools correctly."""
|
||||
|
||||
def echo() -> str:
|
||||
"""Echo the input."""
|
||||
return "Echo"
|
||||
|
||||
dict_function = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Retrieves current weather for the given location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "City and country e.g. Bogotá, Colombia"},
|
||||
"units": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "Units the temperature will be returned in.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "units"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
|
||||
options = ChatOptions(tools=[ai_function_tool, echo, dict_function], tool_choice="auto")
|
||||
assert len(options.tools) == 3
|
||||
assert options.tools[0] == ai_function_tool
|
||||
assert options.tools[1] != echo
|
||||
assert options.tools[2] == dict_function
|
||||
# after prepare, the tools should be represented as dicts
|
||||
# while ai_tools is still the same.
|
||||
chat_client_base._prepare_tools_and_tool_choice(chat_options=options)
|
||||
assert options._ai_tools[0] == ai_function_tool
|
||||
assert options._ai_tools[2] == dict_function
|
||||
assert len(options.tools) == 3
|
||||
assert options.tools[0]["function"]["name"] == "simple_function"
|
||||
assert options.tools[1]["function"]["name"] == "echo"
|
||||
assert options.tools[2]["function"]["name"] == "get_weather"
|
||||
|
||||
@@ -539,54 +539,6 @@ def test_chat_options_and(ai_function_tool, ai_tool) -> None:
|
||||
assert options3.tools == [ai_function_tool, ai_tool]
|
||||
|
||||
|
||||
def test_chat_options_parsing_tools(ai_function_tool, ai_tool) -> None:
|
||||
from agent_framework._clients import _prepare_tools_and_tool_choice
|
||||
|
||||
def echo() -> str:
|
||||
"""Echo the input."""
|
||||
return "Echo"
|
||||
|
||||
dict_function = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Retrieves current weather for the given location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "City and country e.g. Bogotá, Colombia"},
|
||||
"units": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "Units the temperature will be returned in.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "units"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
|
||||
options = ChatOptions(tools=[ai_function_tool, ai_tool, echo, dict_function], tool_choice="auto")
|
||||
assert len(options.tools) == 4
|
||||
assert options.tools[0] == ai_function_tool
|
||||
assert options.tools[1] == ai_tool
|
||||
assert options.tools[2] != echo
|
||||
assert options.tools[3] == dict_function
|
||||
# after prepare, the tools should be represented as dicts
|
||||
# while ai_tools is still the same.
|
||||
_prepare_tools_and_tool_choice(options)
|
||||
assert options._ai_tools[0] == ai_function_tool
|
||||
assert options._ai_tools[1] == ai_tool
|
||||
assert options._ai_tools[3] == dict_function
|
||||
assert len(options.tools) == 4
|
||||
assert options.tools[0]["function"]["name"] == "simple_function"
|
||||
assert options.tools[1]["function"]["name"] == "generic_tool"
|
||||
assert options.tools[2]["function"]["name"] == "echo"
|
||||
assert options.tools[3]["function"]["name"] == "get_weather"
|
||||
|
||||
|
||||
# region Agent Response Fixtures
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user