mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
570a4d54c2
* Support OpenAI allowed_tools in ToolMode (#5309) Add allowed_tools field to ToolMode TypedDict, enabling users to restrict which tools the model may call via the OpenAI allowed_tools tool_choice type. This preserves prompt caching by keeping all tools in the tools list while limiting which ones the model can invoke. - Add allowed_tools: list[str] to ToolMode TypedDict - Add validation in validate_tool_mode() (only valid when mode == "auto") - Convert to OpenAI API format in _prepare_options() - Add tests for validation and API payload generation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Python: Support OpenAI `allowed_tools` tool choice in Python SDK Fixes #5309 * Fix #5309: Validate allowed_tools shape and add Chat Completions client support - validate_tool_mode now checks allowed_tools is a non-string sequence of strings and normalizes to list[str], raising ContentError for invalid types - Add missing allowed_tools branch in _chat_completion_client._prepare_options so allowed_tools is emitted as the OpenAI allowed_tools wire format instead of being silently dropped - Add tests for invalid allowed_tools types (string, int, mixed), empty list, tuple normalization, and Chat Completions client payload generation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: support allowed_tools with mode 'required' in addition to 'auto' OpenAI's allowed_tools tool_choice type supports both mode 'auto' and 'required'. Update validation, client conversion, and tests to allow both modes instead of restricting to 'auto' only. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: use Gemini VALIDATED mode for allowed_tools, warn in unsupported providers - Use FunctionCallingConfigMode.VALIDATED instead of ANY when allowed_tools is set with auto mode in Gemini, preserving optional tool-call semantics. - Handle allowed_tools in required mode with required_function_name precedence. - Fix allowed_names guard to use identity check (is not None) so empty lists are preserved. - Bump google-genai minimum to >=1.32.0 (VALIDATED added in that version). - Add warnings in Anthropic and Bedrock when allowed_tools is set but not supported. - Add Gemini unit tests for allowed_tools with auto, required, empty list, and required_function_name precedence scenarios. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: Chat Completions API does not support allowed_tools, add integration tests - Chat Completions API (_chat_completion_client.py) now warns and falls back to plain mode when allowed_tools is set, since the /chat/completions endpoint does not support the allowed_tools type. - Add allowed_tools integration test param to both OpenAIChatClient (Responses API) and OpenAIChatCompletionClient parametrized option tests. - Update Chat Completions unit tests to reflect the warn-and-fallback behavior. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: remove unused walrus operator variable in chat completion client Remove assigned-but-never-used variable 'allowed' flagged by ruff F841. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1698 lines
67 KiB
Python
1698 lines
67 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
from __future__ import annotations
|
|
|
|
import datetime
|
|
import logging
|
|
import os
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from agent_framework import Content, FunctionTool, Message
|
|
from google.genai import types
|
|
from pydantic import BaseModel
|
|
|
|
from agent_framework_gemini import GeminiChatClient, GeminiChatOptions, ThinkingConfig
|
|
|
|
|
|
def _has_gemini_integration_credentials() -> bool:
|
|
"""Return whether integration credentials for either Gemini API or Vertex AI appear to be configured."""
|
|
if os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY"):
|
|
return True
|
|
|
|
if os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() in {"true", "1", "yes", "on"}:
|
|
return bool(
|
|
os.getenv("GOOGLE_CLOUD_PROJECT")
|
|
or os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
|
|
or os.getenv("GOOGLE_API_KEY")
|
|
)
|
|
|
|
return False
|
|
|
|
|
|
skip_if_no_credentials = pytest.mark.skipif(
|
|
not _has_gemini_integration_credentials(),
|
|
reason="Gemini Developer API or Vertex AI credentials not set; skipping integration tests.",
|
|
)
|
|
|
|
_TEST_MODEL = os.getenv("GOOGLE_MODEL") or os.getenv("GEMINI_MODEL", "gemini-2.5-flash-lite")
|
|
|
|
# stub helpers
|
|
|
|
|
|
def _make_part(
|
|
*,
|
|
text: str | None = None,
|
|
thought: bool = False,
|
|
function_call: tuple[str, str, dict[str, Any]] | None = None,
|
|
executable_code: str | None = None,
|
|
code_execution_result: str | None = None,
|
|
) -> MagicMock:
|
|
"""Build a mock types.Part.
|
|
|
|
Args:
|
|
text: Text content of the part.
|
|
thought: Whether this is a thinking/reasoning part.
|
|
function_call: Tuple of (id, name, args) if this is a function call part.
|
|
executable_code: Source code string for a code execution part.
|
|
code_execution_result: Output string for a code execution result part.
|
|
"""
|
|
part = MagicMock()
|
|
part.text = text
|
|
part.thought = thought
|
|
part.function_response = None
|
|
part.executable_code = None
|
|
part.code_execution_result = None
|
|
|
|
if function_call:
|
|
mock_function_call = MagicMock()
|
|
mock_function_call.id, mock_function_call.name, mock_function_call.args = function_call
|
|
part.function_call = mock_function_call
|
|
else:
|
|
part.function_call = None
|
|
|
|
if executable_code is not None:
|
|
mock_exec = MagicMock()
|
|
mock_exec.code = executable_code
|
|
part.executable_code = mock_exec
|
|
|
|
if code_execution_result is not None:
|
|
mock_result = MagicMock()
|
|
mock_result.output = code_execution_result
|
|
part.code_execution_result = mock_result
|
|
|
|
return part
|
|
|
|
|
|
def _make_response(
|
|
parts: list[MagicMock],
|
|
*,
|
|
finish_reason: str | None = "STOP",
|
|
model_version: str = "gemini-2.5-flash-001",
|
|
prompt_tokens: int | None = 10,
|
|
output_tokens: int | None = 5,
|
|
total_tokens: int | None = 15,
|
|
) -> MagicMock:
|
|
"""Build a mock types.GenerateContentResponse."""
|
|
response = MagicMock()
|
|
candidate = MagicMock()
|
|
candidate.content.parts = parts
|
|
|
|
if finish_reason:
|
|
candidate.finish_reason.name = finish_reason
|
|
else:
|
|
candidate.finish_reason = None
|
|
|
|
response.candidates = [candidate]
|
|
response.finish_reason = finish_reason
|
|
response.model_version = model_version
|
|
|
|
if prompt_tokens is not None or output_tokens is not None:
|
|
usage = MagicMock()
|
|
usage.prompt_token_count = prompt_tokens
|
|
usage.candidates_token_count = output_tokens
|
|
usage.total_token_count = total_tokens
|
|
response.usage_metadata = usage
|
|
else:
|
|
response.usage_metadata = None
|
|
|
|
return response
|
|
|
|
|
|
async def _async_iter(items: list[Any]):
|
|
"""Async generator used to simulate generate_content_stream results."""
|
|
for item in items:
|
|
yield item
|
|
|
|
|
|
def _make_gemini_client(
|
|
model: str = "gemini-2.5-flash",
|
|
mock_client: MagicMock | None = None,
|
|
) -> tuple[GeminiChatClient, MagicMock]:
|
|
"""Return a (GeminiChatClient, mock_genai_client) pair."""
|
|
mock = mock_client or MagicMock()
|
|
mock._api_client.vertexai = False
|
|
mock._api_client._http_options.base_url = "https://generativelanguage.googleapis.com/"
|
|
client = GeminiChatClient(client=mock, model=model)
|
|
return client, mock
|
|
|
|
|
|
# settings & initialisation
|
|
|
|
|
|
def test_model_stored_on_instance() -> None:
|
|
"""Stores the model identifier on the instance so it can be read back."""
|
|
client, _ = _make_gemini_client(model="gemini-2.5-pro")
|
|
assert client.model == "gemini-2.5-pro"
|
|
|
|
|
|
def test_client_created_from_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Initialises successfully when the API key is supplied via environment variable."""
|
|
monkeypatch.setenv("GEMINI_API_KEY", "test-key-123")
|
|
client = GeminiChatClient(model="gemini-2.5-flash")
|
|
assert client.model == "gemini-2.5-flash"
|
|
|
|
|
|
def test_client_created_from_google_api_key_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Initialises successfully when the SDK-standard Google API key environment variable is set."""
|
|
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
|
monkeypatch.delenv("GEMINI_MODEL", raising=False)
|
|
monkeypatch.delenv("GOOGLE_GENAI_USE_VERTEXAI", raising=False)
|
|
monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
|
|
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
|
|
monkeypatch.setenv("GOOGLE_API_KEY", "test-key-123")
|
|
monkeypatch.setenv("GOOGLE_MODEL", "gemini-2.5-flash-lite")
|
|
|
|
mock_client = MagicMock()
|
|
mock_client._api_client.vertexai = False
|
|
mock_client._api_client._http_options.base_url = "https://generativelanguage.googleapis.com/"
|
|
|
|
with patch("agent_framework_gemini._chat_client.genai.Client") as client_factory:
|
|
client_factory.return_value = mock_client
|
|
client = GeminiChatClient()
|
|
|
|
assert client_factory.call_args.kwargs["api_key"] == "test-key-123"
|
|
assert "vertexai" not in client_factory.call_args.kwargs
|
|
assert client.model == "gemini-2.5-flash-lite"
|
|
assert client.service_url() == "https://generativelanguage.googleapis.com"
|
|
|
|
|
|
def test_client_created_from_vertex_ai_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Initialises a Vertex AI client when the SDK-standard Vertex AI environment variables are set."""
|
|
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
|
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
|
|
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
|
|
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project")
|
|
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "global")
|
|
|
|
mock_client = MagicMock()
|
|
mock_client._api_client.vertexai = True
|
|
mock_client._api_client._http_options.base_url = "https://aiplatform.googleapis.com/"
|
|
|
|
with patch("agent_framework_gemini._chat_client.genai.Client", return_value=mock_client) as client_factory:
|
|
client = GeminiChatClient()
|
|
|
|
assert client_factory.call_args.kwargs["vertexai"] is True
|
|
assert client_factory.call_args.kwargs["project"] == "test-project"
|
|
assert client_factory.call_args.kwargs["location"] == "global"
|
|
assert "api_key" not in client_factory.call_args.kwargs
|
|
assert client.service_url() == "https://aiplatform.googleapis.com"
|
|
|
|
|
|
def test_google_settings_take_precedence_over_gemini_aliases(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Prefers SDK-standard ``GOOGLE_*`` settings when both env families are present."""
|
|
monkeypatch.setenv("GEMINI_API_KEY", "gemini-key")
|
|
monkeypatch.setenv("GEMINI_MODEL", "gemini-model")
|
|
monkeypatch.setenv("GOOGLE_API_KEY", "google-key")
|
|
monkeypatch.setenv("GOOGLE_MODEL", "google-model")
|
|
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
|
|
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "google-project")
|
|
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "global")
|
|
|
|
mock_client = MagicMock()
|
|
mock_client._api_client.vertexai = True
|
|
mock_client._api_client._http_options.base_url = "https://aiplatform.googleapis.com/"
|
|
|
|
with patch("agent_framework_gemini._chat_client.genai.Client", return_value=mock_client) as client_factory:
|
|
client = GeminiChatClient()
|
|
|
|
assert client_factory.call_args.kwargs["vertexai"] is True
|
|
assert client_factory.call_args.kwargs["project"] == "google-project"
|
|
assert client_factory.call_args.kwargs["location"] == "global"
|
|
assert "api_key" not in client_factory.call_args.kwargs
|
|
assert client.model == "google-model"
|
|
assert client.service_url() == "https://aiplatform.googleapis.com"
|
|
|
|
|
|
def test_missing_api_key_raises_when_no_client_injected(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Raises ValueError at construction when neither Gemini API nor Vertex AI settings are available."""
|
|
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
|
monkeypatch.delenv("GEMINI_MODEL", raising=False)
|
|
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
|
|
monkeypatch.delenv("GOOGLE_GENAI_USE_VERTEXAI", raising=False)
|
|
monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
|
|
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
|
|
|
|
with pytest.raises(ValueError, match="requires an API key when Vertex AI is not enabled"):
|
|
GeminiChatClient(model="gemini-2.5-flash")
|
|
|
|
|
|
def test_vertex_ai_express_mode_uses_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Passes the API key in Vertex AI express mode when no project/location pair is configured."""
|
|
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
|
monkeypatch.delenv("GEMINI_MODEL", raising=False)
|
|
monkeypatch.setenv("GOOGLE_API_KEY", "test-key-123")
|
|
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
|
|
monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
|
|
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
|
|
|
|
mock_client = MagicMock()
|
|
mock_client._api_client.vertexai = True
|
|
mock_client._api_client._http_options.base_url = "https://aiplatform.googleapis.com/"
|
|
|
|
with patch("agent_framework_gemini._chat_client.genai.Client", return_value=mock_client) as client_factory:
|
|
client = GeminiChatClient(model="gemini-2.5-flash-lite")
|
|
|
|
assert client_factory.call_args.kwargs["vertexai"] is True
|
|
assert client_factory.call_args.kwargs["api_key"] == "test-key-123"
|
|
assert "project" not in client_factory.call_args.kwargs
|
|
assert "location" not in client_factory.call_args.kwargs
|
|
assert client.service_url() == "https://aiplatform.googleapis.com"
|
|
|
|
|
|
def test_vertex_ai_requires_configuration(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Raises a deterministic error when Vertex AI is enabled without any auth configuration."""
|
|
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
|
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
|
|
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
|
|
monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
|
|
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
|
|
|
|
with pytest.raises(ValueError, match="requires Vertex AI credentials or configuration"):
|
|
GeminiChatClient(model="gemini-2.5-flash")
|
|
|
|
|
|
def test_vertex_ai_requires_project_and_location_together(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Raises a deterministic error when only one Vertex AI location setting is present."""
|
|
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
|
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
|
|
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
|
|
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project")
|
|
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
|
|
|
|
with pytest.raises(ValueError, match="requires both GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION"):
|
|
GeminiChatClient(model="gemini-2.5-flash")
|
|
|
|
|
|
async def test_missing_model_raises_on_get_response(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Raises ValueError at call time when no model is set on the client or in options."""
|
|
monkeypatch.delenv("GEMINI_MODEL", raising=False)
|
|
monkeypatch.delenv("GOOGLE_MODEL", raising=False)
|
|
client, mock = _make_gemini_client(model=None) # type: ignore[arg-type]
|
|
mock.aio.models.generate_content = AsyncMock()
|
|
|
|
with pytest.raises(ValueError, match="model"):
|
|
await client.get_response(messages=[Message(role="user", contents=[Content.from_text("hi")])])
|
|
|
|
|
|
# text response
|
|
|
|
|
|
async def test_get_response_returns_text() -> None:
|
|
"""Returns the model's text reply in the first message of the response."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hello!")]))
|
|
|
|
response = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Hi")])])
|
|
|
|
assert response.messages[0].text == "Hello!"
|
|
|
|
|
|
async def test_get_response_model_from_response() -> None:
|
|
"""Populates ChatResponse.model from the model_version field in the API response."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(
|
|
return_value=_make_response([_make_part(text="Hi")], model_version="gemini-2.5-pro-002")
|
|
)
|
|
|
|
response = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Hi")])])
|
|
|
|
assert response.model == "gemini-2.5-pro-002"
|
|
|
|
|
|
async def test_get_response_uses_model_from_options() -> None:
|
|
"""Uses the model specified in options, overriding the client's default."""
|
|
client, mock = _make_gemini_client(model="gemini-2.5-flash")
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={"model": "gemini-2.5-pro"},
|
|
)
|
|
|
|
call_kwargs = mock.aio.models.generate_content.call_args.kwargs
|
|
assert call_kwargs["model"] == "gemini-2.5-pro"
|
|
|
|
|
|
async def test_get_response_usage_details() -> None:
|
|
"""Surfaces input, output, and total token counts from the API usage metadata."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(
|
|
return_value=_make_response(
|
|
[_make_part(text="Hi")],
|
|
prompt_tokens=20,
|
|
output_tokens=8,
|
|
total_tokens=28,
|
|
)
|
|
)
|
|
|
|
response = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Hi")])])
|
|
|
|
assert response.usage_details is not None
|
|
assert response.usage_details["input_token_count"] == 20
|
|
assert response.usage_details["output_token_count"] == 8
|
|
assert response.usage_details["total_token_count"] == 28
|
|
|
|
|
|
async def test_get_response_no_usage_when_metadata_absent() -> None:
|
|
"""Returns None for usage_details when the API response includes no usage metadata."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(
|
|
return_value=_make_response([_make_part(text="Hi")], prompt_tokens=None, output_tokens=None)
|
|
)
|
|
|
|
response = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Hi")])])
|
|
|
|
assert not response.usage_details
|
|
|
|
|
|
# finish reasons
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("gemini_reason", "expected"),
|
|
[
|
|
("STOP", "stop"),
|
|
("MAX_TOKENS", "length"),
|
|
("SAFETY", "content_filter"),
|
|
("RECITATION", "content_filter"),
|
|
("BLOCKLIST", "content_filter"),
|
|
("PROHIBITED_CONTENT", "content_filter"),
|
|
("SPII", "content_filter"),
|
|
("MALFORMED_FUNCTION_CALL", "tool_calls"),
|
|
("OTHER", None),
|
|
],
|
|
)
|
|
async def test_finish_reason_mapping(gemini_reason: str, expected: str | None) -> None:
|
|
"""Maps Gemini finish reason strings to the correct FinishReasonLiteral values."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(
|
|
return_value=_make_response([_make_part(text="Hi")], finish_reason=gemini_reason)
|
|
)
|
|
|
|
response = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Hi")])])
|
|
|
|
assert response.finish_reason == expected
|
|
|
|
|
|
# message conversion
|
|
|
|
|
|
async def test_system_message_extracted_to_system_instruction() -> None:
|
|
"""Extracts a system role message from the conversation and sends it as the system instruction."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
await client.get_response(
|
|
messages=[
|
|
Message(role="system", contents=[Content.from_text("You are concise.")]),
|
|
Message(role="user", contents=[Content.from_text("Hi")]),
|
|
]
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
assert config.system_instruction == "You are concise."
|
|
|
|
|
|
async def test_multiple_system_messages_concatenated() -> None:
|
|
"""Joins multiple system messages into a single system instruction string."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
await client.get_response(
|
|
messages=[
|
|
Message(role="system", contents=[Content.from_text("Be concise.")]),
|
|
Message(role="system", contents=[Content.from_text("Use bullet points.")]),
|
|
Message(role="user", contents=[Content.from_text("Hi")]),
|
|
]
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
assert "Be concise." in config.system_instruction
|
|
assert "Use bullet points." in config.system_instruction
|
|
|
|
|
|
async def test_instructions_option_merged_with_system_instruction() -> None:
|
|
"""Prepends the instructions option to the system message when both are present."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
await client.get_response(
|
|
messages=[
|
|
Message(role="system", contents=[Content.from_text("Be concise.")]),
|
|
Message(role="user", contents=[Content.from_text("Hi")]),
|
|
],
|
|
options={"instructions": "Always respond in French."},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
assert "Always respond in French." in config.system_instruction
|
|
assert "Be concise." in config.system_instruction
|
|
|
|
|
|
async def test_instructions_option_without_system_message() -> None:
|
|
"""Uses the instructions option as the sole system instruction when no system message is present."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={"instructions": "Be helpful."},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
assert config.system_instruction == "Be helpful."
|
|
|
|
|
|
async def test_assistant_role_mapped_to_model() -> None:
|
|
"""Maps the framework 'assistant' role to the 'model' role expected by the Gemini API."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Sure")]))
|
|
|
|
await client.get_response(
|
|
messages=[
|
|
Message(role="user", contents=[Content.from_text("Hello")]),
|
|
Message(role="assistant", contents=[Content.from_text("Hi there")]),
|
|
Message(role="user", contents=[Content.from_text("Follow up")]),
|
|
]
|
|
)
|
|
|
|
contents: list[types.Content] = mock.aio.models.generate_content.call_args.kwargs["contents"]
|
|
roles = [c.role for c in contents]
|
|
assert roles == ["user", "model", "user"]
|
|
|
|
|
|
async def test_tool_messages_collapsed_into_single_user_message() -> None:
|
|
"""Consecutive tool messages must be collapsed into one role='user' message
|
|
with multiple functionResponse parts (parallel tool call pattern).
|
|
"""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Done")]))
|
|
|
|
await client.get_response(
|
|
messages=[
|
|
Message(role="user", contents=[Content.from_text("Run both")]),
|
|
Message(
|
|
role="assistant",
|
|
contents=[
|
|
Content.from_function_call(call_id="c1", name="tool_a", arguments={}),
|
|
Content.from_function_call(call_id="c2", name="tool_b", arguments={}),
|
|
],
|
|
),
|
|
Message(role="tool", contents=[Content.from_function_result(call_id="c1", result="res_a")]),
|
|
Message(role="tool", contents=[Content.from_function_result(call_id="c2", result="res_b")]),
|
|
]
|
|
)
|
|
|
|
contents: list[types.Content] = mock.aio.models.generate_content.call_args.kwargs["contents"]
|
|
# user, model (with 2 function calls), user (with 2 function responses)
|
|
assert contents[-1].role == "user"
|
|
assert len(contents[-1].parts) == 2
|
|
|
|
|
|
async def test_function_result_name_resolved_from_call_history() -> None:
|
|
"""function_result name must come from the matching function_call in history."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Done")]))
|
|
|
|
await client.get_response(
|
|
messages=[
|
|
Message(role="user", contents=[Content.from_text("Go")]),
|
|
Message(
|
|
role="assistant",
|
|
contents=[Content.from_function_call(call_id="call-42", name="get_weather", arguments={})],
|
|
),
|
|
Message(role="tool", contents=[Content.from_function_result(call_id="call-42", result="sunny")]),
|
|
]
|
|
)
|
|
|
|
contents: list[types.Content] = mock.aio.models.generate_content.call_args.kwargs["contents"]
|
|
tool_user_msg = contents[-1]
|
|
assert tool_user_msg.role == "user"
|
|
function_response = tool_user_msg.parts[0].function_response
|
|
assert function_response.name == "get_weather"
|
|
assert function_response.id == "call-42"
|
|
|
|
|
|
async def test_function_result_resolved_when_call_id_was_generated() -> None:
|
|
"""When a function_call has no call_id and a fallback is generated, the subsequent
|
|
function_result referencing that generated ID must still resolve the function name.
|
|
"""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Done")]))
|
|
|
|
generated_id = "tool-call-generated-123"
|
|
with patch.object(client, "_generate_tool_call_id", return_value=generated_id):
|
|
await client.get_response(
|
|
messages=[
|
|
Message(role="user", contents=[Content.from_text("Go")]),
|
|
Message(
|
|
role="assistant",
|
|
contents=[Content.from_function_call(call_id=None, name="get_weather", arguments={})], # type: ignore[arg-type]
|
|
),
|
|
Message(
|
|
role="tool",
|
|
contents=[Content.from_function_result(call_id=generated_id, result="sunny")],
|
|
),
|
|
]
|
|
)
|
|
|
|
contents: list[types.Content] = mock.aio.models.generate_content.call_args.kwargs["contents"]
|
|
tool_turn = next(c for c in contents if c.role == "user" and any(p.function_response for p in c.parts))
|
|
assert tool_turn.parts[0].function_response.name == "get_weather"
|
|
assert tool_turn.parts[0].function_response.id == generated_id
|
|
|
|
|
|
async def test_function_result_without_matching_call_is_skipped(caplog: pytest.LogCaptureFixture) -> None:
|
|
"""A function_result with no prior function_call in history should be skipped with a warning."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Done")]))
|
|
|
|
with caplog.at_level(logging.WARNING, logger="agent_framework.gemini"):
|
|
await client.get_response(
|
|
messages=[
|
|
Message(role="user", contents=[Content.from_text("Go")]),
|
|
Message(
|
|
role="tool",
|
|
contents=[Content.from_function_result(call_id="unknown-id", result="oops")],
|
|
),
|
|
Message(role="user", contents=[Content.from_text("What happened?")]),
|
|
]
|
|
)
|
|
|
|
assert any("unknown-id" in r.message or "function_result" in r.message.lower() for r in caplog.records)
|
|
|
|
|
|
async def test_message_with_only_unsupported_content_type_is_skipped() -> None:
|
|
"""A user message whose contents produce no convertible parts is dropped from the request."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Done")]))
|
|
|
|
await client.get_response(
|
|
messages=[
|
|
Message(role="user", contents=[Content.from_function_result(call_id="x", result="y")]),
|
|
Message(role="user", contents=[Content.from_text("Follow up")]),
|
|
]
|
|
)
|
|
|
|
contents: list[types.Content] = mock.aio.models.generate_content.call_args.kwargs["contents"]
|
|
assert len(contents) == 1
|
|
assert contents[0].parts[0].text == "Follow up"
|
|
|
|
|
|
async def test_non_function_result_content_in_tool_message_is_skipped() -> None:
|
|
"""Unexpected content types inside a tool message are silently ignored."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Done")]))
|
|
|
|
await client.get_response(
|
|
messages=[
|
|
Message(role="user", contents=[Content.from_text("Hi")]),
|
|
Message(role="tool", contents=[Content.from_text("unexpected")]),
|
|
]
|
|
)
|
|
|
|
contents: list[types.Content] = mock.aio.models.generate_content.call_args.kwargs["contents"]
|
|
assert len(contents) == 1
|
|
|
|
|
|
# thinking parts
|
|
|
|
|
|
async def test_thinking_parts_are_silently_skipped() -> None:
|
|
"""Excludes thought-summary parts from ChatResponse.contents, returning only the final answer."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(
|
|
return_value=_make_response([
|
|
_make_part(text="I should think first...", thought=True),
|
|
_make_part(text="The answer is 42."),
|
|
])
|
|
)
|
|
|
|
response = await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("What is the answer?")])]
|
|
)
|
|
|
|
assert len(response.messages[0].contents) == 1
|
|
assert response.messages[0].text == "The answer is 42."
|
|
|
|
|
|
def test_function_call_part_preserves_thought_signature_from_raw_part() -> None:
|
|
"""Reuses the original Gemini Part so tool loops retain thought_signature metadata."""
|
|
client, _ = _make_gemini_client()
|
|
raw_part = types.Part(
|
|
function_call=types.FunctionCall(id="call-1", name="get_weather", args={"location": "Paris"}),
|
|
thought_signature=b"sig-123",
|
|
)
|
|
content = Content.from_function_call(
|
|
call_id="call-1",
|
|
name="get_weather",
|
|
arguments={"location": "Paris"},
|
|
raw_representation=raw_part,
|
|
)
|
|
|
|
parts = client._convert_message_contents([content], {})
|
|
|
|
assert len(parts) == 1
|
|
assert parts[0].thought_signature == b"sig-123"
|
|
assert parts[0].function_call is not None
|
|
assert parts[0].function_call.id == "call-1"
|
|
assert parts[0].function_call.name == "get_weather"
|
|
assert parts[0].function_call.args == {"location": "Paris"}
|
|
|
|
|
|
# code execution parts
|
|
|
|
|
|
async def test_executable_code_part_is_included_as_text() -> None:
|
|
"""executable_code parts are surfaced as text content so callers can see what code was run."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(
|
|
return_value=_make_response([
|
|
_make_part(executable_code="print(sum(range(10)))"),
|
|
_make_part(code_execution_result="45"),
|
|
_make_part(text="The sum of 0 through 9 is 45."),
|
|
])
|
|
)
|
|
|
|
response = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Sum 0 to 9")])])
|
|
|
|
texts = [c.text for c in response.messages[0].contents if c.text]
|
|
assert "print(sum(range(10)))" in texts
|
|
assert "45" in texts
|
|
assert "The sum of 0 through 9 is 45." in texts
|
|
|
|
|
|
async def test_unknown_part_type_is_skipped() -> None:
|
|
"""Parts with no recognised field set are silently skipped."""
|
|
client, mock = _make_gemini_client()
|
|
unknown_part = MagicMock()
|
|
unknown_part.thought = False
|
|
unknown_part.text = None
|
|
unknown_part.function_call = None
|
|
unknown_part.function_response = None
|
|
unknown_part.executable_code = None
|
|
unknown_part.code_execution_result = None
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([unknown_part, _make_part(text="Hi")]))
|
|
|
|
response = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Hi")])])
|
|
|
|
assert len(response.messages[0].contents) == 1
|
|
assert response.messages[0].text == "Hi"
|
|
|
|
|
|
async def test_empty_executable_code_part_is_skipped() -> None:
|
|
"""executable_code parts with no code string produce no Content entry."""
|
|
client, mock = _make_gemini_client()
|
|
mock_part = MagicMock()
|
|
mock_part.text = None
|
|
mock_part.thought = False
|
|
mock_part.function_call = None
|
|
mock_part.function_response = None
|
|
mock_part.code_execution_result = None
|
|
mock_part.executable_code = MagicMock()
|
|
mock_part.executable_code.code = ""
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([mock_part, _make_part(text="Done.")]))
|
|
|
|
response = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Hi")])])
|
|
|
|
assert len(response.messages[0].contents) == 1
|
|
assert response.messages[0].text == "Done."
|
|
|
|
|
|
# generation config options
|
|
|
|
|
|
async def test_prepare_config_temperature() -> None:
|
|
"""Forwards the temperature option to GenerateContentConfig.temperature."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={"temperature": 0.3},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
assert config.temperature == 0.3
|
|
|
|
|
|
async def test_prepare_config_max_tokens() -> None:
|
|
"""Forwards max_tokens to GenerateContentConfig.max_output_tokens."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={"max_tokens": 512},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
assert config.max_output_tokens == 512
|
|
|
|
|
|
async def test_prepare_config_top_p_and_top_k() -> None:
|
|
"""Forwards top_p and top_k to their respective GenerateContentConfig fields."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={"top_p": 0.9, "top_k": 40},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
assert config.top_p == 0.9
|
|
assert config.top_k == 40
|
|
|
|
|
|
async def test_prepare_config_stop_sequences() -> None:
|
|
"""Forwards the stop option to GenerateContentConfig.stop_sequences."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={"stop": ["END", "STOP"]},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
assert config.stop_sequences == ["END", "STOP"]
|
|
|
|
|
|
async def test_prepare_config_seed() -> None:
|
|
"""Forwards the seed option to GenerateContentConfig.seed for reproducible outputs."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={"seed": 42},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
assert config.seed == 42
|
|
|
|
|
|
async def test_prepare_config_frequency_and_presence_penalty() -> None:
|
|
"""Forwards frequency_penalty and presence_penalty to their GenerateContentConfig equivalents."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={"frequency_penalty": 0.5, "presence_penalty": 0.2},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
assert config.frequency_penalty == 0.5
|
|
assert config.presence_penalty == 0.2
|
|
|
|
|
|
async def test_prepare_config_unknown_key_is_forwarded() -> None:
|
|
"""Keys absent from _OPTION_EXCLUDE_KEYS and _OPTION_TRANSLATIONS are forwarded as-is."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
with patch("agent_framework_gemini._chat_client.types.GenerateContentConfig") as mock_config:
|
|
mock_config.return_value = MagicMock()
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={"some_future_param": "value"},
|
|
)
|
|
assert mock_config.call_args.kwargs.get("some_future_param") == "value"
|
|
|
|
|
|
async def test_prepare_config_consumed_keys_are_excluded() -> None:
|
|
"""Keys consumed upstream (model, instructions) are not forwarded to GenerateContentConfig."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
with patch("agent_framework_gemini._chat_client.types.GenerateContentConfig") as mock_config:
|
|
mock_config.return_value = MagicMock()
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={"model": "gemini-2.5-pro", "instructions": "Be helpful."},
|
|
)
|
|
kwargs = mock_config.call_args.kwargs
|
|
assert "model" not in kwargs
|
|
assert "instructions" not in kwargs
|
|
|
|
|
|
# thinking config
|
|
|
|
|
|
async def test_thinking_config_budget() -> None:
|
|
"""Passes thinking_budget through to GenerateContentConfig.thinking_config."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
tc: ThinkingConfig = {"thinking_budget": 1024}
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={"thinking_config": tc},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
assert isinstance(config.thinking_config, types.ThinkingConfig)
|
|
assert config.thinking_config.thinking_budget == 1024
|
|
|
|
|
|
async def test_thinking_config_level() -> None:
|
|
"""Passes thinking_level through to GenerateContentConfig.thinking_config."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
tc: ThinkingConfig = {"thinking_level": types.ThinkingLevel.HIGH}
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={"thinking_config": tc},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
assert isinstance(config.thinking_config, types.ThinkingConfig)
|
|
assert config.thinking_config.thinking_level == types.ThinkingLevel.HIGH
|
|
|
|
|
|
# structured output
|
|
|
|
|
|
async def test_response_format_sets_json_mime_type() -> None:
|
|
"""Sets response_mime_type to application/json when response_format is given."""
|
|
from pydantic import BaseModel
|
|
|
|
class Reply(BaseModel):
|
|
text: str
|
|
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="{}")]))
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={"response_format": Reply},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
assert config.response_mime_type == "application/json"
|
|
|
|
|
|
async def test_response_format_populates_value_on_chat_response() -> None:
|
|
"""When response_format is a Pydantic model, ChatResponse.value must be parsed from the response text."""
|
|
from pydantic import BaseModel
|
|
|
|
class Reply(BaseModel):
|
|
text: str
|
|
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text='{"text": "hello"}')]))
|
|
|
|
response = await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={"response_format": Reply},
|
|
)
|
|
|
|
assert response.value == Reply(text="hello")
|
|
|
|
|
|
async def test_response_schema_added_to_config() -> None:
|
|
"""Sets both response_mime_type and the raw schema on the config when response_schema is given."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="{}")]))
|
|
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={"response_schema": schema},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
assert config.response_mime_type == "application/json"
|
|
assert config.response_schema == schema
|
|
|
|
|
|
async def test_streaming_response_format_passed_to_build_response_stream() -> None:
|
|
"""Verifies that response_format is forwarded to _build_response_stream when streaming
|
|
so that structured output parsing works correctly on the final assembled response.
|
|
"""
|
|
from unittest.mock import patch
|
|
|
|
from pydantic import BaseModel
|
|
|
|
class Reply(BaseModel):
|
|
text: str
|
|
|
|
client, mock = _make_gemini_client()
|
|
chunks = [_make_response([_make_part(text='{"text": "hello"}')], finish_reason="STOP")]
|
|
mock.aio.models.generate_content_stream = AsyncMock(return_value=_async_iter(chunks))
|
|
|
|
with patch.object(client, "_build_response_stream", wraps=client._build_response_stream) as spy:
|
|
stream = client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={"response_format": Reply},
|
|
stream=True,
|
|
)
|
|
async for _ in stream:
|
|
pass
|
|
|
|
_, kwargs = spy.call_args
|
|
assert kwargs.get("response_format") is Reply
|
|
|
|
|
|
# tool calling
|
|
|
|
|
|
async def test_function_call_in_response_mapped_to_content() -> None:
|
|
"""Maps a function_call part in the response to a function_call Content with the correct name and call ID."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(
|
|
return_value=_make_response([_make_part(function_call=("call-1", "get_weather", {"city": "Berlin"}))])
|
|
)
|
|
|
|
response = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Weather?")])])
|
|
|
|
fc = response.messages[0].contents[0]
|
|
assert fc.type == "function_call"
|
|
assert fc.name == "get_weather"
|
|
assert fc.call_id == "call-1"
|
|
|
|
|
|
async def test_function_call_missing_id_gets_fallback() -> None:
|
|
"""Older Gemini models may omit function_call.id — a UUID fallback must be generated."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(
|
|
return_value=_make_response([
|
|
_make_part(function_call=(None, "search", {"q": "test"})) # id is None
|
|
])
|
|
)
|
|
|
|
response = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Search")])])
|
|
|
|
fc = response.messages[0].contents[0]
|
|
assert fc.call_id is not None
|
|
assert len(fc.call_id) > 0
|
|
|
|
|
|
async def test_function_tool_converted_to_function_declaration() -> None:
|
|
"""Translates a FunctionTool in the tools list into a FunctionDeclaration in the generation config."""
|
|
|
|
def get_weather(city: str) -> str:
|
|
"""Get the weather for a city."""
|
|
return "sunny"
|
|
|
|
tool = FunctionTool(name="get_weather", func=get_weather)
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Done")]))
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Weather?")])],
|
|
options={"tools": [tool]},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
assert config.tools is not None
|
|
assert len(config.tools) == 1
|
|
function_declaration = config.tools[0].function_declarations[0]
|
|
assert function_declaration.name == "get_weather"
|
|
|
|
|
|
async def test_callable_tool_resolved_via_validate_options() -> None:
|
|
"""Raw callables passed as tools must be normalized by _validate_options into FunctionTools
|
|
and reach the Gemini config as function declarations.
|
|
"""
|
|
|
|
def get_weather(city: str) -> str:
|
|
"""Get the weather for a city."""
|
|
return "sunny"
|
|
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Done")]))
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Weather?")])],
|
|
options={"tools": [get_weather]},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
assert config.tools is not None
|
|
function_declaration = config.tools[0].function_declarations[0]
|
|
assert function_declaration.name == "get_weather"
|
|
|
|
|
|
# _coerce_to_dict
|
|
|
|
|
|
def test_coerce_to_dict_with_dict_input() -> None:
|
|
"""Returns a dict value unchanged."""
|
|
assert GeminiChatClient._coerce_to_dict({"key": "value"}) == {"key": "value"}
|
|
|
|
|
|
def test_coerce_to_dict_with_json_string() -> None:
|
|
"""Parses a JSON object string into a dict."""
|
|
assert GeminiChatClient._coerce_to_dict('{"key": "value"}') == {"key": "value"}
|
|
|
|
|
|
def test_coerce_to_dict_with_plain_string() -> None:
|
|
"""Wraps a plain non-JSON string as {'result': value}."""
|
|
assert GeminiChatClient._coerce_to_dict("some text") == {"result": "some text"}
|
|
|
|
|
|
def test_coerce_to_dict_with_none() -> None:
|
|
"""Coerces None to {'result': ''}."""
|
|
assert GeminiChatClient._coerce_to_dict(None) == {"result": ""}
|
|
|
|
|
|
def test_coerce_to_dict_with_numeric_value() -> None:
|
|
"""Wraps a numeric value as {'result': str(value)}."""
|
|
assert GeminiChatClient._coerce_to_dict(42) == {"result": "42"}
|
|
|
|
|
|
def test_coerce_to_dict_with_json_array_string() -> None:
|
|
"""Wraps a JSON array string as {'result': value} because it is not a dict."""
|
|
assert GeminiChatClient._coerce_to_dict("[1, 2, 3]") == {"result": "[1, 2, 3]"}
|
|
|
|
|
|
def test_coerce_to_dict_with_json_string_literal() -> None:
|
|
"""Wraps a JSON string literal as {'result': value} because it is not a dict."""
|
|
assert GeminiChatClient._coerce_to_dict('"hello"') == {"result": '"hello"'}
|
|
|
|
|
|
# tool choice
|
|
|
|
|
|
def _get_function_calling_mode(config: types.GenerateContentConfig) -> str:
|
|
return config.tool_config.function_calling_config.mode
|
|
|
|
|
|
def _make_dummy_tool() -> FunctionTool:
|
|
def dummy(x: int) -> int:
|
|
"""Dummy."""
|
|
return x
|
|
|
|
return FunctionTool(name="dummy", func=dummy)
|
|
|
|
|
|
async def _get_config_for_tool_choice(tool_choice: str) -> types.GenerateContentConfig:
|
|
tool = _make_dummy_tool()
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={"tools": [tool], "tool_choice": tool_choice},
|
|
)
|
|
|
|
return mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
|
|
|
|
async def test_tool_choice_auto_maps_to_AUTO() -> None:
|
|
"""Maps 'auto' tool_choice to FunctionCallingConfigMode.AUTO."""
|
|
config = await _get_config_for_tool_choice("auto")
|
|
assert _get_function_calling_mode(config) == "AUTO"
|
|
|
|
|
|
async def test_tool_choice_none_maps_to_NONE() -> None:
|
|
"""Maps 'none' tool_choice to FunctionCallingConfigMode.NONE."""
|
|
config = await _get_config_for_tool_choice("none")
|
|
assert _get_function_calling_mode(config) == "NONE"
|
|
|
|
|
|
async def test_tool_choice_required_maps_to_ANY() -> None:
|
|
"""Maps 'required' tool_choice to FunctionCallingConfigMode.ANY."""
|
|
config = await _get_config_for_tool_choice("required")
|
|
assert _get_function_calling_mode(config) == "ANY"
|
|
|
|
|
|
async def test_tool_choice_required_with_name_sets_allowed_function_names() -> None:
|
|
tool = _make_dummy_tool()
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={
|
|
"tools": [tool],
|
|
"tool_choice": {"mode": "required", "required_function_name": "dummy"},
|
|
},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
function_calling_config = config.tool_config.function_calling_config
|
|
assert function_calling_config.mode == "ANY"
|
|
assert "dummy" in function_calling_config.allowed_function_names
|
|
|
|
|
|
async def test_unknown_tool_choice_mode_is_ignored() -> None:
|
|
"""Produces no tool_config in the generation config when the tool_choice mode is unrecognised."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
with patch("agent_framework_gemini._chat_client.validate_tool_mode", return_value={"mode": "unsupported"}):
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={"tool_choice": "auto"},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
assert not hasattr(config, "tool_config") or config.tool_config is None
|
|
|
|
|
|
async def test_tool_choice_auto_with_allowed_tools_uses_VALIDATED() -> None:
|
|
"""Maps auto + allowed_tools to FunctionCallingConfigMode.VALIDATED with allowed_function_names."""
|
|
tool = _make_dummy_tool()
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={
|
|
"tools": [tool],
|
|
"tool_choice": {"mode": "auto", "allowed_tools": ["dummy", "other"]},
|
|
},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
function_calling_config = config.tool_config.function_calling_config
|
|
assert function_calling_config.mode == "VALIDATED"
|
|
assert function_calling_config.allowed_function_names == ["dummy", "other"]
|
|
|
|
|
|
async def test_tool_choice_auto_with_empty_allowed_tools_uses_VALIDATED() -> None:
|
|
"""Maps auto + empty allowed_tools to VALIDATED with empty allowed_function_names."""
|
|
tool = _make_dummy_tool()
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={
|
|
"tools": [tool],
|
|
"tool_choice": {"mode": "auto", "allowed_tools": []},
|
|
},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
function_calling_config = config.tool_config.function_calling_config
|
|
assert function_calling_config.mode == "VALIDATED"
|
|
assert function_calling_config.allowed_function_names == []
|
|
|
|
|
|
async def test_tool_choice_required_with_allowed_tools_uses_ANY() -> None:
|
|
"""Maps required + allowed_tools to ANY with allowed_function_names."""
|
|
tool = _make_dummy_tool()
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={
|
|
"tools": [tool],
|
|
"tool_choice": {"mode": "required", "allowed_tools": ["dummy"]},
|
|
},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
function_calling_config = config.tool_config.function_calling_config
|
|
assert function_calling_config.mode == "ANY"
|
|
assert function_calling_config.allowed_function_names == ["dummy"]
|
|
|
|
|
|
async def test_tool_choice_required_function_name_takes_precedence_over_allowed_tools() -> None:
|
|
"""When both required_function_name and allowed_tools are present, required_function_name wins."""
|
|
tool = _make_dummy_tool()
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
options={
|
|
"tools": [tool],
|
|
"tool_choice": {"mode": "required", "required_function_name": "dummy", "allowed_tools": ["other"]},
|
|
},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
function_calling_config = config.tool_config.function_calling_config
|
|
assert function_calling_config.mode == "ANY"
|
|
assert function_calling_config.allowed_function_names == ["dummy"]
|
|
|
|
|
|
# built-in tool factories
|
|
|
|
|
|
def test_get_web_search_tool_returns_google_search_tool() -> None:
|
|
"""get_web_search_tool returns a types.Tool with google_search set."""
|
|
tool = GeminiChatClient.get_web_search_tool()
|
|
assert isinstance(tool, types.Tool)
|
|
assert tool.google_search is not None
|
|
|
|
|
|
def test_get_web_search_tool_with_params() -> None:
|
|
"""Parameters are forwarded to types.GoogleSearch."""
|
|
time_range = types.Interval(
|
|
start_time=datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc),
|
|
end_time=datetime.datetime(2024, 12, 31, tzinfo=datetime.timezone.utc),
|
|
)
|
|
tool = GeminiChatClient.get_web_search_tool(
|
|
search_types=types.SearchTypes(web_search=types.WebSearch()),
|
|
blocking_confidence=types.PhishBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
|
exclude_domains=["example.com"],
|
|
time_range_filter=time_range,
|
|
)
|
|
assert tool.google_search is not None
|
|
assert tool.google_search.search_types is not None
|
|
assert tool.google_search.blocking_confidence == types.PhishBlockThreshold.BLOCK_LOW_AND_ABOVE
|
|
assert tool.google_search.exclude_domains == ["example.com"]
|
|
assert tool.google_search.time_range_filter == time_range
|
|
|
|
|
|
def test_get_code_interpreter_tool_returns_code_execution_tool() -> None:
|
|
"""get_code_interpreter_tool returns a types.Tool with code_execution set."""
|
|
tool = GeminiChatClient.get_code_interpreter_tool()
|
|
assert isinstance(tool, types.Tool)
|
|
assert tool.code_execution is not None
|
|
|
|
|
|
def test_get_maps_grounding_tool_returns_google_maps_tool() -> None:
|
|
"""get_maps_grounding_tool returns a types.Tool with google_maps set."""
|
|
tool = GeminiChatClient.get_maps_grounding_tool()
|
|
assert isinstance(tool, types.Tool)
|
|
assert tool.google_maps is not None
|
|
|
|
|
|
def test_get_maps_grounding_tool_with_params() -> None:
|
|
"""Parameters are forwarded to types.GoogleMaps."""
|
|
auth = types.AuthConfig(api_key="test-key")
|
|
tool = GeminiChatClient.get_maps_grounding_tool(enable_widget=True, auth_config=auth)
|
|
assert tool.google_maps is not None
|
|
assert tool.google_maps.enable_widget is True
|
|
assert tool.google_maps.auth_config == auth
|
|
|
|
|
|
def test_get_file_search_tool_returns_file_search_tool() -> None:
|
|
"""get_file_search_tool returns a types.Tool with file_search set."""
|
|
tool = GeminiChatClient.get_file_search_tool(file_search_store_names=["stores/my-store"])
|
|
assert isinstance(tool, types.Tool)
|
|
assert tool.file_search is not None
|
|
assert tool.file_search.file_search_store_names == ["stores/my-store"]
|
|
|
|
|
|
def test_get_file_search_tool_with_params() -> None:
|
|
"""Parameters are forwarded to types.FileSearch."""
|
|
tool = GeminiChatClient.get_file_search_tool(
|
|
file_search_store_names=["stores/my-store"],
|
|
top_k=5,
|
|
metadata_filter="type='pdf'",
|
|
)
|
|
assert tool.file_search is not None
|
|
assert tool.file_search.top_k == 5
|
|
assert tool.file_search.metadata_filter == "type='pdf'"
|
|
|
|
|
|
def test_get_mcp_tool_returns_mcp_server_tool() -> None:
|
|
"""get_mcp_tool returns a types.Tool with a single McpServer entry."""
|
|
tool = GeminiChatClient.get_mcp_tool(name="my-mcp", url="https://mcp.example.com/sse")
|
|
assert isinstance(tool, types.Tool)
|
|
assert tool.mcp_servers is not None
|
|
assert len(tool.mcp_servers) == 1
|
|
server = tool.mcp_servers[0]
|
|
assert server.name == "my-mcp"
|
|
assert server.streamable_http_transport is not None
|
|
assert server.streamable_http_transport.url == "https://mcp.example.com/sse"
|
|
|
|
|
|
def test_get_mcp_tool_forwards_transport_kwargs() -> None:
|
|
"""Transport keyword arguments are passed through to StreamableHttpTransport."""
|
|
tool = GeminiChatClient.get_mcp_tool(
|
|
name="secure-mcp",
|
|
url="https://mcp.example.com/sse",
|
|
headers={"Authorization": "Bearer token"},
|
|
)
|
|
server = tool.mcp_servers[0] # type: ignore[index]
|
|
assert server.streamable_http_transport.headers == {"Authorization": "Bearer token"}
|
|
|
|
|
|
async def test_types_tool_passed_in_tools_list_is_forwarded() -> None:
|
|
"""A types.Tool in the tools list is passed through directly to the Gemini config."""
|
|
client, mock = _make_gemini_client()
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Result")]))
|
|
search_tool = GeminiChatClient.get_web_search_tool()
|
|
|
|
await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Search")])],
|
|
options={"tools": [search_tool]},
|
|
)
|
|
|
|
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
|
|
assert config.tools is not None
|
|
assert any(tool.google_search for tool in config.tools)
|
|
|
|
|
|
async def test_function_response_part_in_response_mapped_to_content() -> None:
|
|
"""A function_response part echoed back in a model response is mapped to a function_result Content."""
|
|
client, mock = _make_gemini_client()
|
|
part = MagicMock()
|
|
part.text = None
|
|
part.thought = False
|
|
part.function_call = None
|
|
part.function_response = MagicMock()
|
|
part.function_response.id = "call-99"
|
|
part.function_response.response = {"result": "done"}
|
|
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([part]))
|
|
|
|
response = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Hi")])])
|
|
|
|
assert response.messages[0].contents[0].type == "function_result"
|
|
|
|
|
|
# streaming
|
|
|
|
|
|
async def test_streaming_yields_text_chunks() -> None:
|
|
"""Yields incremental text updates that together form the complete response."""
|
|
client, mock = _make_gemini_client()
|
|
chunks = [
|
|
_make_response([_make_part(text="Hello ")], finish_reason=None, prompt_tokens=None, output_tokens=None),
|
|
_make_response([_make_part(text="world!")], finish_reason="STOP"),
|
|
]
|
|
mock.aio.models.generate_content_stream = AsyncMock(return_value=_async_iter(chunks))
|
|
|
|
stream = client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
stream=True,
|
|
)
|
|
|
|
updates = [update async for update in stream]
|
|
text = "".join(u.text or "" for u in updates)
|
|
assert "Hello" in text
|
|
assert "world" in text
|
|
|
|
|
|
async def test_streaming_function_call_emitted_immediately() -> None:
|
|
"""Function calls in streaming chunks must be emitted as they arrive, not deferred."""
|
|
client, mock = _make_gemini_client()
|
|
chunks = [
|
|
_make_response(
|
|
[_make_part(function_call=("call-1", "search", {"q": "test"}))],
|
|
finish_reason=None,
|
|
prompt_tokens=None,
|
|
output_tokens=None,
|
|
),
|
|
_make_response([_make_part(text="Done")], finish_reason="STOP"),
|
|
]
|
|
mock.aio.models.generate_content_stream = AsyncMock(return_value=_async_iter(chunks))
|
|
|
|
stream = client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Search")])],
|
|
stream=True,
|
|
)
|
|
|
|
all_contents = []
|
|
async for update in stream:
|
|
all_contents.extend(update.contents)
|
|
|
|
function_calls = [c for c in all_contents if c.type == "function_call"]
|
|
assert len(function_calls) == 1
|
|
assert function_calls[0].name == "search"
|
|
|
|
|
|
async def test_streaming_finish_reason_only_on_last_chunk() -> None:
|
|
"""Sets finish_reason only on the final chunk; intermediate chunks have it as None."""
|
|
client, mock = _make_gemini_client()
|
|
chunks = [
|
|
_make_response([_make_part(text="Hello ")], finish_reason=None, prompt_tokens=None, output_tokens=None),
|
|
_make_response([_make_part(text="world!")], finish_reason="STOP"),
|
|
]
|
|
mock.aio.models.generate_content_stream = AsyncMock(return_value=_async_iter(chunks))
|
|
|
|
stream = client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
stream=True,
|
|
)
|
|
|
|
updates = [update async for update in stream]
|
|
assert updates[0].finish_reason is None
|
|
assert updates[-1].finish_reason == "stop"
|
|
|
|
|
|
async def test_streaming_usage_only_on_final_chunk() -> None:
|
|
"""Attaches usage content only to the final chunk, not to intermediate ones."""
|
|
client, mock = _make_gemini_client()
|
|
chunks = [
|
|
_make_response([_make_part(text="Hello ")], finish_reason=None, prompt_tokens=None, output_tokens=None),
|
|
_make_response([_make_part(text="world!")], finish_reason="STOP", prompt_tokens=10, output_tokens=5),
|
|
]
|
|
mock.aio.models.generate_content_stream = AsyncMock(return_value=_async_iter(chunks))
|
|
|
|
stream = client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
stream=True,
|
|
)
|
|
|
|
updates = [update async for update in stream]
|
|
assert not any(c.type == "usage" for c in updates[0].contents)
|
|
assert any(c.type == "usage" for c in updates[-1].contents)
|
|
|
|
|
|
async def test_streaming_get_final_response() -> None:
|
|
"""get_final_response() must return a fully assembled ChatResponse after the stream is exhausted."""
|
|
client, mock = _make_gemini_client()
|
|
chunks = [
|
|
_make_response([_make_part(text="Hello ")], finish_reason=None, prompt_tokens=None, output_tokens=None),
|
|
_make_response([_make_part(text="world!")], finish_reason="STOP", prompt_tokens=10, output_tokens=5),
|
|
]
|
|
mock.aio.models.generate_content_stream = AsyncMock(return_value=_async_iter(chunks))
|
|
|
|
stream = client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
stream=True,
|
|
)
|
|
|
|
async for _ in stream:
|
|
pass
|
|
|
|
final = await stream.get_final_response()
|
|
|
|
assert final.messages[0].text == "Hello world!"
|
|
assert final.finish_reason == "stop"
|
|
assert final.usage_details is not None
|
|
assert final.usage_details["input_token_count"] == 10
|
|
assert final.usage_details["output_token_count"] == 5
|
|
|
|
|
|
# The Gemini API returns a list of candidates, each representing a possible response from the model.
|
|
# In practice only one candidate is returned, but the list can be empty or None if the request
|
|
# was blocked by safety filters or the API returned an unexpected response.
|
|
|
|
|
|
@pytest.mark.parametrize("candidates", [None, []])
|
|
async def test_empty_candidates_returns_empty_message(candidates: list | None) -> None:
|
|
"""An API response with no candidates must not raise and must return an empty assistant message."""
|
|
client, mock = _make_gemini_client()
|
|
response = _make_response([])
|
|
response.candidates = candidates
|
|
mock.aio.models.generate_content = AsyncMock(return_value=response)
|
|
|
|
result = await client.get_response(messages=[Message(role="user", contents=[Content.from_text("Hi")])])
|
|
|
|
assert result.messages[0].role == "assistant"
|
|
assert result.messages[0].contents == []
|
|
assert result.finish_reason is None
|
|
|
|
|
|
@pytest.mark.parametrize("candidates", [None, []])
|
|
async def test_empty_candidates_in_stream_does_not_raise(candidates: list | None) -> None:
|
|
"""A streaming chunk with no candidates must not raise and must yield an empty update."""
|
|
client, mock = _make_gemini_client()
|
|
chunk = _make_response([], finish_reason=None, prompt_tokens=None, output_tokens=None)
|
|
chunk.candidates = candidates
|
|
mock.aio.models.generate_content_stream = AsyncMock(return_value=_async_iter([chunk]))
|
|
|
|
updates = [
|
|
update
|
|
async for update in client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
|
|
stream=True,
|
|
)
|
|
]
|
|
|
|
assert len(updates) == 1
|
|
assert updates[0].contents == []
|
|
assert updates[0].finish_reason is None
|
|
|
|
|
|
# service_url
|
|
|
|
|
|
def test_service_url() -> None:
|
|
"""Returns the Gemini API base URL."""
|
|
client, _ = _make_gemini_client()
|
|
assert client.service_url() == "https://generativelanguage.googleapis.com"
|
|
|
|
|
|
def test_service_url_falls_back_when_sdk_base_url_is_unavailable() -> None:
|
|
"""Falls back to the known service URL when the SDK client does not expose a base URL."""
|
|
gemini_sdk_client = MagicMock()
|
|
gemini_sdk_client._api_client.vertexai = False
|
|
gemini_client = GeminiChatClient(client=gemini_sdk_client, model="gemini-2.5-flash")
|
|
|
|
vertex_sdk_client = MagicMock()
|
|
vertex_sdk_client._api_client.vertexai = True
|
|
vertex_client = GeminiChatClient(client=vertex_sdk_client, model="gemini-2.5-flash")
|
|
|
|
assert gemini_client.service_url() == "https://generativelanguage.googleapis.com"
|
|
assert vertex_client.service_url() == "https://aiplatform.googleapis.com"
|
|
|
|
|
|
# integration tests
|
|
|
|
|
|
@pytest.mark.flaky
|
|
@pytest.mark.integration
|
|
@skip_if_no_credentials
|
|
async def test_integration_basic_chat() -> None:
|
|
"""Basic request/response round-trip returns a non-empty text reply."""
|
|
client = GeminiChatClient(model=_TEST_MODEL)
|
|
response = await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Reply with the single word: hello")])]
|
|
)
|
|
|
|
assert response.messages
|
|
assert response.messages[0].text
|
|
|
|
|
|
@pytest.mark.flaky
|
|
@pytest.mark.integration
|
|
@skip_if_no_credentials
|
|
async def test_integration_streaming() -> None:
|
|
"""Streaming yields multiple chunks that together form a non-empty response."""
|
|
client = GeminiChatClient(model=_TEST_MODEL)
|
|
stream = client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("Count from 1 to 5.")])],
|
|
stream=True,
|
|
)
|
|
|
|
chunks = [update async for update in stream]
|
|
assert len(chunks) > 0
|
|
full_text = "".join(u.text or "" for u in chunks)
|
|
assert full_text
|
|
|
|
|
|
@pytest.mark.flaky
|
|
@pytest.mark.integration
|
|
@skip_if_no_credentials
|
|
async def test_integration_structured_output() -> None:
|
|
"""Structured output with a Pydantic response_format returns a parsed value via response.value."""
|
|
|
|
class Answer(BaseModel):
|
|
answer: str
|
|
|
|
client = GeminiChatClient(model=_TEST_MODEL)
|
|
|
|
response = await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("What is the capital of Germany?")])],
|
|
options={"response_format": Answer},
|
|
)
|
|
|
|
assert response.value is not None
|
|
assert isinstance(response.value, Answer)
|
|
assert response.value.answer
|
|
|
|
|
|
@pytest.mark.flaky
|
|
@pytest.mark.integration
|
|
@skip_if_no_credentials
|
|
async def test_integration_tool_calling() -> None:
|
|
"""Model invokes the registered tool when asked a question that requires it."""
|
|
|
|
def get_temperature(city: str) -> str:
|
|
"""Return the current temperature for a city."""
|
|
return f"22°C in {city}"
|
|
|
|
tool = FunctionTool(name="get_temperature", func=get_temperature)
|
|
client = GeminiChatClient(model=_TEST_MODEL)
|
|
|
|
response = await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("What is the temperature in Berlin?")])],
|
|
options={"tools": [tool], "tool_choice": "required"},
|
|
)
|
|
|
|
function_calls = [c for c in response.messages[0].contents if c.type == "function_call"]
|
|
assert len(function_calls) >= 1
|
|
assert function_calls[0].name == "get_temperature"
|
|
|
|
|
|
@pytest.mark.flaky
|
|
@pytest.mark.integration
|
|
@skip_if_no_credentials
|
|
async def test_integration_thinking_config() -> None:
|
|
"""Model accepts a thinking budget and returns a non-empty text reply."""
|
|
options: GeminiChatOptions = {"thinking_config": ThinkingConfig(thinking_budget=512)}
|
|
client = GeminiChatClient(model=_TEST_MODEL)
|
|
|
|
response = await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("What is 17 * 34?")])],
|
|
options=options,
|
|
)
|
|
|
|
assert response.messages
|
|
assert response.messages[0].text
|
|
|
|
|
|
@pytest.mark.flaky
|
|
@pytest.mark.integration
|
|
@skip_if_no_credentials
|
|
async def test_integration_google_search_grounding() -> None:
|
|
"""Google Search grounding returns a non-empty response for a current-events question."""
|
|
client = GeminiChatClient(model=_TEST_MODEL)
|
|
|
|
response = await client.get_response(
|
|
messages=[Message(role="user", contents=[Content.from_text("What is the latest stable version of Python?")])],
|
|
options={"tools": [GeminiChatClient.get_web_search_tool()]},
|
|
)
|
|
|
|
assert response.messages
|
|
assert response.messages[0].text
|
|
|
|
|
|
@pytest.mark.flaky
|
|
@pytest.mark.integration
|
|
@skip_if_no_credentials
|
|
async def test_integration_google_maps_grounding() -> None:
|
|
"""Google Maps grounding returns a non-empty response for a location-based question."""
|
|
client = GeminiChatClient(model=_TEST_MODEL)
|
|
|
|
response = await client.get_response(
|
|
messages=[
|
|
Message(
|
|
role="user",
|
|
contents=[Content.from_text("What are some highly rated restaurants in Karlsruhe city center?")],
|
|
)
|
|
],
|
|
options={"tools": [GeminiChatClient.get_maps_grounding_tool()]},
|
|
)
|
|
|
|
assert response.messages
|
|
assert response.messages[0].text
|
|
|
|
|
|
@pytest.mark.flaky
|
|
@pytest.mark.integration
|
|
@skip_if_no_credentials
|
|
async def test_integration_code_execution() -> None:
|
|
"""Code execution tool produces a non-empty response for a computation request."""
|
|
client = GeminiChatClient(model=_TEST_MODEL)
|
|
|
|
response = await client.get_response(
|
|
messages=[
|
|
Message(
|
|
role="user",
|
|
contents=[Content.from_text("Compute the sum of the first 100 natural numbers using code.")],
|
|
)
|
|
],
|
|
options={"tools": [GeminiChatClient.get_code_interpreter_tool()]},
|
|
)
|
|
|
|
assert response.messages
|
|
assert response.messages[0].text
|