Merge branch 'main' into feature/python-foundry-hosted-agent-vnext

This commit is contained in:
Tao Chen
2026-04-16 13:55:04 -07:00
Unverified
107 changed files with 8359 additions and 144 deletions
@@ -374,6 +374,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
contents=contents,
role="assistant" if item.role == A2ARole.agent else "user",
response_id=str(getattr(item, "message_id", uuid.uuid4())),
additional_properties={"a2a_metadata": item.metadata} if item.metadata else None,
raw_representation=item,
)
all_updates.append(update)
@@ -452,13 +453,24 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
role=message.role,
response_id=task.id,
message_id=getattr(message.raw_representation, "artifact_id", None),
additional_properties={"a2a_metadata": merged}
if (merged := {**message.additional_properties, **(task.metadata or {})})
else None,
raw_representation=task,
)
for message in task_messages
]
if task.artifacts is not None:
return []
return [AgentResponseUpdate(contents=[], role="assistant", response_id=task.id, raw_representation=task)]
return [
AgentResponseUpdate(
contents=[],
role="assistant",
response_id=task.id,
additional_properties={"a2a_metadata": task.metadata} if task.metadata else None,
raw_representation=task,
)
]
if background and status.state in IN_PROGRESS_TASK_STATES:
token = self._build_continuation_token(task)
@@ -468,6 +480,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
role="assistant",
response_id=task.id,
continuation_token=token,
additional_properties={"a2a_metadata": task.metadata} if task.metadata else None,
raw_representation=task,
)
]
@@ -488,6 +501,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
contents=contents,
role="assistant" if status.message.role == A2ARole.agent else "user",
response_id=task.id,
additional_properties={"a2a_metadata": task.metadata} if task.metadata else None,
raw_representation=task,
)
]
@@ -502,12 +516,17 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
contents = self._parse_contents_from_a2a(update_event.artifact.parts)
if not contents:
return []
merged_metadata = {
**(update_event.artifact.metadata or {}),
**(update_event.metadata or {}),
} or None
return [
AgentResponseUpdate(
contents=contents,
role="assistant",
response_id=update_event.task_id,
message_id=update_event.artifact.artifact_id,
additional_properties={"a2a_metadata": merged_metadata} if merged_metadata else None,
raw_representation=update_event,
)
]
@@ -523,11 +542,16 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
if not contents:
return []
merged_metadata = {
**(message.metadata or {}),
**(update_event.metadata or {}),
} or None
return [
AgentResponseUpdate(
contents=contents,
role="assistant" if message.role == A2ARole.agent else "user",
response_id=update_event.task_id,
additional_properties={"a2a_metadata": merged_metadata} if merged_metadata else None,
raw_representation=update_event,
)
]
@@ -642,9 +666,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
case _:
raise ValueError(f"Unknown content type: {content.type}")
# Exclude framework-internal keys (e.g. attribution) from wire metadata
internal_keys = {"_attribution", "context_id"}
metadata = {k: v for k, v in message.additional_properties.items() if k not in internal_keys} or None
metadata = message.additional_properties.get("a2a_metadata")
return A2AMessage(
role=A2ARole("user"),
@@ -718,6 +740,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
Message(
role="assistant" if history_item.role == A2ARole.agent else "user",
contents=contents,
additional_properties=history_item.metadata,
raw_representation=history_item,
)
)
@@ -730,5 +753,6 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
return Message(
role="assistant",
contents=contents,
additional_properties=artifact.metadata,
raw_representation=artifact,
)
+208 -1
View File
@@ -530,7 +530,7 @@ def test_prepare_message_for_a2a_forwards_context_id() -> None:
message = Message(
role="user",
contents=[Content.from_text(text="Continue the task")],
additional_properties={"context_id": "ctx-123", "trace_id": "trace-456"},
additional_properties={"context_id": "ctx-123", "a2a_metadata": {"trace_id": "trace-456"}},
)
result = agent._prepare_message_for_a2a(message)
@@ -1385,3 +1385,210 @@ async def test_streaming_terminal_task_only_emits_unstreamed_artifacts(
# endregion
# region Metadata propagation tests
async def test_message_metadata_propagated(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
"""A2AMessage.metadata should appear on response.additional_properties."""
msg = A2AMessage(
message_id="msg-meta",
role=A2ARole.agent,
parts=[Part(root=TextPart(text="hi"))],
metadata={"source": "server", "trace_id": "abc"},
)
mock_a2a_client.responses.append(msg)
response = await a2a_agent.run("hello")
assert response.additional_properties["a2a_metadata"]["source"] == "server"
assert response.additional_properties["a2a_metadata"]["trace_id"] == "abc"
async def test_artifact_metadata_propagated(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
"""Artifact.metadata should appear on response.additional_properties."""
task = Task(
id="task-art-meta",
context_id="ctx",
status=TaskStatus(state=TaskState.completed),
artifacts=[
Artifact(
artifact_id="a1",
parts=[Part(root=TextPart(text="result"))],
metadata={"artifact_key": "artifact_value"},
),
],
)
mock_a2a_client.responses.append((task, None))
response = await a2a_agent.run("go")
assert response.additional_properties["a2a_metadata"]["artifact_key"] == "artifact_value"
async def test_task_metadata_propagated_to_response(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
"""Task.metadata should appear on response.additional_properties for terminal tasks."""
task = Task(
id="task-meta",
context_id="ctx",
status=TaskStatus(state=TaskState.completed),
artifacts=[
Artifact(artifact_id="a1", parts=[Part(root=TextPart(text="done"))]),
],
metadata={"task_key": "task_value"},
)
mock_a2a_client.responses.append((task, None))
response = await a2a_agent.run("go")
assert response.additional_properties["a2a_metadata"]["task_key"] == "task_value"
async def test_task_artifact_update_event_metadata_merged(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
"""TaskArtifactUpdateEvent and Artifact metadata should both appear on the streaming update."""
artifact_event = TaskArtifactUpdateEvent(
task_id="task-ae",
context_id="ctx",
artifact=Artifact(
artifact_id="a1",
parts=[Part(root=TextPart(text="chunk"))],
metadata={"from_artifact": True},
),
metadata={"from_event": True},
)
working_task = Task(
id="task-ae",
context_id="ctx",
status=TaskStatus(state=TaskState.working),
)
terminal_task = Task(
id="task-ae",
context_id="ctx",
status=TaskStatus(state=TaskState.completed),
artifacts=[
Artifact(artifact_id="a1", parts=[Part(root=TextPart(text="chunk"))]),
],
)
terminal_event = TaskStatusUpdateEvent(
task_id="task-ae",
context_id="ctx",
status=TaskStatus(state=TaskState.completed),
final=True,
)
mock_a2a_client.responses.extend([
(working_task, artifact_event),
(terminal_task, terminal_event),
])
stream = a2a_agent.run("hello", stream=True)
updates: list[AgentResponseUpdate] = []
async for update in stream:
updates.append(update)
artifact_update = updates[0]
assert artifact_update.additional_properties["a2a_metadata"]["from_artifact"] is True
assert artifact_update.additional_properties["a2a_metadata"]["from_event"] is True
async def test_task_status_update_event_metadata_merged(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
"""TaskStatusUpdateEvent and its message metadata should both appear on the streaming update."""
status_event = TaskStatusUpdateEvent(
task_id="task-se",
context_id="ctx",
status=TaskStatus(
state=TaskState.working,
message=A2AMessage(
message_id="m1",
role=A2ARole.agent,
parts=[Part(root=TextPart(text="working..."))],
metadata={"msg_key": "msg_val"},
),
),
final=False,
metadata={"event_key": "event_val"},
)
working_task = Task(
id="task-se",
context_id="ctx",
status=TaskStatus(state=TaskState.working),
)
terminal_task = Task(
id="task-se",
context_id="ctx",
status=TaskStatus(state=TaskState.completed),
artifacts=[
Artifact(artifact_id="a1", parts=[Part(root=TextPart(text="done"))]),
],
)
terminal_event = TaskStatusUpdateEvent(
task_id="task-se",
context_id="ctx",
status=TaskStatus(state=TaskState.completed),
final=True,
)
mock_a2a_client.responses.extend([
(working_task, status_event),
(terminal_task, terminal_event),
])
stream = a2a_agent.run("hello", stream=True)
updates: list[AgentResponseUpdate] = []
async for update in stream:
updates.append(update)
status_update = updates[0]
assert status_update.additional_properties["a2a_metadata"]["msg_key"] == "msg_val"
assert status_update.additional_properties["a2a_metadata"]["event_key"] == "event_val"
async def test_history_message_metadata_propagated(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
"""Metadata on a history Message should appear on response.additional_properties."""
task = Task(
id="task-hist",
context_id="ctx",
status=TaskStatus(state=TaskState.completed),
history=[
A2AMessage(
message_id="h1",
role=A2ARole.agent,
parts=[Part(root=TextPart(text="reply"))],
metadata={"history_key": "history_value"},
),
],
)
mock_a2a_client.responses.append((task, None))
response = await a2a_agent.run("go")
assert response.additional_properties["a2a_metadata"]["history_key"] == "history_value"
async def test_continuation_token_update_carries_task_metadata(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""In-progress tasks with background=True should propagate task metadata."""
task = Task(
id="task-cont",
context_id="ctx",
status=TaskStatus(state=TaskState.working),
metadata={"bg_key": "bg_value"},
)
mock_a2a_client.responses.append((task, None))
response = await a2a_agent.run("go", background=True)
assert response.continuation_token is not None
assert response.additional_properties["a2a_metadata"]["bg_key"] == "bg_value"
async def test_none_metadata_leaves_additional_properties_empty(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""When A2A types have no metadata, additional_properties should remain empty/default."""
msg = A2AMessage(
message_id="msg-none",
role=A2ARole.agent,
parts=[Part(root=TextPart(text="no meta"))],
)
mock_a2a_client.responses.append(msg)
response = await a2a_agent.run("hello")
assert not response.additional_properties
# endregion
@@ -906,6 +906,9 @@ def _tools_to_dict( # pyright: ignore[reportUnusedFunction]
if isinstance(tool_item, FunctionTool):
results.append(tool_item.to_json_schema_spec())
continue
if isinstance(tool_item, BaseModel):
results.append(tool_item.model_dump(exclude_none=True))
continue
if isinstance(tool_item, SerializationMixin):
results.append(tool_item.to_dict())
continue
@@ -1879,6 +1879,12 @@ def _process_update(response: ChatResponse | AgentResponse, update: ChatResponse
response.finish_reason = update.finish_reason
if update.model is not None:
response.model = update.model
if (
isinstance(response, AgentResponse)
and isinstance(update, AgentResponseUpdate)
and update.finish_reason is not None
):
response.finish_reason = update.finish_reason
response.continuation_token = update.continuation_token
@@ -2435,6 +2441,7 @@ class AgentResponse(SerializationMixin, Generic[ResponseModelT]):
response_id: str | None = None,
agent_id: str | None = None,
created_at: CreatedAtT | None = None,
finish_reason: FinishReasonLiteral | FinishReason | None = None,
usage_details: UsageDetails | None = None,
value: ResponseModelT | None = None,
response_format: StructuredResponseFormat = None,
@@ -2450,6 +2457,9 @@ class AgentResponse(SerializationMixin, Generic[ResponseModelT]):
agent_id: The identifier of the agent that produced this response. Useful in multi-agent
scenarios to track which agent generated the response.
created_at: A timestamp for the chat response.
finish_reason: The reason the model stopped generating. Common values include
``"stop"`` (natural completion), ``"length"`` (token limit), and
``"tool_calls"`` (the model invoked a tool).
usage_details: The usage details for the chat response.
value: The structured output of the agent run response, if applicable.
response_format: Optional response format for the agent response.
@@ -2476,6 +2486,7 @@ class AgentResponse(SerializationMixin, Generic[ResponseModelT]):
self.response_id = response_id
self.agent_id = agent_id
self.created_at = created_at
self.finish_reason = finish_reason
self.usage_details = usage_details
self._value: ResponseModelT | None = value
self._response_format: type[BaseModel] | Mapping[str, Any] | None = response_format
@@ -2688,6 +2699,7 @@ class AgentResponseUpdate(SerializationMixin):
response_id: str | None = None,
message_id: str | None = None,
created_at: CreatedAtT | None = None,
finish_reason: FinishReasonLiteral | FinishReason | None = None,
continuation_token: ContinuationToken | None = None,
additional_properties: dict[str, Any] | None = None,
raw_representation: Any | None = None,
@@ -2703,6 +2715,9 @@ class AgentResponseUpdate(SerializationMixin):
response_id: Optional ID of the response of which this update is a part.
message_id: Optional ID of the message of which this update is a part.
created_at: Optional timestamp for the chat response update.
finish_reason: The reason the model stopped generating. Common values include
``"stop"`` (natural completion), ``"length"`` (token limit), and
``"tool_calls"`` (the model invoked a tool).
continuation_token: Optional token for resuming a long-running background operation.
When present, indicates the operation is still in progress.
additional_properties: Optional additional properties associated with the chat response update.
@@ -2729,6 +2744,7 @@ class AgentResponseUpdate(SerializationMixin):
self.response_id = response_id
self.message_id = message_id
self.created_at = created_at
self.finish_reason = finish_reason
self.continuation_token = continuation_token
self.additional_properties = _restore_compaction_annotation_in_additional_properties(
additional_properties,
@@ -2761,6 +2777,7 @@ def map_chat_to_agent_update(update: ChatResponseUpdate, agent_name: str | None)
response_id=update.response_id,
message_id=update.message_id,
created_at=update.created_at,
finish_reason=update.finish_reason, # type: ignore[arg-type]
continuation_token=update.continuation_token,
additional_properties=update.additional_properties,
raw_representation=update,
@@ -59,6 +59,62 @@ class AgentExecutorResponse:
agent_response: AgentResponse
full_conversation: list[Message]
def with_text(self, text: str) -> "AgentExecutorResponse":
"""Create a new AgentExecutorResponse with replaced text, preserving the conversation history.
Use this in custom executors that transform agent output text (e.g. upper-casing, summarising)
when you need downstream AgentExecutors to still have access to the full prior conversation.
Without this helper, sending a plain ``str`` from a custom executor breaks the context chain:
the downstream ``AgentExecutor.from_str`` handler only adds that one string to its cache and
loses all prior messages. By using ``with_text`` the response type stays
``AgentExecutorResponse``, so ``AgentExecutor.from_response`` is invoked instead and the full
conversation is preserved.
Args:
text: The replacement assistant message text.
Returns:
A new ``AgentExecutorResponse`` whose ``agent_response`` contains a single assistant
message with ``text``, and whose ``full_conversation`` is the prior conversation
(everything before the original agent turn) followed by the new assistant message.
Example:
.. code-block:: python
from agent_framework import AgentExecutorResponse, WorkflowContext, executor
@executor(
id="upper_case_executor",
input=AgentExecutorResponse,
output=AgentExecutorResponse,
workflow_output=str,
)
async def upper_case(
response: AgentExecutorResponse,
ctx: WorkflowContext[AgentExecutorResponse, str],
) -> None:
upper_text = response.agent_response.text.upper()
await ctx.send_message(response.with_text(upper_text))
await ctx.yield_output(upper_text)
"""
new_message = Message("assistant", [text])
new_agent_response = AgentResponse(messages=[new_message])
# Strip off the original agent turn and replace with the new text.
n_agent_messages = len(self.agent_response.messages)
prior_messages = (
self.full_conversation[:-n_agent_messages] if n_agent_messages else list(self.full_conversation)
)
new_full_conversation = [*prior_messages, new_message]
return AgentExecutorResponse(
executor_id=self.executor_id,
agent_response=new_agent_response,
full_conversation=new_full_conversation,
)
class AgentExecutor(Executor):
"""built-in executor that wraps an agent for handling messages.
@@ -183,7 +239,25 @@ class AgentExecutor(Executor):
"""Accept a raw user prompt string and run the agent.
The new string input will be added to the cache which is used as the conversation context for the agent run.
Warning:
If the upstream executor received an ``AgentExecutorResponse`` but emits a plain
``str``, this handler will be invoked instead of ``from_response``. This resets
the conversation context because only the new string is added to the cache and
all prior messages from the upstream agent are lost.
To preserve the full conversation when transforming agent output in a custom
executor, use ``AgentExecutorResponse.with_text(...)`` so that the message type
stays ``AgentExecutorResponse`` and ``from_response`` is called instead.
"""
if not self._cache and ctx.source_executor_ids != ["Workflow"]:
logger.warning(
"AgentExecutor '%s': from_str handler invoked with an empty cache. "
"If you are chaining from an AgentExecutor, the upstream custom executor may be "
"emitting a plain str instead of using AgentExecutorResponse.with_text(...), "
"which causes the full conversation context to be lost.",
self.id,
)
self._cache.extend(normalize_messages_input(text))
await self._run_agent_and_emit(ctx)
@@ -268,6 +268,19 @@ def executor(
forward references. When provided, takes precedence over introspection from the
``WorkflowContext`` second generic parameter (W_OutT).
Warning:
When placing a custom ``@executor`` **between** two ``AgentExecutor`` nodes, be
careful about the output type. If the custom executor receives an
``AgentExecutorResponse`` but emits a plain ``str``, the downstream
``AgentExecutor.from_str`` handler is invoked instead of ``from_response``.
This resets the conversation context because only the new string is added to
the cache and all prior messages from the upstream agent are lost.
To preserve the full conversation, use
``AgentExecutorResponse.with_text(new_text)`` to create a new response that
keeps the prior history, and set ``output=AgentExecutorResponse`` on the
decorator.
Returns:
A FunctionExecutor instance that can be wired into a Workflow.
@@ -16,12 +16,26 @@ from agent_framework._middleware import FunctionInvocationContext
from agent_framework._tools import (
_parse_annotation,
_parse_inputs,
_tools_to_dict,
)
from agent_framework.observability import OtelAttr
# region FunctionTool and tool decorator tests
def test_tools_to_dict_supports_pydantic_tool_models() -> None:
"""Pydantic-based tool specs are serialized without logging parse warnings."""
class ProviderTool(BaseModel):
kind: str
enabled: bool = True
note: str | None = None
result = _tools_to_dict([ProviderTool(kind="google_search")])
assert result == [{"kind": "google_search", "enabled": True}]
def test_tool_decorator():
"""Test the tool decorator."""
@@ -40,8 +40,10 @@ from agent_framework._types import (
_get_data_bytes_as_str,
_parse_content_list,
_parse_structured_response_value,
_process_update,
_validate_uri,
add_usage_details,
map_chat_to_agent_update,
validate_tool_mode,
)
from agent_framework.exceptions import AdditionItemMismatch, ContentError
@@ -4179,3 +4181,101 @@ def test_prepend_instructions_custom_role():
# endregion
# region finish_reason
def test_agent_response_init_with_finish_reason() -> None:
"""Test that AgentResponse correctly initializes and stores finish_reason."""
response = AgentResponse(
messages=[Message("assistant", [Content.from_text("test")])],
finish_reason="stop",
)
assert response.finish_reason == "stop"
def test_agent_response_update_init_with_finish_reason() -> None:
"""Test that AgentResponseUpdate correctly initializes and stores finish_reason."""
update = AgentResponseUpdate(
contents=[Content.from_text("test")],
role="assistant",
finish_reason="stop",
)
assert update.finish_reason == "stop"
def test_map_chat_to_agent_update_forwards_finish_reason() -> None:
"""Test that mapping a ChatResponseUpdate with finish_reason forwards it."""
chat_update = ChatResponseUpdate(
contents=[Content.from_text("test")],
finish_reason="length",
)
agent_update = map_chat_to_agent_update(chat_update, agent_name="test_agent")
assert agent_update.finish_reason == "length"
assert agent_update.author_name == "test_agent"
def test_process_update_propagates_finish_reason_to_agent_response() -> None:
"""Test that _process_update correctly updates an AgentResponse from an AgentResponseUpdate."""
response = AgentResponse(messages=[Message("assistant", [Content.from_text("test")])])
update = AgentResponseUpdate(
contents=[Content.from_text("more text")],
role="assistant",
finish_reason="stop",
)
# Process the update
_process_update(response, update)
assert response.finish_reason == "stop"
def test_process_update_does_not_overwrite_with_none() -> None:
"""Test that _process_update does not overwrite an existing finish_reason with None."""
response = AgentResponse(
messages=[Message("assistant", [Content.from_text("test")])],
finish_reason="length",
)
update = AgentResponseUpdate(
contents=[Content.from_text("more text")],
role="assistant",
finish_reason=None,
)
# Process the update
_process_update(response, update)
assert response.finish_reason == "length"
def test_agent_response_serialization_includes_finish_reason() -> None:
"""Test that AgentResponse serializes correctly, including finish_reason."""
response = AgentResponse(
messages=[Message("assistant", [Content.from_text("test")])],
response_id="test_123",
finish_reason="stop",
)
# Serialize using the framework's API and verify finish_reason is included.
data = response.to_dict()
assert "finish_reason" in data
assert data["finish_reason"] == "stop"
def test_agent_response_update_serialization_includes_finish_reason() -> None:
"""Test that AgentResponseUpdate serializes correctly, including finish_reason."""
update = AgentResponseUpdate(
contents=[Content.from_text("test")],
role="assistant",
response_id="test_456",
finish_reason="tool_calls",
)
data = update.to_dict()
assert "finish_reason" in data
assert data["finish_reason"] == "tool_calls"
# endregion
@@ -23,6 +23,7 @@ from agent_framework import (
WorkflowBuilder,
WorkflowContext,
WorkflowRunState,
executor,
handler,
)
from agent_framework.orchestrations import SequentialBuilder
@@ -478,3 +479,90 @@ async def test_from_response_preserves_service_session_id() -> None:
assert result.get_outputs() is not None
assert spy_agent._captured_service_session_id == "resp_PREVIOUS_RUN" # pyright: ignore[reportPrivateUsage]
@executor(
id="upper_case_executor",
input=AgentExecutorResponse,
output=AgentExecutorResponse,
workflow_output=str,
)
async def _upper_case_executor(
response: AgentExecutorResponse,
ctx: WorkflowContext[AgentExecutorResponse, str],
) -> None:
upper_text = response.agent_response.text.upper()
await ctx.send_message(response.with_text(upper_text))
await ctx.yield_output(upper_text)
async def test_with_text_preserves_full_conversation_through_custom_executor() -> None:
"""Custom executor using with_text must preserve the full conversation chain."""
# Mirrors the reproduction from issue #5246:
# agent1 ("User likes sky red") -> agent2 ("User likes sky blue") -> upper_case -> agent3 ("User likes sky green")
agent1 = AgentExecutor(
_SimpleAgent(id="agent1", name="ContextAgent1", reply_text="User likes sky red"), id="agent1"
)
agent2 = AgentExecutor(
_SimpleAgent(id="agent2", name="ContextAgent2", reply_text="User likes sky blue"), id="agent2"
)
agent3 = AgentExecutor(
_SimpleAgent(id="agent3", name="ContextAgent3", reply_text="User likes sky green"), id="agent3"
)
capturer = _CaptureFullConversation(id="capture")
wf = (
WorkflowBuilder(start_executor=agent1, output_executors=[capturer])
.add_chain([agent1, agent2, _upper_case_executor, agent3, capturer])
.build()
)
result = await wf.run("")
payload = next(o for o in result.get_outputs() if isinstance(o, dict))
# The final agent must see the full conversation: user, agent1, UPPER(agent2), agent3
assert payload["roles"] == ["user", "assistant", "assistant", "assistant"]
assert payload["texts"][1] == "User likes sky red"
assert payload["texts"][2] == "USER LIKES SKY BLUE"
assert payload["texts"][3] == "User likes sky green"
async def test_with_text_does_not_mutate_original() -> None:
"""with_text returns a new instance; the original must be unmodified."""
original = AgentExecutorResponse(
executor_id="test_exec",
agent_response=AgentResponse(messages=[Message("assistant", ["original reply"])]),
full_conversation=[Message("user", ["prompt"]), Message("assistant", ["original reply"])],
)
new = original.with_text("transformed reply")
assert new is not original
assert new.agent_response.text == "transformed reply"
assert new.full_conversation[-1].text == "transformed reply"
assert new.full_conversation[-1].role == "assistant"
# Original unchanged
assert original.agent_response.text == "original reply"
assert original.full_conversation[-1].text == "original reply"
async def test_with_text_strips_multi_message_agent_turn() -> None:
"""When the agent turn has multiple messages (tool calls), with_text strips all of them."""
tool_call = Message("assistant", ["<tool_call>"])
tool_result = Message("tool", ["<result>"])
final_reply = Message("assistant", ["actual answer"])
user_msg = Message("user", ["question"])
original = AgentExecutorResponse(
executor_id="exec",
agent_response=AgentResponse(messages=[tool_call, tool_result, final_reply]),
full_conversation=[user_msg, tool_call, tool_result, final_reply],
)
new = original.with_text("summarised answer")
# Only the pre-agent-turn messages should remain, plus the replacement
assert len(new.full_conversation) == 2
assert new.full_conversation[0].text == "question"
assert new.full_conversation[1].text == "summarised answer"
assert new.agent_response.text == "summarised answer"
+2 -1
View File
@@ -1,6 +1,6 @@
# Gemini Package (agent-framework-gemini)
Integration with Google's Gemini API via the `google-genai` SDK.
Integration with Google's Gemini Developer API and Vertex AI via the `google-genai` SDK.
## Core Classes
@@ -8,6 +8,7 @@ Integration with Google's Gemini API via the `google-genai` SDK.
- **`GeminiChatClient`** - Full-featured chat client with function invocation, middleware, and telemetry
- **`GeminiChatOptions`** - Options TypedDict for Gemini-specific parameters
- **`GeminiSettings`** - Settings loaded from environment variables
- **`GoogleGeminiSettings`** - SDK-standard `GOOGLE_*` settings loaded from environment variables
- **`ThinkingConfig`** - Configuration for extended thinking
## Gemini-specific Options
+19 -2
View File
@@ -12,11 +12,28 @@ The Gemini integration enables Microsoft Agent Framework applications to call Go
## Authentication
Obtain an API key from [Google AI Studio](https://aistudio.google.com/apikey) and set it via environment variable:
The connector supports both `google-genai` authentication modes.
### Gemini Developer API
Obtain an API key from [Google AI Studio](https://aistudio.google.com/apikey) and set either the package-prefixed or SDK-standard environment variable:
```bash
export GEMINI_API_KEY="your-api-key"
export GEMINI_MODEL="gemini-2.5-flash"
# or: export GOOGLE_API_KEY="your-api-key"
export GEMINI_MODEL="gemini-2.5-flash-lite"
# or: export GOOGLE_MODEL="gemini-2.5-flash-lite"
```
### Vertex AI
Set the standard Vertex AI environment variables used by `google-genai`:
```bash
export GOOGLE_GENAI_USE_VERTEXAI=true
export GOOGLE_CLOUD_PROJECT="your-project-id"
export GOOGLE_CLOUD_LOCATION="global"
export GOOGLE_MODEL="gemini-2.5-flash-lite"
```
## Examples
@@ -2,7 +2,14 @@
import importlib.metadata
from ._chat_client import GeminiChatClient, GeminiChatOptions, GeminiSettings, RawGeminiChatClient, ThinkingConfig
from ._chat_client import (
GeminiChatClient,
GeminiChatOptions,
GeminiSettings,
GoogleGeminiSettings,
RawGeminiChatClient,
ThinkingConfig,
)
try:
__version__ = importlib.metadata.version(__name__)
@@ -13,6 +20,7 @@ __all__ = [
"GeminiChatClient",
"GeminiChatOptions",
"GeminiSettings",
"GoogleGeminiSettings",
"RawGeminiChatClient",
"ThinkingConfig",
"__version__",
@@ -30,6 +30,7 @@ from agent_framework import (
from agent_framework._settings import SecretString, load_settings
from agent_framework.observability import ChatTelemetryLayer
from google import genai
from google.auth.credentials import Credentials
from google.genai import types
from pydantic import BaseModel
@@ -54,6 +55,7 @@ __all__ = [
"GeminiChatClient",
"GeminiChatOptions",
"GeminiSettings",
"GoogleGeminiSettings",
"RawGeminiChatClient",
"ThinkingConfig",
]
@@ -161,10 +163,74 @@ class GeminiSettings(TypedDict, total=False):
model: str | None
class GoogleGeminiSettings(TypedDict, total=False):
"""Google SDK configuration settings loaded from ``GOOGLE_*`` environment variables."""
api_key: SecretString | None
model: str | None
genai_use_vertexai: bool | None
cloud_project: str | None
cloud_location: str | None
# endregion
_GEMINI_SERVICE_URL = "https://generativelanguage.googleapis.com"
_GEMINI_API_BASE_URL = "https://generativelanguage.googleapis.com"
_VERTEX_AI_BASE_URL = "https://aiplatform.googleapis.com"
def _resolve_vertexai_mode(client: genai.Client, *, fallback: bool | None = None) -> bool:
"""Resolve whether a client targets Vertex AI, preferring the instantiated SDK client state."""
api_client = getattr(client, "_api_client", None)
vertexai = getattr(api_client, "vertexai", None)
if isinstance(vertexai, bool):
return vertexai
return bool(fallback)
def _resolve_service_url(client: genai.Client, *, vertexai: bool) -> str:
"""Resolve the base service URL from the instantiated SDK client, with a stable fallback."""
api_client = getattr(client, "_api_client", None)
http_options = getattr(api_client, "_http_options", None)
base_url = getattr(http_options, "base_url", None)
if isinstance(base_url, str) and base_url:
return base_url.rstrip("/")
return _VERTEX_AI_BASE_URL if vertexai else _GEMINI_API_BASE_URL
def _validate_client_auth_configuration(
*,
vertexai: bool | None,
api_key: SecretString | None,
project: str | None,
location: str | None,
credentials: Credentials | None,
) -> None:
"""Validate supported auth combinations before instantiating the SDK client."""
if vertexai is not True:
if api_key is None:
raise ValueError(
"Gemini client requires an API key when Vertex AI is not enabled. "
"Set GOOGLE_API_KEY or GEMINI_API_KEY, or pass api_key explicitly."
)
return
if api_key is not None or credentials is not None or (project and location):
return
if project or location:
raise ValueError(
"Gemini client requires both GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION "
"when Vertex AI is enabled without an API key."
)
raise ValueError(
"Gemini client requires Vertex AI credentials or configuration when Vertex AI is enabled. "
"Provide GOOGLE_API_KEY for Vertex AI express mode, pass credentials, or set "
"GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION."
)
# Keys mapping to a different GenerateContentConfig field name
_OPTION_TRANSLATIONS: dict[str, str] = {
@@ -210,7 +276,7 @@ class RawGeminiChatClient(
BaseChatClient[GeminiChatOptionsT],
Generic[GeminiChatOptionsT],
):
"""A raw Gemini chat client for the Google Gemini API without function invocation, middleware or telemetry.
"""A raw Gemini chat client for Gemini Developer API or Vertex AI.
Use this when you want full control over the request pipeline. For instance, to opt out of
telemetry, use custom middleware, or compose your own layers. If you want the full-featured
@@ -224,6 +290,10 @@ class RawGeminiChatClient(
*,
api_key: str | None = None,
model: str | None = None,
vertexai: bool | None = None,
project: str | None = None,
location: str | None = None,
credentials: Credentials | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
client: genai.Client | None = None,
@@ -232,11 +302,21 @@ class RawGeminiChatClient(
"""Create a raw Gemini chat client.
Args:
api_key: Google AI Studio API key. Falls back to ``GEMINI_API_KEY`` environment variable.
model: Default model identifier. Falls back to ``GEMINI_MODEL`` environment variable.
api_key: Gemini Developer API key. Falls back to environment settings, preferring
``GOOGLE_API_KEY`` over ``GEMINI_API_KEY``.
model: Default model identifier. Falls back to environment settings, preferring
``GOOGLE_MODEL`` over ``GEMINI_MODEL``.
vertexai: Whether to use Vertex AI endpoints. Falls back to environment settings,
using ``GOOGLE_GENAI_USE_VERTEXAI`` when not passed explicitly.
project: Google Cloud project ID for Vertex AI. Falls back to environment settings,
using ``GOOGLE_CLOUD_PROJECT`` when not passed explicitly.
location: Vertex AI location. Falls back to environment settings, preferring
using ``GOOGLE_CLOUD_LOCATION`` when not passed explicitly.
credentials: Google Cloud credentials for Vertex AI. When omitted, the SDK can use
Application Default Credentials.
env_file_path: Path to a ``.env`` file for credential loading.
env_file_encoding: Encoding for the ``.env`` file.
client: Pre-built ``genai.Client`` instance. When provided, ``api_key`` is not required.
client: Pre-built ``genai.Client`` instance. When provided, connector auth settings are not required.
additional_properties: Extra properties stored on the client instance.
"""
settings = load_settings(
@@ -247,21 +327,58 @@ class RawGeminiChatClient(
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
google_settings = load_settings(
GoogleGeminiSettings,
env_prefix="GOOGLE_",
api_key=api_key,
model=model,
genai_use_vertexai=vertexai,
cloud_project=project,
cloud_location=location,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
configured_vertexai = google_settings.get("genai_use_vertexai")
if client:
self._genai_client = client
else:
resolved_key = settings.get("api_key")
if not resolved_key:
raise ValueError(
"Gemini API key is required. Set via api_key parameter or GEMINI_API_KEY environment variable."
)
self._genai_client = genai.Client(
api_key=resolved_key.get_secret_value(),
http_options={"headers": {"x-goog-api-client": AGENT_FRAMEWORK_USER_AGENT}},
resolved_key = google_settings.get("api_key") or settings.get("api_key")
resolved_project = google_settings.get("cloud_project")
resolved_location = google_settings.get("cloud_location")
_validate_client_auth_configuration(
vertexai=configured_vertexai,
api_key=resolved_key,
project=resolved_project,
location=resolved_location,
credentials=credentials,
)
self.model = settings.get("model")
client_kwargs: dict[str, Any] = {
"http_options": {"headers": {"x-goog-api-client": AGENT_FRAMEWORK_USER_AGENT}},
}
if configured_vertexai is not None:
client_kwargs["vertexai"] = configured_vertexai
if resolved_key is not None and (
configured_vertexai is not True
or (credentials is None and not (resolved_project and resolved_location))
):
client_kwargs["api_key"] = resolved_key.get_secret_value()
if configured_vertexai is True and resolved_project:
client_kwargs["project"] = resolved_project
if configured_vertexai is True and resolved_location:
client_kwargs["location"] = resolved_location
if configured_vertexai is True and credentials is not None:
client_kwargs["credentials"] = credentials
self._genai_client = genai.Client(**client_kwargs)
self._vertexai = _resolve_vertexai_mode(self._genai_client, fallback=configured_vertexai)
self._service_url = _resolve_service_url(self._genai_client, vertexai=self._vertexai)
self.model = google_settings.get("model") or settings.get("model")
super().__init__(additional_properties=additional_properties)
@@ -414,12 +531,12 @@ class RawGeminiChatClient(
@override
def service_url(self) -> str:
"""Return the base URL of the Gemini API service.
"""Return the base URL of the configured Gemini or Vertex AI service.
Returns:
The Gemini API base URL.
The resolved service base URL.
"""
return _GEMINI_SERVICE_URL
return self._service_url
# region Request preparation
@@ -528,15 +645,16 @@ class RawGeminiChatClient(
call_id = content.call_id or self._generate_tool_call_id()
if content.name:
call_id_to_name[call_id] = content.name
parts.append(
types.Part(
function_call=types.FunctionCall(
id=call_id,
name=content.name or "",
args=content.parse_arguments() or {},
)
)
function_call = types.FunctionCall(
id=call_id,
name=content.name or "",
args=content.parse_arguments() or {},
)
raw_part = content.raw_representation
if isinstance(raw_part, types.Part) and raw_part.function_call is not None:
parts.append(raw_part.model_copy(update={"function_call": function_call}, deep=True))
else:
parts.append(types.Part(function_call=function_call))
case _:
logger.debug("Skipping unsupported content type for Gemini: %s", content.type)
return parts
@@ -889,7 +1007,7 @@ class GeminiChatClient(
RawGeminiChatClient[GeminiChatOptionsT],
Generic[GeminiChatOptionsT],
):
"""Gemini chat client for the Google Gemini API with function invocation, middleware, and telemetry.
"""Gemini chat client for Gemini Developer API or Vertex AI with function invocation, middleware, and telemetry.
This is the recommended client for most use cases. It builds on ``RawGeminiChatClient``
and adds:
@@ -908,6 +1026,10 @@ class GeminiChatClient(
*,
api_key: str | None = None,
model: str | None = None,
vertexai: bool | None = None,
project: str | None = None,
location: str | None = None,
credentials: Credentials | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
client: genai.Client | None = None,
@@ -918,11 +1040,18 @@ class GeminiChatClient(
"""Create a Gemini chat client.
Args:
api_key: The Google AI Studio API key. Falls back to ``GEMINI_API_KEY`` environment variable.
model: Default model identifier. Falls back to ``GEMINI_MODEL`` environment variable.
api_key: Gemini Developer API key. Falls back to environment settings, preferring
``GOOGLE_API_KEY`` over ``GEMINI_API_KEY``.
model: Default model identifier. Falls back to environment settings, preferring
``GOOGLE_MODEL`` over ``GEMINI_MODEL``.
vertexai: Whether to use Vertex AI endpoints. Falls back to ``GOOGLE_GENAI_USE_VERTEXAI``.
project: Google Cloud project ID for Vertex AI. Falls back to ``GOOGLE_CLOUD_PROJECT``.
location: Vertex AI location. Falls back to ``GOOGLE_CLOUD_LOCATION``.
credentials: Google Cloud credentials for Vertex AI. When omitted, the SDK can use
Application Default Credentials.
env_file_path: Path to a ``.env`` file for credential loading.
env_file_encoding: Encoding for the ``.env`` file.
client: Pre-built ``genai.Client`` instance. When provided, ``api_key`` is not required.
client: Pre-built ``genai.Client`` instance. When provided, connector auth settings are not required.
additional_properties: Extra properties stored on the client instance.
middleware: Optional middleware chain applied to every call.
function_invocation_configuration: Optional configuration for the function invocation loop.
@@ -930,6 +1059,10 @@ class GeminiChatClient(
super().__init__(
api_key=api_key,
model=model,
vertexai=vertexai,
project=project,
location=location,
credentials=credentials,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
client=client,
+4 -2
View File
@@ -14,5 +14,7 @@ This folder contains examples demonstrating how to use Google Gemini models with
## Environment Variables
- `GEMINI_API_KEY`: Your Google AI Studio API key (get one from [Google AI Studio](https://aistudio.google.com/apikey))
- `GEMINI_MODEL`: The Gemini model to use (e.g., `gemini-2.5-flash`, `gemini-2.5-pro`)
- `GOOGLE_MODEL` or `GEMINI_MODEL`: The Gemini model to use (for example,
`gemini-2.5-flash-lite` or `gemini-2.5-pro`)
- For Gemini Developer API: `GEMINI_API_KEY` or `GOOGLE_API_KEY`
- For Vertex AI: `GOOGLE_GENAI_USE_VERTEXAI=true`, `GOOGLE_CLOUD_PROJECT`, and `GOOGLE_CLOUD_LOCATION`
@@ -0,0 +1 @@
# Copyright (c) Microsoft. All rights reserved.
@@ -4,9 +4,9 @@
Allows the model to reason through complex problems before responding.
Requires the following environment variables to be set:
- GEMINI_API_KEY
- GEMINI_MODEL
Requires ``GOOGLE_MODEL`` or ``GEMINI_MODEL`` and either Gemini Developer API credentials
(``GEMINI_API_KEY`` or ``GOOGLE_API_KEY``) or Vertex AI settings
(``GOOGLE_GENAI_USE_VERTEXAI``, ``GOOGLE_CLOUD_PROJECT``, and ``GOOGLE_CLOUD_LOCATION``).
"""
import asyncio
@@ -23,10 +23,12 @@ async def main() -> None:
"""Example of extended thinking with a Python version comparison question."""
print("=== Extended thinking ===")
# 1. Configure Gemini extended thinking for a reasoning-heavy request.
options: GeminiChatOptions = {
"thinking_config": ThinkingConfig(thinking_budget=2048),
}
# 2. Create the agent with the Gemini chat client and default thinking options.
agent = Agent(
client=GeminiChatClient(),
name="PythonAgent",
@@ -34,6 +36,7 @@ async def main() -> None:
default_options=options,
)
# 3. Stream the answer so you can see the final response as it arrives.
query = "What new language features were introduced in Python between 3.10 and 3.14?"
print(f"User: {query}")
print("Agent: ", end="", flush=True)
@@ -45,3 +48,12 @@ async def main() -> None:
if __name__ == "__main__":
asyncio.run(main())
"""
Sample output:
=== Extended thinking ===
User: What new language features were introduced in Python between 3.10 and 3.14?
Agent: Python 3.11 introduced exception groups and TaskGroup.
Python 3.12 added PEP 695 type parameter syntax.
Python 3.13-3.14 continued improving typing, performance, and developer ergonomics.
"""
+18 -3
View File
@@ -4,9 +4,9 @@
Covers both non-streaming and streaming responses.
Requires the following environment variables to be set:
- GEMINI_API_KEY
- GEMINI_MODEL
Requires ``GOOGLE_MODEL`` or ``GEMINI_MODEL`` and either Gemini Developer API credentials
(``GEMINI_API_KEY`` or ``GOOGLE_API_KEY``) or Vertex AI settings
(``GOOGLE_GENAI_USE_VERTEXAI``, ``GOOGLE_CLOUD_PROJECT``, and ``GOOGLE_CLOUD_LOCATION``).
"""
import asyncio
@@ -35,6 +35,7 @@ async def non_streaming_example() -> None:
"""Runs the agent and waits for the complete response before printing it."""
print("=== Non-streaming ===")
# 1. Create the agent with the Gemini chat client and local weather tool.
agent = Agent(
client=GeminiChatClient(),
name="WeatherAgent",
@@ -42,6 +43,7 @@ async def non_streaming_example() -> None:
tools=[get_weather],
)
# 2. Ask the agent for a single weather lookup and print the final response.
query = "What's the weather like in Karlsruhe, Germany?"
print(f"User: {query}")
result = await agent.run(query)
@@ -52,6 +54,7 @@ async def streaming_example() -> None:
"""Runs the agent and prints each chunk as it is received."""
print("=== Streaming ===")
# 1. Create the same agent configuration for a streaming tool-call example.
agent = Agent(
client=GeminiChatClient(),
name="WeatherAgent",
@@ -59,6 +62,7 @@ async def streaming_example() -> None:
tools=[get_weather],
)
# 2. Ask a multi-location question and stream the model output as it arrives.
query = "What's the weather like in Portland and in Paris?"
print(f"User: {query}")
print("Agent: ", end="", flush=True)
@@ -76,3 +80,14 @@ async def main() -> None:
if __name__ == "__main__":
asyncio.run(main())
"""
Sample output:
=== Non-streaming ===
User: What's the weather like in Karlsruhe, Germany?
Result: The weather in Karlsruhe, Germany is currently sunny with a high of 16°C.
=== Streaming ===
User: What's the weather like in Portland and in Paris?
Agent: In Portland, it is currently rainy with a high of 11°C. In Paris, it is cloudy with a high of 27°C.
"""
@@ -4,9 +4,9 @@
Allows the model to write and run code in a sandboxed environment to answer questions.
Requires the following environment variables to be set:
- GEMINI_API_KEY
- GEMINI_MODEL
Requires ``GOOGLE_MODEL`` or ``GEMINI_MODEL`` and either Gemini Developer API credentials
(``GEMINI_API_KEY`` or ``GOOGLE_API_KEY``) or Vertex AI settings
(``GOOGLE_GENAI_USE_VERTEXAI``, ``GOOGLE_CLOUD_PROJECT``, and ``GOOGLE_CLOUD_LOCATION``).
"""
import asyncio
@@ -23,6 +23,7 @@ async def main() -> None:
"""Run the code execution example."""
print("=== Code execution ===")
# 1. Create the agent with Gemini and the built-in code execution tool.
agent = Agent(
client=GeminiChatClient(),
name="CodeAgent",
@@ -30,6 +31,7 @@ async def main() -> None:
tools=[GeminiChatClient.get_code_interpreter_tool()],
)
# 2. Ask for a computed answer and stream the generated code and final result.
query = "What are the first 20 prime numbers? Compute them in code."
print(f"User: {query}")
print("Agent: ", end="", flush=True)
@@ -41,3 +43,10 @@ async def main() -> None:
if __name__ == "__main__":
asyncio.run(main())
"""
Sample output:
=== Code execution ===
User: What are the first 20 prime numbers? Compute them in code.
Agent: The first 20 prime numbers are 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, and 71.
"""
@@ -4,9 +4,9 @@
Allows Gemini to retrieve location and mapping information before responding.
Requires the following environment variables to be set:
- GEMINI_API_KEY
- GEMINI_MODEL
Requires ``GOOGLE_MODEL`` or ``GEMINI_MODEL`` and either Gemini Developer API credentials
(``GEMINI_API_KEY`` or ``GOOGLE_API_KEY``) or Vertex AI settings
(``GOOGLE_GENAI_USE_VERTEXAI``, ``GOOGLE_CLOUD_PROJECT``, and ``GOOGLE_CLOUD_LOCATION``).
"""
import asyncio
@@ -23,6 +23,7 @@ async def main() -> None:
"""Run the Google Maps grounding example."""
print("=== Google Maps grounding ===")
# 1. Create the agent with Gemini and the built-in Google Maps grounding tool.
agent = Agent(
client=GeminiChatClient(),
name="MapsAgent",
@@ -30,6 +31,7 @@ async def main() -> None:
tools=[GeminiChatClient.get_maps_grounding_tool()],
)
# 2. Ask a location-aware question and stream the grounded answer.
query = "What are some highly rated restaurants in the city center of Karlsruhe, Germany?"
print(f"User: {query}")
print("Agent: ", end="", flush=True)
@@ -41,3 +43,11 @@ async def main() -> None:
if __name__ == "__main__":
asyncio.run(main())
"""
Sample output:
=== Google Maps grounding ===
User: What are some highly rated restaurants in the city center of Karlsruhe, Germany?
Agent: Here are several highly rated restaurants near Karlsruhe city center,
along with their cuisine styles and approximate walking distance.
"""
@@ -4,9 +4,9 @@
Allows Gemini to retrieve up-to-date information from the web before responding.
Requires the following environment variables to be set:
- GEMINI_API_KEY
- GEMINI_MODEL
Requires ``GOOGLE_MODEL`` or ``GEMINI_MODEL`` and either Gemini Developer API credentials
(``GEMINI_API_KEY`` or ``GOOGLE_API_KEY``) or Vertex AI settings
(``GOOGLE_GENAI_USE_VERTEXAI``, ``GOOGLE_CLOUD_PROJECT``, and ``GOOGLE_CLOUD_LOCATION``).
"""
import asyncio
@@ -23,6 +23,7 @@ async def main() -> None:
"""Run the Google Search grounding example."""
print("=== Google Search grounding ===")
# 1. Create the agent with Gemini and the built-in Google Search grounding tool.
agent = Agent(
client=GeminiChatClient(),
name="SearchAgent",
@@ -30,6 +31,7 @@ async def main() -> None:
tools=[GeminiChatClient.get_web_search_tool()],
)
# 2. Ask a current-events style question and stream the grounded answer.
query = "What is the latest stable release of the .NET SDK?"
print(f"User: {query}")
print("Agent: ", end="", flush=True)
@@ -41,3 +43,10 @@ async def main() -> None:
if __name__ == "__main__":
asyncio.run(main())
"""
Sample output:
=== Google Search grounding ===
User: What is the latest stable release of the .NET SDK?
Agent: As of April 14, 2026, the latest stable release of the .NET SDK is .NET 10.0 (SDK 10.0.201).
"""
@@ -15,12 +15,28 @@ from pydantic import BaseModel
from agent_framework_gemini import GeminiChatClient, GeminiChatOptions, ThinkingConfig
skip_if_no_api_key = pytest.mark.skipif(
not os.getenv("GEMINI_API_KEY"),
reason="GEMINI_API_KEY not set; skipping integration tests.",
def _has_gemini_integration_credentials() -> bool:
"""Return whether integration credentials for either Gemini API or Vertex AI appear to be configured."""
if os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY"):
return True
if os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() in {"true", "1", "yes", "on"}:
return bool(
os.getenv("GOOGLE_CLOUD_PROJECT")
or os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
or os.getenv("GOOGLE_API_KEY")
)
return False
skip_if_no_credentials = pytest.mark.skipif(
not _has_gemini_integration_credentials(),
reason="Gemini Developer API or Vertex AI credentials not set; skipping integration tests.",
)
_TEST_MODEL = "gemini-2.5-flash"
_TEST_MODEL = os.getenv("GOOGLE_MODEL") or os.getenv("GEMINI_MODEL", "gemini-2.5-flash-lite")
# stub helpers
@@ -89,6 +105,7 @@ def _make_response(
candidate.finish_reason = None
response.candidates = [candidate]
response.finish_reason = finish_reason
response.model_version = model_version
if prompt_tokens is not None or output_tokens is not None:
@@ -115,6 +132,8 @@ def _make_gemini_client(
) -> tuple[GeminiChatClient, MagicMock]:
"""Return a (GeminiChatClient, mock_genai_client) pair."""
mock = mock_client or MagicMock()
mock._api_client.vertexai = False
mock._api_client._http_options.base_url = "https://generativelanguage.googleapis.com/"
client = GeminiChatClient(client=mock, model=model)
return client, mock
@@ -135,12 +154,134 @@ def test_client_created_from_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
assert client.model == "gemini-2.5-flash"
def test_missing_api_key_raises_when_no_client_injected(monkeypatch: pytest.MonkeyPatch) -> None:
"""Raises ValueError at construction when neither an API key nor a pre-built client is available."""
def test_client_created_from_google_api_key_env(monkeypatch: pytest.MonkeyPatch) -> None:
"""Initialises successfully when the SDK-standard Google API key environment variable is set."""
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
monkeypatch.delenv("GEMINI_MODEL", raising=False)
monkeypatch.delenv("GOOGLE_GENAI_USE_VERTEXAI", raising=False)
monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
monkeypatch.setenv("GOOGLE_API_KEY", "test-key-123")
monkeypatch.setenv("GOOGLE_MODEL", "gemini-2.5-flash-lite")
with pytest.raises(ValueError, match="GEMINI_API_KEY"):
mock_client = MagicMock()
mock_client._api_client.vertexai = False
mock_client._api_client._http_options.base_url = "https://generativelanguage.googleapis.com/"
with patch("agent_framework_gemini._chat_client.genai.Client") as client_factory:
client_factory.return_value = mock_client
client = GeminiChatClient()
assert client_factory.call_args.kwargs["api_key"] == "test-key-123"
assert "vertexai" not in client_factory.call_args.kwargs
assert client.model == "gemini-2.5-flash-lite"
assert client.service_url() == "https://generativelanguage.googleapis.com"
def test_client_created_from_vertex_ai_env(monkeypatch: pytest.MonkeyPatch) -> None:
"""Initialises a Vertex AI client when the SDK-standard Vertex AI environment variables are set."""
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project")
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "global")
mock_client = MagicMock()
mock_client._api_client.vertexai = True
mock_client._api_client._http_options.base_url = "https://aiplatform.googleapis.com/"
with patch("agent_framework_gemini._chat_client.genai.Client", return_value=mock_client) as client_factory:
client = GeminiChatClient()
assert client_factory.call_args.kwargs["vertexai"] is True
assert client_factory.call_args.kwargs["project"] == "test-project"
assert client_factory.call_args.kwargs["location"] == "global"
assert "api_key" not in client_factory.call_args.kwargs
assert client.service_url() == "https://aiplatform.googleapis.com"
def test_google_settings_take_precedence_over_gemini_aliases(monkeypatch: pytest.MonkeyPatch) -> None:
"""Prefers SDK-standard ``GOOGLE_*`` settings when both env families are present."""
monkeypatch.setenv("GEMINI_API_KEY", "gemini-key")
monkeypatch.setenv("GEMINI_MODEL", "gemini-model")
monkeypatch.setenv("GOOGLE_API_KEY", "google-key")
monkeypatch.setenv("GOOGLE_MODEL", "google-model")
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "google-project")
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "global")
mock_client = MagicMock()
mock_client._api_client.vertexai = True
mock_client._api_client._http_options.base_url = "https://aiplatform.googleapis.com/"
with patch("agent_framework_gemini._chat_client.genai.Client", return_value=mock_client) as client_factory:
client = GeminiChatClient()
assert client_factory.call_args.kwargs["vertexai"] is True
assert client_factory.call_args.kwargs["project"] == "google-project"
assert client_factory.call_args.kwargs["location"] == "global"
assert "api_key" not in client_factory.call_args.kwargs
assert client.model == "google-model"
assert client.service_url() == "https://aiplatform.googleapis.com"
def test_missing_api_key_raises_when_no_client_injected(monkeypatch: pytest.MonkeyPatch) -> None:
"""Raises ValueError at construction when neither Gemini API nor Vertex AI settings are available."""
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
monkeypatch.delenv("GEMINI_MODEL", raising=False)
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
monkeypatch.delenv("GOOGLE_GENAI_USE_VERTEXAI", raising=False)
monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
with pytest.raises(ValueError, match="requires an API key when Vertex AI is not enabled"):
GeminiChatClient(model="gemini-2.5-flash")
def test_vertex_ai_express_mode_uses_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
"""Passes the API key in Vertex AI express mode when no project/location pair is configured."""
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
monkeypatch.delenv("GEMINI_MODEL", raising=False)
monkeypatch.setenv("GOOGLE_API_KEY", "test-key-123")
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
mock_client = MagicMock()
mock_client._api_client.vertexai = True
mock_client._api_client._http_options.base_url = "https://aiplatform.googleapis.com/"
with patch("agent_framework_gemini._chat_client.genai.Client", return_value=mock_client) as client_factory:
client = GeminiChatClient(model="gemini-2.5-flash-lite")
assert client_factory.call_args.kwargs["vertexai"] is True
assert client_factory.call_args.kwargs["api_key"] == "test-key-123"
assert "project" not in client_factory.call_args.kwargs
assert "location" not in client_factory.call_args.kwargs
assert client.service_url() == "https://aiplatform.googleapis.com"
def test_vertex_ai_requires_configuration(monkeypatch: pytest.MonkeyPatch) -> None:
"""Raises a deterministic error when Vertex AI is enabled without any auth configuration."""
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
with pytest.raises(ValueError, match="requires Vertex AI credentials or configuration"):
GeminiChatClient(model="gemini-2.5-flash")
def test_vertex_ai_requires_project_and_location_together(monkeypatch: pytest.MonkeyPatch) -> None:
"""Raises a deterministic error when only one Vertex AI location setting is present."""
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project")
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
with pytest.raises(ValueError, match="requires both GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION"):
GeminiChatClient(model="gemini-2.5-flash")
@@ -495,6 +636,30 @@ async def test_thinking_parts_are_silently_skipped() -> None:
assert response.messages[0].text == "The answer is 42."
def test_function_call_part_preserves_thought_signature_from_raw_part() -> None:
"""Reuses the original Gemini Part so tool loops retain thought_signature metadata."""
client, _ = _make_gemini_client()
raw_part = types.Part(
function_call=types.FunctionCall(id="call-1", name="get_weather", args={"location": "Paris"}),
thought_signature=b"sig-123",
)
content = Content.from_function_call(
call_id="call-1",
name="get_weather",
arguments={"location": "Paris"},
raw_representation=raw_part,
)
parts = client._convert_message_contents([content], {})
assert len(parts) == 1
assert parts[0].thought_signature == b"sig-123"
assert parts[0].function_call is not None
assert parts[0].function_call.id == "call-1"
assert parts[0].function_call.name == "get_weather"
assert parts[0].function_call.args == {"location": "Paris"}
# code execution parts
@@ -1283,12 +1448,26 @@ def test_service_url() -> None:
assert client.service_url() == "https://generativelanguage.googleapis.com"
def test_service_url_falls_back_when_sdk_base_url_is_unavailable() -> None:
"""Falls back to the known service URL when the SDK client does not expose a base URL."""
gemini_sdk_client = MagicMock()
gemini_sdk_client._api_client.vertexai = False
gemini_client = GeminiChatClient(client=gemini_sdk_client, model="gemini-2.5-flash")
vertex_sdk_client = MagicMock()
vertex_sdk_client._api_client.vertexai = True
vertex_client = GeminiChatClient(client=vertex_sdk_client, model="gemini-2.5-flash")
assert gemini_client.service_url() == "https://generativelanguage.googleapis.com"
assert vertex_client.service_url() == "https://aiplatform.googleapis.com"
# integration tests
@pytest.mark.flaky
@pytest.mark.integration
@skip_if_no_api_key
@skip_if_no_credentials
async def test_integration_basic_chat() -> None:
"""Basic request/response round-trip returns a non-empty text reply."""
client = GeminiChatClient(model=_TEST_MODEL)
@@ -1302,7 +1481,7 @@ async def test_integration_basic_chat() -> None:
@pytest.mark.flaky
@pytest.mark.integration
@skip_if_no_api_key
@skip_if_no_credentials
async def test_integration_streaming() -> None:
"""Streaming yields multiple chunks that together form a non-empty response."""
client = GeminiChatClient(model=_TEST_MODEL)
@@ -1319,7 +1498,7 @@ async def test_integration_streaming() -> None:
@pytest.mark.flaky
@pytest.mark.integration
@skip_if_no_api_key
@skip_if_no_credentials
async def test_integration_structured_output() -> None:
"""Structured output with a Pydantic response_format returns a parsed value via response.value."""
@@ -1340,7 +1519,7 @@ async def test_integration_structured_output() -> None:
@pytest.mark.flaky
@pytest.mark.integration
@skip_if_no_api_key
@skip_if_no_credentials
async def test_integration_tool_calling() -> None:
"""Model invokes the registered tool when asked a question that requires it."""
@@ -1363,7 +1542,7 @@ async def test_integration_tool_calling() -> None:
@pytest.mark.flaky
@pytest.mark.integration
@skip_if_no_api_key
@skip_if_no_credentials
async def test_integration_thinking_config() -> None:
"""Model accepts a thinking budget and returns a non-empty text reply."""
options: GeminiChatOptions = {"thinking_config": ThinkingConfig(thinking_budget=512)}
@@ -1380,7 +1559,7 @@ async def test_integration_thinking_config() -> None:
@pytest.mark.flaky
@pytest.mark.integration
@skip_if_no_api_key
@skip_if_no_credentials
async def test_integration_google_search_grounding() -> None:
"""Google Search grounding returns a non-empty response for a current-events question."""
client = GeminiChatClient(model=_TEST_MODEL)
@@ -1396,7 +1575,7 @@ async def test_integration_google_search_grounding() -> None:
@pytest.mark.flaky
@pytest.mark.integration
@skip_if_no_api_key
@skip_if_no_credentials
async def test_integration_google_maps_grounding() -> None:
"""Google Maps grounding returns a non-empty response for a location-based question."""
client = GeminiChatClient(model=_TEST_MODEL)
@@ -1417,7 +1596,7 @@ async def test_integration_google_maps_grounding() -> None:
@pytest.mark.flaky
@pytest.mark.integration
@skip_if_no_api_key
@skip_if_no_credentials
async def test_integration_code_execution() -> None:
"""Code execution tool produces a non-empty response for a computation request."""
client = GeminiChatClient(model=_TEST_MODEL)
@@ -2474,6 +2474,29 @@ class RawOpenAIChatClient( # type: ignore[misc]
raw_representation=event,
)
)
elif ann_type == "url_citation":
ann_url = _get_ann_value("url")
if ann_url:
ann_start = _get_ann_value("start_index")
ann_end = _get_ann_value("end_index")
annotation_obj = Annotation(
type="citation",
title=_get_ann_value("title") or "",
url=str(ann_url),
additional_properties={"annotation_index": event.annotation_index},
raw_representation=annotation,
)
if ann_start is not None and ann_end is not None:
annotation_obj["annotated_regions"] = [
TextSpanRegion(
type="text_span",
start_index=ann_start,
end_index=ann_end,
)
]
contents.append(
Content.from_text(text="", annotations=[annotation_obj], raw_representation=event)
)
else:
logger.debug("Unparsed annotation type in streaming: %s", ann_type)
case "response.output_item.done":
@@ -2570,8 +2570,65 @@ def test_streaming_annotation_added_with_container_file_citation() -> None:
assert content.additional_properties.get("end_index") == 50
def test_streaming_annotation_added_with_unknown_type() -> None:
"""Test streaming annotation added event with unknown type is ignored."""
def test_streaming_annotation_added_with_url_citation() -> None:
"""Test streaming annotation added event with url_citation type produces citation annotation."""
client = OpenAIChatClient(model="test-model", api_key="test-key")
chat_options = ChatOptions()
function_call_ids: dict[int, tuple[str, str]] = {}
mock_event = MagicMock()
mock_event.type = "response.output_text.annotation.added"
mock_event.annotation_index = 0
mock_event.annotation = {
"type": "url_citation",
"url": "https://example.sharepoint.com/sites/my-site/doc.pdf",
"title": "doc.pdf",
"start_index": 100,
"end_index": 112,
}
response = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids)
assert len(response.contents) == 1
content = response.contents[0]
assert content.type == "text"
assert content.annotations is not None
assert len(content.annotations) == 1
annotation = content.annotations[0]
assert annotation["type"] == "citation"
assert annotation["title"] == "doc.pdf"
assert annotation["url"] == "https://example.sharepoint.com/sites/my-site/doc.pdf"
assert annotation["additional_properties"]["annotation_index"] == 0
assert annotation["raw_representation"] == mock_event.annotation
assert annotation["annotated_regions"] is not None
assert len(annotation["annotated_regions"]) == 1
region = annotation["annotated_regions"][0]
assert region["type"] == "text_span"
assert region["start_index"] == 100
assert region["end_index"] == 112
def test_streaming_annotation_added_with_url_citation_no_url() -> None:
"""Test streaming annotation added event with url_citation but missing url is ignored."""
client = OpenAIChatClient(model="test-model", api_key="test-key")
chat_options = ChatOptions()
function_call_ids: dict[int, tuple[str, str]] = {}
mock_event = MagicMock()
mock_event.type = "response.output_text.annotation.added"
mock_event.annotation_index = 0
mock_event.annotation = {
"type": "url_citation",
"title": "doc.pdf",
}
response = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids)
assert len(response.contents) == 0
def test_streaming_annotation_added_with_url_citation_no_indices() -> None:
"""Test streaming annotation with url_citation that has url but no start_index/end_index."""
client = OpenAIChatClient(model="test-model", api_key="test-key")
chat_options = ChatOptions()
function_call_ids: dict[int, tuple[str, str]] = {}
@@ -2582,11 +2639,36 @@ def test_streaming_annotation_added_with_unknown_type() -> None:
mock_event.annotation = {
"type": "url_citation",
"url": "https://example.com",
"title": "Example",
}
response = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids)
assert len(response.contents) == 1
annotation = response.contents[0].annotations[0]
assert annotation["type"] == "citation"
assert annotation["title"] == "Example"
assert annotation["url"] == "https://example.com"
assert annotation["additional_properties"]["annotation_index"] == 0
assert "annotated_regions" not in annotation
def test_streaming_annotation_added_with_unknown_type() -> None:
"""Test streaming annotation added event with unknown type is ignored."""
client = OpenAIChatClient(model="test-model", api_key="test-key")
chat_options = ChatOptions()
function_call_ids: dict[int, tuple[str, str]] = {}
mock_event = MagicMock()
mock_event.type = "response.output_text.annotation.added"
mock_event.annotation_index = 0
mock_event.annotation = {
"type": "some_future_annotation_type",
"data": "test",
}
response = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids)
# url_citation should not produce HostedFileContent
assert len(response.contents) == 0