Python: Support OpenAI and Gemini allowed_tools tool choice (#5322)

* Support OpenAI allowed_tools in ToolMode (#5309)

Add allowed_tools field to ToolMode TypedDict, enabling users to restrict
which tools the model may call via the OpenAI allowed_tools tool_choice
type. This preserves prompt caching by keeping all tools in the tools list
while limiting which ones the model can invoke.

- Add allowed_tools: list[str] to ToolMode TypedDict
- Add validation in validate_tool_mode() (only valid when mode == "auto")
- Convert to OpenAI API format in _prepare_options()
- Add tests for validation and API payload generation

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Python: Support OpenAI `allowed_tools` tool choice in Python SDK

Fixes #5309

* Fix #5309: Validate allowed_tools shape and add Chat Completions client support

- validate_tool_mode now checks allowed_tools is a non-string sequence of
  strings and normalizes to list[str], raising ContentError for invalid types
- Add missing allowed_tools branch in _chat_completion_client._prepare_options
  so allowed_tools is emitted as the OpenAI allowed_tools wire format instead
  of being silently dropped
- Add tests for invalid allowed_tools types (string, int, mixed), empty list,
  tuple normalization, and Chat Completions client payload generation

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix: support allowed_tools with mode 'required' in addition to 'auto'

OpenAI's allowed_tools tool_choice type supports both mode 'auto' and
'required'. Update validation, client conversion, and tests to allow
both modes instead of restricting to 'auto' only.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix: use Gemini VALIDATED mode for allowed_tools, warn in unsupported providers

- Use FunctionCallingConfigMode.VALIDATED instead of ANY when allowed_tools
  is set with auto mode in Gemini, preserving optional tool-call semantics.
- Handle allowed_tools in required mode with required_function_name precedence.
- Fix allowed_names guard to use identity check (is not None) so empty lists
  are preserved.
- Bump google-genai minimum to >=1.32.0 (VALIDATED added in that version).
- Add warnings in Anthropic and Bedrock when allowed_tools is set but not
  supported.
- Add Gemini unit tests for allowed_tools with auto, required, empty list,
  and required_function_name precedence scenarios.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix: Chat Completions API does not support allowed_tools, add integration tests

- Chat Completions API (_chat_completion_client.py) now warns and falls
  back to plain mode when allowed_tools is set, since the /chat/completions
  endpoint does not support the allowed_tools type.
- Add allowed_tools integration test param to both OpenAIChatClient
  (Responses API) and OpenAIChatCompletionClient parametrized option tests.
- Update Chat Completions unit tests to reflect the warn-and-fallback
  behavior.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix: remove unused walrus operator variable in chat completion client

Remove assigned-but-never-used variable 'allowed' flagged by ruff F841.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Giles Odigwe
2026-04-29 10:43:47 -07:00
committed by GitHub
Unverified
parent f5419b9f38
commit 570a4d54c2
11 changed files with 900 additions and 593 deletions
@@ -872,6 +872,8 @@ class RawAnthropicClient(
tool_mode = validate_tool_mode(options.get("tool_choice"))
if tool_mode is None:
return result or None
if "allowed_tools" in tool_mode:
logger.warning("allowed_tools is not supported by Anthropic; the setting will be ignored")
allow_multiple = options.get("allow_multiple_tool_calls")
match tool_mode.get("mode"):
case "auto":
@@ -405,6 +405,8 @@ class BedrockChatClient(
tool_config = self._prepare_tools(options.get("tools"))
if tool_mode := validate_tool_mode(options.get("tool_choice")):
if "allowed_tools" in tool_mode:
logger.warning("allowed_tools is not supported by Bedrock; the setting will be ignored")
match tool_mode.get("mode"):
case "none":
# Bedrock doesn't support toolChoice "none".
+14 -1
View File
@@ -3246,10 +3246,12 @@ class ToolMode(TypedDict, total=False):
Fields:
mode: One of "auto", "required", or "none".
required_function_name: Optional function name when `mode == "required"`.
allowed_tools: Optional list of tool names when `mode` is `"auto"` or `"required"`.
"""
mode: Literal["auto", "required", "none"]
required_function_name: str
allowed_tools: list[str]
# region TypedDict-based Chat Options
@@ -3482,7 +3484,7 @@ def validate_tool_mode(
Returns:
A ToolMode dict (contains keys: "mode", and optionally
"required_function_name"), or ``None`` when not provided.
"required_function_name" or "allowed_tools"), or ``None`` when not provided.
Raises:
ContentError: If the tool_choice string is invalid.
@@ -3499,6 +3501,17 @@ def validate_tool_mode(
raise ContentError(f"Invalid tool choice: {tool_choice['mode']}")
if tool_choice["mode"] != "required" and "required_function_name" in tool_choice:
raise ContentError("tool_choice with mode other than 'required' cannot have 'required_function_name'")
if tool_choice["mode"] not in ("auto", "required") and "allowed_tools" in tool_choice:
raise ContentError("tool_choice 'allowed_tools' is only valid when mode is 'auto' or 'required'")
if "allowed_tools" in tool_choice:
allowed_tools = tool_choice["allowed_tools"]
if isinstance(allowed_tools, str) or not isinstance(allowed_tools, Sequence):
raise ContentError("tool_choice 'allowed_tools' must be a non-string sequence of strings")
if not all(isinstance(tool_name, str) for tool_name in allowed_tools):
raise ContentError("tool_choice 'allowed_tools' must contain only strings")
normalized_tool_choice = dict(tool_choice)
normalized_tool_choice["allowed_tools"] = list(allowed_tools)
return cast(ToolMode, normalized_tool_choice)
return tool_choice
@@ -1087,16 +1087,20 @@ def test_chat_tool_mode():
required_any: ToolMode = {"mode": "required"}
required_mode: ToolMode = {"mode": "required", "required_function_name": "example_function"}
none_mode: ToolMode = {"mode": "none"}
allowed_mode: ToolMode = {"mode": "auto", "allowed_tools": ["get_weather", "search_docs"]}
# Check the type and content
assert auto_mode["mode"] == "auto"
assert "required_function_name" not in auto_mode
assert "allowed_tools" not in auto_mode
assert required_any["mode"] == "required"
assert "required_function_name" not in required_any
assert required_mode["mode"] == "required"
assert required_mode["required_function_name"] == "example_function"
assert none_mode["mode"] == "none"
assert "required_function_name" not in none_mode
assert allowed_mode["mode"] == "auto"
assert allowed_mode["allowed_tools"] == ["get_weather", "search_docs"]
# equality of dicts
assert {"mode": "required", "required_function_name": "example_function"} == {
@@ -1154,6 +1158,45 @@ def test_chat_options_tool_choice_validation():
with raises(ContentError):
validate_tool_mode({"mode": "auto", "required_function_name": "should_not_be_here"})
# Valid allowed_tools
assert validate_tool_mode({"mode": "auto", "allowed_tools": ["get_weather"]}) == {
"mode": "auto",
"allowed_tools": ["get_weather"],
}
assert validate_tool_mode({"mode": "auto", "allowed_tools": ["get_weather", "search_docs"]}) == {
"mode": "auto",
"allowed_tools": ["get_weather", "search_docs"],
}
# allowed_tools valid with required mode
assert validate_tool_mode({"mode": "required", "allowed_tools": ["get_weather"]}) == {
"mode": "required",
"allowed_tools": ["get_weather"],
}
# allowed_tools invalid with none mode
with raises(ContentError):
validate_tool_mode({"mode": "none", "allowed_tools": ["get_weather"]})
# allowed_tools must be a non-string sequence of strings
with raises(ContentError):
validate_tool_mode({"mode": "auto", "allowed_tools": "get_weather"})
with raises(ContentError):
validate_tool_mode({"mode": "auto", "allowed_tools": 123})
with raises(ContentError):
validate_tool_mode({"mode": "auto", "allowed_tools": ["get_weather", 123]})
# Empty list is valid (caller explicitly allows no tools)
assert validate_tool_mode({"mode": "auto", "allowed_tools": []}) == {
"mode": "auto",
"allowed_tools": [],
}
# Tuple is normalized to list
result = validate_tool_mode({"mode": "auto", "allowed_tools": ("get_weather",)})
assert result is not None
assert result["allowed_tools"] == ["get_weather"]
def test_chat_options_merge(tool_tool, ai_tool) -> None:
"""Test merge_chat_options utility function."""
@@ -823,19 +823,28 @@ class RawGeminiChatClient(
match tool_mode.get("mode"):
case "auto":
function_calling_mode, allowed_names = types.FunctionCallingConfigMode.AUTO, None
if "allowed_tools" in tool_mode:
function_calling_mode = types.FunctionCallingConfigMode.VALIDATED
allowed_names = list(tool_mode["allowed_tools"])
else:
function_calling_mode, allowed_names = types.FunctionCallingConfigMode.AUTO, None
case "none":
function_calling_mode, allowed_names = types.FunctionCallingConfigMode.NONE, None
case "required":
function_calling_mode = types.FunctionCallingConfigMode.ANY
name = tool_mode.get("required_function_name")
allowed_names = [name] if name else None
if name:
allowed_names = [name]
elif "allowed_tools" in tool_mode:
allowed_names = list(tool_mode["allowed_tools"])
else:
allowed_names = None
case unknown_mode:
logger.warning("Unsupported tool_choice mode for Gemini: %s", unknown_mode)
return None
function_calling_kwargs: dict[str, Any] = {"mode": function_calling_mode}
if allowed_names:
if allowed_names is not None:
function_calling_kwargs["allowed_function_names"] = allowed_names
return types.ToolConfig(function_calling_config=types.FunctionCallingConfig(**function_calling_kwargs))
@@ -1157,6 +1157,86 @@ async def test_unknown_tool_choice_mode_is_ignored() -> None:
assert not hasattr(config, "tool_config") or config.tool_config is None
async def test_tool_choice_auto_with_allowed_tools_uses_VALIDATED() -> None:
"""Maps auto + allowed_tools to FunctionCallingConfigMode.VALIDATED with allowed_function_names."""
tool = _make_dummy_tool()
client, mock = _make_gemini_client()
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
await client.get_response(
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
options={
"tools": [tool],
"tool_choice": {"mode": "auto", "allowed_tools": ["dummy", "other"]},
},
)
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
function_calling_config = config.tool_config.function_calling_config
assert function_calling_config.mode == "VALIDATED"
assert function_calling_config.allowed_function_names == ["dummy", "other"]
async def test_tool_choice_auto_with_empty_allowed_tools_uses_VALIDATED() -> None:
"""Maps auto + empty allowed_tools to VALIDATED with empty allowed_function_names."""
tool = _make_dummy_tool()
client, mock = _make_gemini_client()
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
await client.get_response(
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
options={
"tools": [tool],
"tool_choice": {"mode": "auto", "allowed_tools": []},
},
)
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
function_calling_config = config.tool_config.function_calling_config
assert function_calling_config.mode == "VALIDATED"
assert function_calling_config.allowed_function_names == []
async def test_tool_choice_required_with_allowed_tools_uses_ANY() -> None:
"""Maps required + allowed_tools to ANY with allowed_function_names."""
tool = _make_dummy_tool()
client, mock = _make_gemini_client()
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
await client.get_response(
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
options={
"tools": [tool],
"tool_choice": {"mode": "required", "allowed_tools": ["dummy"]},
},
)
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
function_calling_config = config.tool_config.function_calling_config
assert function_calling_config.mode == "ANY"
assert function_calling_config.allowed_function_names == ["dummy"]
async def test_tool_choice_required_function_name_takes_precedence_over_allowed_tools() -> None:
"""When both required_function_name and allowed_tools are present, required_function_name wins."""
tool = _make_dummy_tool()
client, mock = _make_gemini_client()
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
await client.get_response(
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
options={
"tools": [tool],
"tool_choice": {"mode": "required", "required_function_name": "dummy", "allowed_tools": ["other"]},
},
)
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
function_calling_config = config.tool_config.function_calling_config
assert function_calling_config.mode == "ANY"
assert function_calling_config.allowed_function_names == ["dummy"]
# built-in tool factories
@@ -1296,6 +1296,12 @@ class RawOpenAIChatClient( # type: ignore[misc]
"type": "function",
"name": func_name,
}
elif mode == "auto" and (allowed := tool_mode.get("allowed_tools")) is not None:
run_options["tool_choice"] = {
"type": "allowed_tools",
"mode": "auto",
"tools": [{"type": "function", "name": name} for name in allowed],
}
else:
run_options["tool_choice"] = mode
else:
@@ -662,6 +662,12 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
"type": "function",
"function": {"name": func_name},
}
elif mode in ("auto", "required") and tool_mode.get("allowed_tools") is not None:
logger.warning(
"allowed_tools is not supported by the Chat Completions API; "
"the setting will be ignored. Use OpenAIChatClient (Responses API) instead."
)
run_options["tool_choice"] = mode
else:
run_options["tool_choice"] = mode
@@ -4259,6 +4259,12 @@ def test_with_callable_api_key() -> None:
True,
id="tool_choice_required",
),
param(
"tool_choice",
{"mode": "auto", "allowed_tools": ["get_weather"]},
True,
id="tool_choice_allowed_tools",
),
param("response_format", OutputStruct, True, id="response_format_pydantic"),
param(
"response_format",
@@ -4813,6 +4819,90 @@ async def test_prepare_options_excludes_continuation_token() -> None:
assert run_options["background"] is True
async def test_prepare_options_allowed_tools() -> None:
"""Test that _prepare_options converts allowed_tools to OpenAI API format."""
client = OpenAIChatClient(model="test-model", api_key="test-key")
@tool
def get_weather(city: str) -> str:
"""Get the weather for a city."""
return f"Sunny in {city}"
@tool
def search_docs(query: str) -> str:
"""Search documentation."""
return f"Results for {query}"
messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
options: dict[str, Any] = {
"model": "test-model",
"tools": [get_weather, search_docs],
"tool_choice": {"mode": "auto", "allowed_tools": ["get_weather"]},
}
run_options = await client._prepare_options(messages, options)
assert run_options["tool_choice"] == {
"type": "allowed_tools",
"mode": "auto",
"tools": [{"type": "function", "name": "get_weather"}],
}
async def test_prepare_options_allowed_tools_multiple() -> None:
"""Test that _prepare_options converts multiple allowed_tools correctly."""
client = OpenAIChatClient(model="test-model", api_key="test-key")
@tool
def get_weather(city: str) -> str:
"""Get the weather for a city."""
return f"Sunny in {city}"
@tool
def search_docs(query: str) -> str:
"""Search documentation."""
return f"Results for {query}"
messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
options: dict[str, Any] = {
"model": "test-model",
"tools": [get_weather, search_docs],
"tool_choice": {"mode": "auto", "allowed_tools": ["get_weather", "search_docs"]},
}
run_options = await client._prepare_options(messages, options)
assert run_options["tool_choice"] == {
"type": "allowed_tools",
"mode": "auto",
"tools": [
{"type": "function", "name": "get_weather"},
{"type": "function", "name": "search_docs"},
],
}
async def test_prepare_options_auto_without_allowed_tools() -> None:
"""Test that auto mode without allowed_tools still returns plain 'auto' string."""
client = OpenAIChatClient(model="test-model", api_key="test-key")
@tool
def get_weather(city: str) -> str:
"""Get the weather for a city."""
return f"Sunny in {city}"
messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
options: dict[str, Any] = {
"model": "test-model",
"tools": [get_weather],
"tool_choice": {"mode": "auto"},
}
run_options = await client._prepare_options(messages, options)
assert run_options["tool_choice"] == "auto"
# endregion
@@ -1430,6 +1430,57 @@ def test_tool_choice_required_with_function_name(
assert prepared_options["tool_choice"]["function"]["name"] == "get_weather"
def test_tool_choice_allowed_tools_falls_back_to_mode(
openai_unit_test_env: dict[str, str],
) -> None:
"""Test that tool_choice with allowed_tools falls back to plain mode (Chat Completions API unsupported)."""
client = OpenAIChatCompletionClient()
messages = [Message(role="user", contents=["test"])]
options = {
"tools": [get_weather],
"tool_choice": {"mode": "auto", "allowed_tools": ["get_weather"]},
}
prepared_options = client._prepare_options(messages, options)
assert prepared_options["tool_choice"] == "auto"
def test_tool_choice_allowed_tools_required_mode_falls_back(
openai_unit_test_env: dict[str, str],
) -> None:
"""Test that tool_choice with allowed_tools and required mode falls back to 'required'."""
client = OpenAIChatCompletionClient()
messages = [Message(role="user", contents=["test"])]
options = {
"tools": [get_weather],
"tool_choice": {"mode": "required", "allowed_tools": ["get_weather"]},
}
prepared_options = client._prepare_options(messages, options)
assert prepared_options["tool_choice"] == "required"
def test_tool_choice_auto_dict_without_allowed_tools(
openai_unit_test_env: dict[str, str],
) -> None:
"""Test that tool_choice dict with mode auto and no allowed_tools falls through to plain 'auto'."""
client = OpenAIChatCompletionClient()
messages = [Message(role="user", contents=["test"])]
options = {
"tools": [get_weather],
"tool_choice": {"mode": "auto"},
}
prepared_options = client._prepare_options(messages, options)
assert prepared_options["tool_choice"] == "auto"
def test_response_format_dict_passthrough(openai_unit_test_env: dict[str, str]) -> None:
"""Test that response_format as dict is passed through directly."""
client = OpenAIChatCompletionClient()
@@ -1590,6 +1641,12 @@ class OutputStruct(BaseModel):
False,
id="tool_choice_required",
),
param(
"tool_choice",
{"mode": "auto", "allowed_tools": ["get_weather"]},
False,
id="tool_choice_allowed_tools",
),
param("response_format", OutputStruct, True, id="response_format_pydantic"),
param(
"response_format",
+588 -589
View File
File diff suppressed because it is too large Load Diff