mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Merge branch 'main' into feature/python-foundry-hosted-agent-vnext
This commit is contained in:
@@ -374,6 +374,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
contents=contents,
|
||||
role="assistant" if item.role == A2ARole.agent else "user",
|
||||
response_id=str(getattr(item, "message_id", uuid.uuid4())),
|
||||
additional_properties={"a2a_metadata": item.metadata} if item.metadata else None,
|
||||
raw_representation=item,
|
||||
)
|
||||
all_updates.append(update)
|
||||
@@ -452,13 +453,24 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
role=message.role,
|
||||
response_id=task.id,
|
||||
message_id=getattr(message.raw_representation, "artifact_id", None),
|
||||
additional_properties={"a2a_metadata": merged}
|
||||
if (merged := {**message.additional_properties, **(task.metadata or {})})
|
||||
else None,
|
||||
raw_representation=task,
|
||||
)
|
||||
for message in task_messages
|
||||
]
|
||||
if task.artifacts is not None:
|
||||
return []
|
||||
return [AgentResponseUpdate(contents=[], role="assistant", response_id=task.id, raw_representation=task)]
|
||||
return [
|
||||
AgentResponseUpdate(
|
||||
contents=[],
|
||||
role="assistant",
|
||||
response_id=task.id,
|
||||
additional_properties={"a2a_metadata": task.metadata} if task.metadata else None,
|
||||
raw_representation=task,
|
||||
)
|
||||
]
|
||||
|
||||
if background and status.state in IN_PROGRESS_TASK_STATES:
|
||||
token = self._build_continuation_token(task)
|
||||
@@ -468,6 +480,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
role="assistant",
|
||||
response_id=task.id,
|
||||
continuation_token=token,
|
||||
additional_properties={"a2a_metadata": task.metadata} if task.metadata else None,
|
||||
raw_representation=task,
|
||||
)
|
||||
]
|
||||
@@ -488,6 +501,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
contents=contents,
|
||||
role="assistant" if status.message.role == A2ARole.agent else "user",
|
||||
response_id=task.id,
|
||||
additional_properties={"a2a_metadata": task.metadata} if task.metadata else None,
|
||||
raw_representation=task,
|
||||
)
|
||||
]
|
||||
@@ -502,12 +516,17 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
contents = self._parse_contents_from_a2a(update_event.artifact.parts)
|
||||
if not contents:
|
||||
return []
|
||||
merged_metadata = {
|
||||
**(update_event.artifact.metadata or {}),
|
||||
**(update_event.metadata or {}),
|
||||
} or None
|
||||
return [
|
||||
AgentResponseUpdate(
|
||||
contents=contents,
|
||||
role="assistant",
|
||||
response_id=update_event.task_id,
|
||||
message_id=update_event.artifact.artifact_id,
|
||||
additional_properties={"a2a_metadata": merged_metadata} if merged_metadata else None,
|
||||
raw_representation=update_event,
|
||||
)
|
||||
]
|
||||
@@ -523,11 +542,16 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
if not contents:
|
||||
return []
|
||||
|
||||
merged_metadata = {
|
||||
**(message.metadata or {}),
|
||||
**(update_event.metadata or {}),
|
||||
} or None
|
||||
return [
|
||||
AgentResponseUpdate(
|
||||
contents=contents,
|
||||
role="assistant" if message.role == A2ARole.agent else "user",
|
||||
response_id=update_event.task_id,
|
||||
additional_properties={"a2a_metadata": merged_metadata} if merged_metadata else None,
|
||||
raw_representation=update_event,
|
||||
)
|
||||
]
|
||||
@@ -642,9 +666,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
case _:
|
||||
raise ValueError(f"Unknown content type: {content.type}")
|
||||
|
||||
# Exclude framework-internal keys (e.g. attribution) from wire metadata
|
||||
internal_keys = {"_attribution", "context_id"}
|
||||
metadata = {k: v for k, v in message.additional_properties.items() if k not in internal_keys} or None
|
||||
metadata = message.additional_properties.get("a2a_metadata")
|
||||
|
||||
return A2AMessage(
|
||||
role=A2ARole("user"),
|
||||
@@ -718,6 +740,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
Message(
|
||||
role="assistant" if history_item.role == A2ARole.agent else "user",
|
||||
contents=contents,
|
||||
additional_properties=history_item.metadata,
|
||||
raw_representation=history_item,
|
||||
)
|
||||
)
|
||||
@@ -730,5 +753,6 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
return Message(
|
||||
role="assistant",
|
||||
contents=contents,
|
||||
additional_properties=artifact.metadata,
|
||||
raw_representation=artifact,
|
||||
)
|
||||
|
||||
@@ -530,7 +530,7 @@ def test_prepare_message_for_a2a_forwards_context_id() -> None:
|
||||
message = Message(
|
||||
role="user",
|
||||
contents=[Content.from_text(text="Continue the task")],
|
||||
additional_properties={"context_id": "ctx-123", "trace_id": "trace-456"},
|
||||
additional_properties={"context_id": "ctx-123", "a2a_metadata": {"trace_id": "trace-456"}},
|
||||
)
|
||||
|
||||
result = agent._prepare_message_for_a2a(message)
|
||||
@@ -1385,3 +1385,210 @@ async def test_streaming_terminal_task_only_emits_unstreamed_artifacts(
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Metadata propagation tests
|
||||
|
||||
|
||||
async def test_message_metadata_propagated(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""A2AMessage.metadata should appear on response.additional_properties."""
|
||||
msg = A2AMessage(
|
||||
message_id="msg-meta",
|
||||
role=A2ARole.agent,
|
||||
parts=[Part(root=TextPart(text="hi"))],
|
||||
metadata={"source": "server", "trace_id": "abc"},
|
||||
)
|
||||
mock_a2a_client.responses.append(msg)
|
||||
|
||||
response = await a2a_agent.run("hello")
|
||||
assert response.additional_properties["a2a_metadata"]["source"] == "server"
|
||||
assert response.additional_properties["a2a_metadata"]["trace_id"] == "abc"
|
||||
|
||||
|
||||
async def test_artifact_metadata_propagated(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Artifact.metadata should appear on response.additional_properties."""
|
||||
task = Task(
|
||||
id="task-art-meta",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.completed),
|
||||
artifacts=[
|
||||
Artifact(
|
||||
artifact_id="a1",
|
||||
parts=[Part(root=TextPart(text="result"))],
|
||||
metadata={"artifact_key": "artifact_value"},
|
||||
),
|
||||
],
|
||||
)
|
||||
mock_a2a_client.responses.append((task, None))
|
||||
|
||||
response = await a2a_agent.run("go")
|
||||
assert response.additional_properties["a2a_metadata"]["artifact_key"] == "artifact_value"
|
||||
|
||||
|
||||
async def test_task_metadata_propagated_to_response(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Task.metadata should appear on response.additional_properties for terminal tasks."""
|
||||
task = Task(
|
||||
id="task-meta",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.completed),
|
||||
artifacts=[
|
||||
Artifact(artifact_id="a1", parts=[Part(root=TextPart(text="done"))]),
|
||||
],
|
||||
metadata={"task_key": "task_value"},
|
||||
)
|
||||
mock_a2a_client.responses.append((task, None))
|
||||
|
||||
response = await a2a_agent.run("go")
|
||||
assert response.additional_properties["a2a_metadata"]["task_key"] == "task_value"
|
||||
|
||||
|
||||
async def test_task_artifact_update_event_metadata_merged(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""TaskArtifactUpdateEvent and Artifact metadata should both appear on the streaming update."""
|
||||
artifact_event = TaskArtifactUpdateEvent(
|
||||
task_id="task-ae",
|
||||
context_id="ctx",
|
||||
artifact=Artifact(
|
||||
artifact_id="a1",
|
||||
parts=[Part(root=TextPart(text="chunk"))],
|
||||
metadata={"from_artifact": True},
|
||||
),
|
||||
metadata={"from_event": True},
|
||||
)
|
||||
working_task = Task(
|
||||
id="task-ae",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.working),
|
||||
)
|
||||
terminal_task = Task(
|
||||
id="task-ae",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.completed),
|
||||
artifacts=[
|
||||
Artifact(artifact_id="a1", parts=[Part(root=TextPart(text="chunk"))]),
|
||||
],
|
||||
)
|
||||
terminal_event = TaskStatusUpdateEvent(
|
||||
task_id="task-ae",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.completed),
|
||||
final=True,
|
||||
)
|
||||
mock_a2a_client.responses.extend([
|
||||
(working_task, artifact_event),
|
||||
(terminal_task, terminal_event),
|
||||
])
|
||||
|
||||
stream = a2a_agent.run("hello", stream=True)
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
async for update in stream:
|
||||
updates.append(update)
|
||||
|
||||
artifact_update = updates[0]
|
||||
assert artifact_update.additional_properties["a2a_metadata"]["from_artifact"] is True
|
||||
assert artifact_update.additional_properties["a2a_metadata"]["from_event"] is True
|
||||
|
||||
|
||||
async def test_task_status_update_event_metadata_merged(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""TaskStatusUpdateEvent and its message metadata should both appear on the streaming update."""
|
||||
status_event = TaskStatusUpdateEvent(
|
||||
task_id="task-se",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(
|
||||
state=TaskState.working,
|
||||
message=A2AMessage(
|
||||
message_id="m1",
|
||||
role=A2ARole.agent,
|
||||
parts=[Part(root=TextPart(text="working..."))],
|
||||
metadata={"msg_key": "msg_val"},
|
||||
),
|
||||
),
|
||||
final=False,
|
||||
metadata={"event_key": "event_val"},
|
||||
)
|
||||
working_task = Task(
|
||||
id="task-se",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.working),
|
||||
)
|
||||
terminal_task = Task(
|
||||
id="task-se",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.completed),
|
||||
artifacts=[
|
||||
Artifact(artifact_id="a1", parts=[Part(root=TextPart(text="done"))]),
|
||||
],
|
||||
)
|
||||
terminal_event = TaskStatusUpdateEvent(
|
||||
task_id="task-se",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.completed),
|
||||
final=True,
|
||||
)
|
||||
mock_a2a_client.responses.extend([
|
||||
(working_task, status_event),
|
||||
(terminal_task, terminal_event),
|
||||
])
|
||||
|
||||
stream = a2a_agent.run("hello", stream=True)
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
async for update in stream:
|
||||
updates.append(update)
|
||||
|
||||
status_update = updates[0]
|
||||
assert status_update.additional_properties["a2a_metadata"]["msg_key"] == "msg_val"
|
||||
assert status_update.additional_properties["a2a_metadata"]["event_key"] == "event_val"
|
||||
|
||||
|
||||
async def test_history_message_metadata_propagated(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Metadata on a history Message should appear on response.additional_properties."""
|
||||
task = Task(
|
||||
id="task-hist",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.completed),
|
||||
history=[
|
||||
A2AMessage(
|
||||
message_id="h1",
|
||||
role=A2ARole.agent,
|
||||
parts=[Part(root=TextPart(text="reply"))],
|
||||
metadata={"history_key": "history_value"},
|
||||
),
|
||||
],
|
||||
)
|
||||
mock_a2a_client.responses.append((task, None))
|
||||
|
||||
response = await a2a_agent.run("go")
|
||||
assert response.additional_properties["a2a_metadata"]["history_key"] == "history_value"
|
||||
|
||||
|
||||
async def test_continuation_token_update_carries_task_metadata(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""In-progress tasks with background=True should propagate task metadata."""
|
||||
task = Task(
|
||||
id="task-cont",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.working),
|
||||
metadata={"bg_key": "bg_value"},
|
||||
)
|
||||
mock_a2a_client.responses.append((task, None))
|
||||
|
||||
response = await a2a_agent.run("go", background=True)
|
||||
assert response.continuation_token is not None
|
||||
assert response.additional_properties["a2a_metadata"]["bg_key"] == "bg_value"
|
||||
|
||||
|
||||
async def test_none_metadata_leaves_additional_properties_empty(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""When A2A types have no metadata, additional_properties should remain empty/default."""
|
||||
msg = A2AMessage(
|
||||
message_id="msg-none",
|
||||
role=A2ARole.agent,
|
||||
parts=[Part(root=TextPart(text="no meta"))],
|
||||
)
|
||||
mock_a2a_client.responses.append(msg)
|
||||
|
||||
response = await a2a_agent.run("hello")
|
||||
assert not response.additional_properties
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -906,6 +906,9 @@ def _tools_to_dict( # pyright: ignore[reportUnusedFunction]
|
||||
if isinstance(tool_item, FunctionTool):
|
||||
results.append(tool_item.to_json_schema_spec())
|
||||
continue
|
||||
if isinstance(tool_item, BaseModel):
|
||||
results.append(tool_item.model_dump(exclude_none=True))
|
||||
continue
|
||||
if isinstance(tool_item, SerializationMixin):
|
||||
results.append(tool_item.to_dict())
|
||||
continue
|
||||
|
||||
@@ -1879,6 +1879,12 @@ def _process_update(response: ChatResponse | AgentResponse, update: ChatResponse
|
||||
response.finish_reason = update.finish_reason
|
||||
if update.model is not None:
|
||||
response.model = update.model
|
||||
if (
|
||||
isinstance(response, AgentResponse)
|
||||
and isinstance(update, AgentResponseUpdate)
|
||||
and update.finish_reason is not None
|
||||
):
|
||||
response.finish_reason = update.finish_reason
|
||||
response.continuation_token = update.continuation_token
|
||||
|
||||
|
||||
@@ -2435,6 +2441,7 @@ class AgentResponse(SerializationMixin, Generic[ResponseModelT]):
|
||||
response_id: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
created_at: CreatedAtT | None = None,
|
||||
finish_reason: FinishReasonLiteral | FinishReason | None = None,
|
||||
usage_details: UsageDetails | None = None,
|
||||
value: ResponseModelT | None = None,
|
||||
response_format: StructuredResponseFormat = None,
|
||||
@@ -2450,6 +2457,9 @@ class AgentResponse(SerializationMixin, Generic[ResponseModelT]):
|
||||
agent_id: The identifier of the agent that produced this response. Useful in multi-agent
|
||||
scenarios to track which agent generated the response.
|
||||
created_at: A timestamp for the chat response.
|
||||
finish_reason: The reason the model stopped generating. Common values include
|
||||
``"stop"`` (natural completion), ``"length"`` (token limit), and
|
||||
``"tool_calls"`` (the model invoked a tool).
|
||||
usage_details: The usage details for the chat response.
|
||||
value: The structured output of the agent run response, if applicable.
|
||||
response_format: Optional response format for the agent response.
|
||||
@@ -2476,6 +2486,7 @@ class AgentResponse(SerializationMixin, Generic[ResponseModelT]):
|
||||
self.response_id = response_id
|
||||
self.agent_id = agent_id
|
||||
self.created_at = created_at
|
||||
self.finish_reason = finish_reason
|
||||
self.usage_details = usage_details
|
||||
self._value: ResponseModelT | None = value
|
||||
self._response_format: type[BaseModel] | Mapping[str, Any] | None = response_format
|
||||
@@ -2688,6 +2699,7 @@ class AgentResponseUpdate(SerializationMixin):
|
||||
response_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
created_at: CreatedAtT | None = None,
|
||||
finish_reason: FinishReasonLiteral | FinishReason | None = None,
|
||||
continuation_token: ContinuationToken | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
raw_representation: Any | None = None,
|
||||
@@ -2703,6 +2715,9 @@ class AgentResponseUpdate(SerializationMixin):
|
||||
response_id: Optional ID of the response of which this update is a part.
|
||||
message_id: Optional ID of the message of which this update is a part.
|
||||
created_at: Optional timestamp for the chat response update.
|
||||
finish_reason: The reason the model stopped generating. Common values include
|
||||
``"stop"`` (natural completion), ``"length"`` (token limit), and
|
||||
``"tool_calls"`` (the model invoked a tool).
|
||||
continuation_token: Optional token for resuming a long-running background operation.
|
||||
When present, indicates the operation is still in progress.
|
||||
additional_properties: Optional additional properties associated with the chat response update.
|
||||
@@ -2729,6 +2744,7 @@ class AgentResponseUpdate(SerializationMixin):
|
||||
self.response_id = response_id
|
||||
self.message_id = message_id
|
||||
self.created_at = created_at
|
||||
self.finish_reason = finish_reason
|
||||
self.continuation_token = continuation_token
|
||||
self.additional_properties = _restore_compaction_annotation_in_additional_properties(
|
||||
additional_properties,
|
||||
@@ -2761,6 +2777,7 @@ def map_chat_to_agent_update(update: ChatResponseUpdate, agent_name: str | None)
|
||||
response_id=update.response_id,
|
||||
message_id=update.message_id,
|
||||
created_at=update.created_at,
|
||||
finish_reason=update.finish_reason, # type: ignore[arg-type]
|
||||
continuation_token=update.continuation_token,
|
||||
additional_properties=update.additional_properties,
|
||||
raw_representation=update,
|
||||
|
||||
@@ -59,6 +59,62 @@ class AgentExecutorResponse:
|
||||
agent_response: AgentResponse
|
||||
full_conversation: list[Message]
|
||||
|
||||
def with_text(self, text: str) -> "AgentExecutorResponse":
|
||||
"""Create a new AgentExecutorResponse with replaced text, preserving the conversation history.
|
||||
|
||||
Use this in custom executors that transform agent output text (e.g. upper-casing, summarising)
|
||||
when you need downstream AgentExecutors to still have access to the full prior conversation.
|
||||
|
||||
Without this helper, sending a plain ``str`` from a custom executor breaks the context chain:
|
||||
the downstream ``AgentExecutor.from_str`` handler only adds that one string to its cache and
|
||||
loses all prior messages. By using ``with_text`` the response type stays
|
||||
``AgentExecutorResponse``, so ``AgentExecutor.from_response`` is invoked instead and the full
|
||||
conversation is preserved.
|
||||
|
||||
Args:
|
||||
text: The replacement assistant message text.
|
||||
|
||||
Returns:
|
||||
A new ``AgentExecutorResponse`` whose ``agent_response`` contains a single assistant
|
||||
message with ``text``, and whose ``full_conversation`` is the prior conversation
|
||||
(everything before the original agent turn) followed by the new assistant message.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework import AgentExecutorResponse, WorkflowContext, executor
|
||||
|
||||
|
||||
@executor(
|
||||
id="upper_case_executor",
|
||||
input=AgentExecutorResponse,
|
||||
output=AgentExecutorResponse,
|
||||
workflow_output=str,
|
||||
)
|
||||
async def upper_case(
|
||||
response: AgentExecutorResponse,
|
||||
ctx: WorkflowContext[AgentExecutorResponse, str],
|
||||
) -> None:
|
||||
upper_text = response.agent_response.text.upper()
|
||||
await ctx.send_message(response.with_text(upper_text))
|
||||
await ctx.yield_output(upper_text)
|
||||
"""
|
||||
new_message = Message("assistant", [text])
|
||||
new_agent_response = AgentResponse(messages=[new_message])
|
||||
|
||||
# Strip off the original agent turn and replace with the new text.
|
||||
n_agent_messages = len(self.agent_response.messages)
|
||||
prior_messages = (
|
||||
self.full_conversation[:-n_agent_messages] if n_agent_messages else list(self.full_conversation)
|
||||
)
|
||||
new_full_conversation = [*prior_messages, new_message]
|
||||
|
||||
return AgentExecutorResponse(
|
||||
executor_id=self.executor_id,
|
||||
agent_response=new_agent_response,
|
||||
full_conversation=new_full_conversation,
|
||||
)
|
||||
|
||||
|
||||
class AgentExecutor(Executor):
|
||||
"""built-in executor that wraps an agent for handling messages.
|
||||
@@ -183,7 +239,25 @@ class AgentExecutor(Executor):
|
||||
"""Accept a raw user prompt string and run the agent.
|
||||
|
||||
The new string input will be added to the cache which is used as the conversation context for the agent run.
|
||||
|
||||
Warning:
|
||||
If the upstream executor received an ``AgentExecutorResponse`` but emits a plain
|
||||
``str``, this handler will be invoked instead of ``from_response``. This resets
|
||||
the conversation context because only the new string is added to the cache and
|
||||
all prior messages from the upstream agent are lost.
|
||||
|
||||
To preserve the full conversation when transforming agent output in a custom
|
||||
executor, use ``AgentExecutorResponse.with_text(...)`` so that the message type
|
||||
stays ``AgentExecutorResponse`` and ``from_response`` is called instead.
|
||||
"""
|
||||
if not self._cache and ctx.source_executor_ids != ["Workflow"]:
|
||||
logger.warning(
|
||||
"AgentExecutor '%s': from_str handler invoked with an empty cache. "
|
||||
"If you are chaining from an AgentExecutor, the upstream custom executor may be "
|
||||
"emitting a plain str instead of using AgentExecutorResponse.with_text(...), "
|
||||
"which causes the full conversation context to be lost.",
|
||||
self.id,
|
||||
)
|
||||
self._cache.extend(normalize_messages_input(text))
|
||||
await self._run_agent_and_emit(ctx)
|
||||
|
||||
|
||||
@@ -268,6 +268,19 @@ def executor(
|
||||
forward references. When provided, takes precedence over introspection from the
|
||||
``WorkflowContext`` second generic parameter (W_OutT).
|
||||
|
||||
Warning:
|
||||
When placing a custom ``@executor`` **between** two ``AgentExecutor`` nodes, be
|
||||
careful about the output type. If the custom executor receives an
|
||||
``AgentExecutorResponse`` but emits a plain ``str``, the downstream
|
||||
``AgentExecutor.from_str`` handler is invoked instead of ``from_response``.
|
||||
This resets the conversation context because only the new string is added to
|
||||
the cache and all prior messages from the upstream agent are lost.
|
||||
|
||||
To preserve the full conversation, use
|
||||
``AgentExecutorResponse.with_text(new_text)`` to create a new response that
|
||||
keeps the prior history, and set ``output=AgentExecutorResponse`` on the
|
||||
decorator.
|
||||
|
||||
Returns:
|
||||
A FunctionExecutor instance that can be wired into a Workflow.
|
||||
|
||||
|
||||
@@ -16,12 +16,26 @@ from agent_framework._middleware import FunctionInvocationContext
|
||||
from agent_framework._tools import (
|
||||
_parse_annotation,
|
||||
_parse_inputs,
|
||||
_tools_to_dict,
|
||||
)
|
||||
from agent_framework.observability import OtelAttr
|
||||
|
||||
# region FunctionTool and tool decorator tests
|
||||
|
||||
|
||||
def test_tools_to_dict_supports_pydantic_tool_models() -> None:
|
||||
"""Pydantic-based tool specs are serialized without logging parse warnings."""
|
||||
|
||||
class ProviderTool(BaseModel):
|
||||
kind: str
|
||||
enabled: bool = True
|
||||
note: str | None = None
|
||||
|
||||
result = _tools_to_dict([ProviderTool(kind="google_search")])
|
||||
|
||||
assert result == [{"kind": "google_search", "enabled": True}]
|
||||
|
||||
|
||||
def test_tool_decorator():
|
||||
"""Test the tool decorator."""
|
||||
|
||||
|
||||
@@ -40,8 +40,10 @@ from agent_framework._types import (
|
||||
_get_data_bytes_as_str,
|
||||
_parse_content_list,
|
||||
_parse_structured_response_value,
|
||||
_process_update,
|
||||
_validate_uri,
|
||||
add_usage_details,
|
||||
map_chat_to_agent_update,
|
||||
validate_tool_mode,
|
||||
)
|
||||
from agent_framework.exceptions import AdditionItemMismatch, ContentError
|
||||
@@ -4179,3 +4181,101 @@ def test_prepend_instructions_custom_role():
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region finish_reason
|
||||
|
||||
|
||||
def test_agent_response_init_with_finish_reason() -> None:
|
||||
"""Test that AgentResponse correctly initializes and stores finish_reason."""
|
||||
response = AgentResponse(
|
||||
messages=[Message("assistant", [Content.from_text("test")])],
|
||||
finish_reason="stop",
|
||||
)
|
||||
assert response.finish_reason == "stop"
|
||||
|
||||
|
||||
def test_agent_response_update_init_with_finish_reason() -> None:
|
||||
"""Test that AgentResponseUpdate correctly initializes and stores finish_reason."""
|
||||
update = AgentResponseUpdate(
|
||||
contents=[Content.from_text("test")],
|
||||
role="assistant",
|
||||
finish_reason="stop",
|
||||
)
|
||||
assert update.finish_reason == "stop"
|
||||
|
||||
|
||||
def test_map_chat_to_agent_update_forwards_finish_reason() -> None:
|
||||
"""Test that mapping a ChatResponseUpdate with finish_reason forwards it."""
|
||||
chat_update = ChatResponseUpdate(
|
||||
contents=[Content.from_text("test")],
|
||||
finish_reason="length",
|
||||
)
|
||||
agent_update = map_chat_to_agent_update(chat_update, agent_name="test_agent")
|
||||
|
||||
assert agent_update.finish_reason == "length"
|
||||
assert agent_update.author_name == "test_agent"
|
||||
|
||||
|
||||
def test_process_update_propagates_finish_reason_to_agent_response() -> None:
|
||||
"""Test that _process_update correctly updates an AgentResponse from an AgentResponseUpdate."""
|
||||
response = AgentResponse(messages=[Message("assistant", [Content.from_text("test")])])
|
||||
update = AgentResponseUpdate(
|
||||
contents=[Content.from_text("more text")],
|
||||
role="assistant",
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
# Process the update
|
||||
_process_update(response, update)
|
||||
|
||||
assert response.finish_reason == "stop"
|
||||
|
||||
|
||||
def test_process_update_does_not_overwrite_with_none() -> None:
|
||||
"""Test that _process_update does not overwrite an existing finish_reason with None."""
|
||||
response = AgentResponse(
|
||||
messages=[Message("assistant", [Content.from_text("test")])],
|
||||
finish_reason="length",
|
||||
)
|
||||
update = AgentResponseUpdate(
|
||||
contents=[Content.from_text("more text")],
|
||||
role="assistant",
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
# Process the update
|
||||
_process_update(response, update)
|
||||
|
||||
assert response.finish_reason == "length"
|
||||
|
||||
|
||||
def test_agent_response_serialization_includes_finish_reason() -> None:
|
||||
"""Test that AgentResponse serializes correctly, including finish_reason."""
|
||||
response = AgentResponse(
|
||||
messages=[Message("assistant", [Content.from_text("test")])],
|
||||
response_id="test_123",
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
# Serialize using the framework's API and verify finish_reason is included.
|
||||
data = response.to_dict()
|
||||
assert "finish_reason" in data
|
||||
assert data["finish_reason"] == "stop"
|
||||
|
||||
|
||||
def test_agent_response_update_serialization_includes_finish_reason() -> None:
|
||||
"""Test that AgentResponseUpdate serializes correctly, including finish_reason."""
|
||||
update = AgentResponseUpdate(
|
||||
contents=[Content.from_text("test")],
|
||||
role="assistant",
|
||||
response_id="test_456",
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
|
||||
data = update.to_dict()
|
||||
assert "finish_reason" in data
|
||||
assert data["finish_reason"] == "tool_calls"
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -23,6 +23,7 @@ from agent_framework import (
|
||||
WorkflowBuilder,
|
||||
WorkflowContext,
|
||||
WorkflowRunState,
|
||||
executor,
|
||||
handler,
|
||||
)
|
||||
from agent_framework.orchestrations import SequentialBuilder
|
||||
@@ -478,3 +479,90 @@ async def test_from_response_preserves_service_session_id() -> None:
|
||||
assert result.get_outputs() is not None
|
||||
|
||||
assert spy_agent._captured_service_session_id == "resp_PREVIOUS_RUN" # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
@executor(
|
||||
id="upper_case_executor",
|
||||
input=AgentExecutorResponse,
|
||||
output=AgentExecutorResponse,
|
||||
workflow_output=str,
|
||||
)
|
||||
async def _upper_case_executor(
|
||||
response: AgentExecutorResponse,
|
||||
ctx: WorkflowContext[AgentExecutorResponse, str],
|
||||
) -> None:
|
||||
upper_text = response.agent_response.text.upper()
|
||||
await ctx.send_message(response.with_text(upper_text))
|
||||
await ctx.yield_output(upper_text)
|
||||
|
||||
|
||||
async def test_with_text_preserves_full_conversation_through_custom_executor() -> None:
|
||||
"""Custom executor using with_text must preserve the full conversation chain."""
|
||||
# Mirrors the reproduction from issue #5246:
|
||||
# agent1 ("User likes sky red") -> agent2 ("User likes sky blue") -> upper_case -> agent3 ("User likes sky green")
|
||||
agent1 = AgentExecutor(
|
||||
_SimpleAgent(id="agent1", name="ContextAgent1", reply_text="User likes sky red"), id="agent1"
|
||||
)
|
||||
agent2 = AgentExecutor(
|
||||
_SimpleAgent(id="agent2", name="ContextAgent2", reply_text="User likes sky blue"), id="agent2"
|
||||
)
|
||||
agent3 = AgentExecutor(
|
||||
_SimpleAgent(id="agent3", name="ContextAgent3", reply_text="User likes sky green"), id="agent3"
|
||||
)
|
||||
capturer = _CaptureFullConversation(id="capture")
|
||||
|
||||
wf = (
|
||||
WorkflowBuilder(start_executor=agent1, output_executors=[capturer])
|
||||
.add_chain([agent1, agent2, _upper_case_executor, agent3, capturer])
|
||||
.build()
|
||||
)
|
||||
|
||||
result = await wf.run("")
|
||||
payload = next(o for o in result.get_outputs() if isinstance(o, dict))
|
||||
|
||||
# The final agent must see the full conversation: user, agent1, UPPER(agent2), agent3
|
||||
assert payload["roles"] == ["user", "assistant", "assistant", "assistant"]
|
||||
assert payload["texts"][1] == "User likes sky red"
|
||||
assert payload["texts"][2] == "USER LIKES SKY BLUE"
|
||||
assert payload["texts"][3] == "User likes sky green"
|
||||
|
||||
|
||||
async def test_with_text_does_not_mutate_original() -> None:
|
||||
"""with_text returns a new instance; the original must be unmodified."""
|
||||
original = AgentExecutorResponse(
|
||||
executor_id="test_exec",
|
||||
agent_response=AgentResponse(messages=[Message("assistant", ["original reply"])]),
|
||||
full_conversation=[Message("user", ["prompt"]), Message("assistant", ["original reply"])],
|
||||
)
|
||||
|
||||
new = original.with_text("transformed reply")
|
||||
|
||||
assert new is not original
|
||||
assert new.agent_response.text == "transformed reply"
|
||||
assert new.full_conversation[-1].text == "transformed reply"
|
||||
assert new.full_conversation[-1].role == "assistant"
|
||||
# Original unchanged
|
||||
assert original.agent_response.text == "original reply"
|
||||
assert original.full_conversation[-1].text == "original reply"
|
||||
|
||||
|
||||
async def test_with_text_strips_multi_message_agent_turn() -> None:
|
||||
"""When the agent turn has multiple messages (tool calls), with_text strips all of them."""
|
||||
tool_call = Message("assistant", ["<tool_call>"])
|
||||
tool_result = Message("tool", ["<result>"])
|
||||
final_reply = Message("assistant", ["actual answer"])
|
||||
user_msg = Message("user", ["question"])
|
||||
|
||||
original = AgentExecutorResponse(
|
||||
executor_id="exec",
|
||||
agent_response=AgentResponse(messages=[tool_call, tool_result, final_reply]),
|
||||
full_conversation=[user_msg, tool_call, tool_result, final_reply],
|
||||
)
|
||||
|
||||
new = original.with_text("summarised answer")
|
||||
|
||||
# Only the pre-agent-turn messages should remain, plus the replacement
|
||||
assert len(new.full_conversation) == 2
|
||||
assert new.full_conversation[0].text == "question"
|
||||
assert new.full_conversation[1].text == "summarised answer"
|
||||
assert new.agent_response.text == "summarised answer"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Gemini Package (agent-framework-gemini)
|
||||
|
||||
Integration with Google's Gemini API via the `google-genai` SDK.
|
||||
Integration with Google's Gemini Developer API and Vertex AI via the `google-genai` SDK.
|
||||
|
||||
## Core Classes
|
||||
|
||||
@@ -8,6 +8,7 @@ Integration with Google's Gemini API via the `google-genai` SDK.
|
||||
- **`GeminiChatClient`** - Full-featured chat client with function invocation, middleware, and telemetry
|
||||
- **`GeminiChatOptions`** - Options TypedDict for Gemini-specific parameters
|
||||
- **`GeminiSettings`** - Settings loaded from environment variables
|
||||
- **`GoogleGeminiSettings`** - SDK-standard `GOOGLE_*` settings loaded from environment variables
|
||||
- **`ThinkingConfig`** - Configuration for extended thinking
|
||||
|
||||
## Gemini-specific Options
|
||||
|
||||
@@ -12,11 +12,28 @@ The Gemini integration enables Microsoft Agent Framework applications to call Go
|
||||
|
||||
## Authentication
|
||||
|
||||
Obtain an API key from [Google AI Studio](https://aistudio.google.com/apikey) and set it via environment variable:
|
||||
The connector supports both `google-genai` authentication modes.
|
||||
|
||||
### Gemini Developer API
|
||||
|
||||
Obtain an API key from [Google AI Studio](https://aistudio.google.com/apikey) and set either the package-prefixed or SDK-standard environment variable:
|
||||
|
||||
```bash
|
||||
export GEMINI_API_KEY="your-api-key"
|
||||
export GEMINI_MODEL="gemini-2.5-flash"
|
||||
# or: export GOOGLE_API_KEY="your-api-key"
|
||||
export GEMINI_MODEL="gemini-2.5-flash-lite"
|
||||
# or: export GOOGLE_MODEL="gemini-2.5-flash-lite"
|
||||
```
|
||||
|
||||
### Vertex AI
|
||||
|
||||
Set the standard Vertex AI environment variables used by `google-genai`:
|
||||
|
||||
```bash
|
||||
export GOOGLE_GENAI_USE_VERTEXAI=true
|
||||
export GOOGLE_CLOUD_PROJECT="your-project-id"
|
||||
export GOOGLE_CLOUD_LOCATION="global"
|
||||
export GOOGLE_MODEL="gemini-2.5-flash-lite"
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
@@ -2,7 +2,14 @@
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
from ._chat_client import GeminiChatClient, GeminiChatOptions, GeminiSettings, RawGeminiChatClient, ThinkingConfig
|
||||
from ._chat_client import (
|
||||
GeminiChatClient,
|
||||
GeminiChatOptions,
|
||||
GeminiSettings,
|
||||
GoogleGeminiSettings,
|
||||
RawGeminiChatClient,
|
||||
ThinkingConfig,
|
||||
)
|
||||
|
||||
try:
|
||||
__version__ = importlib.metadata.version(__name__)
|
||||
@@ -13,6 +20,7 @@ __all__ = [
|
||||
"GeminiChatClient",
|
||||
"GeminiChatOptions",
|
||||
"GeminiSettings",
|
||||
"GoogleGeminiSettings",
|
||||
"RawGeminiChatClient",
|
||||
"ThinkingConfig",
|
||||
"__version__",
|
||||
|
||||
@@ -30,6 +30,7 @@ from agent_framework import (
|
||||
from agent_framework._settings import SecretString, load_settings
|
||||
from agent_framework.observability import ChatTelemetryLayer
|
||||
from google import genai
|
||||
from google.auth.credentials import Credentials
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -54,6 +55,7 @@ __all__ = [
|
||||
"GeminiChatClient",
|
||||
"GeminiChatOptions",
|
||||
"GeminiSettings",
|
||||
"GoogleGeminiSettings",
|
||||
"RawGeminiChatClient",
|
||||
"ThinkingConfig",
|
||||
]
|
||||
@@ -161,10 +163,74 @@ class GeminiSettings(TypedDict, total=False):
|
||||
model: str | None
|
||||
|
||||
|
||||
class GoogleGeminiSettings(TypedDict, total=False):
|
||||
"""Google SDK configuration settings loaded from ``GOOGLE_*`` environment variables."""
|
||||
|
||||
api_key: SecretString | None
|
||||
model: str | None
|
||||
genai_use_vertexai: bool | None
|
||||
cloud_project: str | None
|
||||
cloud_location: str | None
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
_GEMINI_SERVICE_URL = "https://generativelanguage.googleapis.com"
|
||||
_GEMINI_API_BASE_URL = "https://generativelanguage.googleapis.com"
|
||||
_VERTEX_AI_BASE_URL = "https://aiplatform.googleapis.com"
|
||||
|
||||
|
||||
def _resolve_vertexai_mode(client: genai.Client, *, fallback: bool | None = None) -> bool:
|
||||
"""Resolve whether a client targets Vertex AI, preferring the instantiated SDK client state."""
|
||||
api_client = getattr(client, "_api_client", None)
|
||||
vertexai = getattr(api_client, "vertexai", None)
|
||||
if isinstance(vertexai, bool):
|
||||
return vertexai
|
||||
return bool(fallback)
|
||||
|
||||
|
||||
def _resolve_service_url(client: genai.Client, *, vertexai: bool) -> str:
|
||||
"""Resolve the base service URL from the instantiated SDK client, with a stable fallback."""
|
||||
api_client = getattr(client, "_api_client", None)
|
||||
http_options = getattr(api_client, "_http_options", None)
|
||||
base_url = getattr(http_options, "base_url", None)
|
||||
if isinstance(base_url, str) and base_url:
|
||||
return base_url.rstrip("/")
|
||||
return _VERTEX_AI_BASE_URL if vertexai else _GEMINI_API_BASE_URL
|
||||
|
||||
|
||||
def _validate_client_auth_configuration(
|
||||
*,
|
||||
vertexai: bool | None,
|
||||
api_key: SecretString | None,
|
||||
project: str | None,
|
||||
location: str | None,
|
||||
credentials: Credentials | None,
|
||||
) -> None:
|
||||
"""Validate supported auth combinations before instantiating the SDK client."""
|
||||
if vertexai is not True:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Gemini client requires an API key when Vertex AI is not enabled. "
|
||||
"Set GOOGLE_API_KEY or GEMINI_API_KEY, or pass api_key explicitly."
|
||||
)
|
||||
return
|
||||
|
||||
if api_key is not None or credentials is not None or (project and location):
|
||||
return
|
||||
|
||||
if project or location:
|
||||
raise ValueError(
|
||||
"Gemini client requires both GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION "
|
||||
"when Vertex AI is enabled without an API key."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
"Gemini client requires Vertex AI credentials or configuration when Vertex AI is enabled. "
|
||||
"Provide GOOGLE_API_KEY for Vertex AI express mode, pass credentials, or set "
|
||||
"GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION."
|
||||
)
|
||||
|
||||
|
||||
# Keys mapping to a different GenerateContentConfig field name
|
||||
_OPTION_TRANSLATIONS: dict[str, str] = {
|
||||
@@ -210,7 +276,7 @@ class RawGeminiChatClient(
|
||||
BaseChatClient[GeminiChatOptionsT],
|
||||
Generic[GeminiChatOptionsT],
|
||||
):
|
||||
"""A raw Gemini chat client for the Google Gemini API without function invocation, middleware or telemetry.
|
||||
"""A raw Gemini chat client for Gemini Developer API or Vertex AI.
|
||||
|
||||
Use this when you want full control over the request pipeline. For instance, to opt out of
|
||||
telemetry, use custom middleware, or compose your own layers. If you want the full-featured
|
||||
@@ -224,6 +290,10 @@ class RawGeminiChatClient(
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
model: str | None = None,
|
||||
vertexai: bool | None = None,
|
||||
project: str | None = None,
|
||||
location: str | None = None,
|
||||
credentials: Credentials | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
client: genai.Client | None = None,
|
||||
@@ -232,11 +302,21 @@ class RawGeminiChatClient(
|
||||
"""Create a raw Gemini chat client.
|
||||
|
||||
Args:
|
||||
api_key: Google AI Studio API key. Falls back to ``GEMINI_API_KEY`` environment variable.
|
||||
model: Default model identifier. Falls back to ``GEMINI_MODEL`` environment variable.
|
||||
api_key: Gemini Developer API key. Falls back to environment settings, preferring
|
||||
``GOOGLE_API_KEY`` over ``GEMINI_API_KEY``.
|
||||
model: Default model identifier. Falls back to environment settings, preferring
|
||||
``GOOGLE_MODEL`` over ``GEMINI_MODEL``.
|
||||
vertexai: Whether to use Vertex AI endpoints. Falls back to environment settings,
|
||||
using ``GOOGLE_GENAI_USE_VERTEXAI`` when not passed explicitly.
|
||||
project: Google Cloud project ID for Vertex AI. Falls back to environment settings,
|
||||
using ``GOOGLE_CLOUD_PROJECT`` when not passed explicitly.
|
||||
location: Vertex AI location. Falls back to environment settings, preferring
|
||||
using ``GOOGLE_CLOUD_LOCATION`` when not passed explicitly.
|
||||
credentials: Google Cloud credentials for Vertex AI. When omitted, the SDK can use
|
||||
Application Default Credentials.
|
||||
env_file_path: Path to a ``.env`` file for credential loading.
|
||||
env_file_encoding: Encoding for the ``.env`` file.
|
||||
client: Pre-built ``genai.Client`` instance. When provided, ``api_key`` is not required.
|
||||
client: Pre-built ``genai.Client`` instance. When provided, connector auth settings are not required.
|
||||
additional_properties: Extra properties stored on the client instance.
|
||||
"""
|
||||
settings = load_settings(
|
||||
@@ -247,21 +327,58 @@ class RawGeminiChatClient(
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
google_settings = load_settings(
|
||||
GoogleGeminiSettings,
|
||||
env_prefix="GOOGLE_",
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
genai_use_vertexai=vertexai,
|
||||
cloud_project=project,
|
||||
cloud_location=location,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
configured_vertexai = google_settings.get("genai_use_vertexai")
|
||||
if client:
|
||||
self._genai_client = client
|
||||
else:
|
||||
resolved_key = settings.get("api_key")
|
||||
if not resolved_key:
|
||||
raise ValueError(
|
||||
"Gemini API key is required. Set via api_key parameter or GEMINI_API_KEY environment variable."
|
||||
)
|
||||
self._genai_client = genai.Client(
|
||||
api_key=resolved_key.get_secret_value(),
|
||||
http_options={"headers": {"x-goog-api-client": AGENT_FRAMEWORK_USER_AGENT}},
|
||||
resolved_key = google_settings.get("api_key") or settings.get("api_key")
|
||||
resolved_project = google_settings.get("cloud_project")
|
||||
resolved_location = google_settings.get("cloud_location")
|
||||
_validate_client_auth_configuration(
|
||||
vertexai=configured_vertexai,
|
||||
api_key=resolved_key,
|
||||
project=resolved_project,
|
||||
location=resolved_location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
self.model = settings.get("model")
|
||||
client_kwargs: dict[str, Any] = {
|
||||
"http_options": {"headers": {"x-goog-api-client": AGENT_FRAMEWORK_USER_AGENT}},
|
||||
}
|
||||
if configured_vertexai is not None:
|
||||
client_kwargs["vertexai"] = configured_vertexai
|
||||
|
||||
if resolved_key is not None and (
|
||||
configured_vertexai is not True
|
||||
or (credentials is None and not (resolved_project and resolved_location))
|
||||
):
|
||||
client_kwargs["api_key"] = resolved_key.get_secret_value()
|
||||
|
||||
if configured_vertexai is True and resolved_project:
|
||||
client_kwargs["project"] = resolved_project
|
||||
|
||||
if configured_vertexai is True and resolved_location:
|
||||
client_kwargs["location"] = resolved_location
|
||||
if configured_vertexai is True and credentials is not None:
|
||||
client_kwargs["credentials"] = credentials
|
||||
|
||||
self._genai_client = genai.Client(**client_kwargs)
|
||||
|
||||
self._vertexai = _resolve_vertexai_mode(self._genai_client, fallback=configured_vertexai)
|
||||
self._service_url = _resolve_service_url(self._genai_client, vertexai=self._vertexai)
|
||||
self.model = google_settings.get("model") or settings.get("model")
|
||||
|
||||
super().__init__(additional_properties=additional_properties)
|
||||
|
||||
@@ -414,12 +531,12 @@ class RawGeminiChatClient(
|
||||
|
||||
@override
|
||||
def service_url(self) -> str:
|
||||
"""Return the base URL of the Gemini API service.
|
||||
"""Return the base URL of the configured Gemini or Vertex AI service.
|
||||
|
||||
Returns:
|
||||
The Gemini API base URL.
|
||||
The resolved service base URL.
|
||||
"""
|
||||
return _GEMINI_SERVICE_URL
|
||||
return self._service_url
|
||||
|
||||
# region Request preparation
|
||||
|
||||
@@ -528,15 +645,16 @@ class RawGeminiChatClient(
|
||||
call_id = content.call_id or self._generate_tool_call_id()
|
||||
if content.name:
|
||||
call_id_to_name[call_id] = content.name
|
||||
parts.append(
|
||||
types.Part(
|
||||
function_call=types.FunctionCall(
|
||||
id=call_id,
|
||||
name=content.name or "",
|
||||
args=content.parse_arguments() or {},
|
||||
)
|
||||
)
|
||||
function_call = types.FunctionCall(
|
||||
id=call_id,
|
||||
name=content.name or "",
|
||||
args=content.parse_arguments() or {},
|
||||
)
|
||||
raw_part = content.raw_representation
|
||||
if isinstance(raw_part, types.Part) and raw_part.function_call is not None:
|
||||
parts.append(raw_part.model_copy(update={"function_call": function_call}, deep=True))
|
||||
else:
|
||||
parts.append(types.Part(function_call=function_call))
|
||||
case _:
|
||||
logger.debug("Skipping unsupported content type for Gemini: %s", content.type)
|
||||
return parts
|
||||
@@ -889,7 +1007,7 @@ class GeminiChatClient(
|
||||
RawGeminiChatClient[GeminiChatOptionsT],
|
||||
Generic[GeminiChatOptionsT],
|
||||
):
|
||||
"""Gemini chat client for the Google Gemini API with function invocation, middleware, and telemetry.
|
||||
"""Gemini chat client for Gemini Developer API or Vertex AI with function invocation, middleware, and telemetry.
|
||||
|
||||
This is the recommended client for most use cases. It builds on ``RawGeminiChatClient``
|
||||
and adds:
|
||||
@@ -908,6 +1026,10 @@ class GeminiChatClient(
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
model: str | None = None,
|
||||
vertexai: bool | None = None,
|
||||
project: str | None = None,
|
||||
location: str | None = None,
|
||||
credentials: Credentials | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
client: genai.Client | None = None,
|
||||
@@ -918,11 +1040,18 @@ class GeminiChatClient(
|
||||
"""Create a Gemini chat client.
|
||||
|
||||
Args:
|
||||
api_key: The Google AI Studio API key. Falls back to ``GEMINI_API_KEY`` environment variable.
|
||||
model: Default model identifier. Falls back to ``GEMINI_MODEL`` environment variable.
|
||||
api_key: Gemini Developer API key. Falls back to environment settings, preferring
|
||||
``GOOGLE_API_KEY`` over ``GEMINI_API_KEY``.
|
||||
model: Default model identifier. Falls back to environment settings, preferring
|
||||
``GOOGLE_MODEL`` over ``GEMINI_MODEL``.
|
||||
vertexai: Whether to use Vertex AI endpoints. Falls back to ``GOOGLE_GENAI_USE_VERTEXAI``.
|
||||
project: Google Cloud project ID for Vertex AI. Falls back to ``GOOGLE_CLOUD_PROJECT``.
|
||||
location: Vertex AI location. Falls back to ``GOOGLE_CLOUD_LOCATION``.
|
||||
credentials: Google Cloud credentials for Vertex AI. When omitted, the SDK can use
|
||||
Application Default Credentials.
|
||||
env_file_path: Path to a ``.env`` file for credential loading.
|
||||
env_file_encoding: Encoding for the ``.env`` file.
|
||||
client: Pre-built ``genai.Client`` instance. When provided, ``api_key`` is not required.
|
||||
client: Pre-built ``genai.Client`` instance. When provided, connector auth settings are not required.
|
||||
additional_properties: Extra properties stored on the client instance.
|
||||
middleware: Optional middleware chain applied to every call.
|
||||
function_invocation_configuration: Optional configuration for the function invocation loop.
|
||||
@@ -930,6 +1059,10 @@ class GeminiChatClient(
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
vertexai=vertexai,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
client=client,
|
||||
|
||||
@@ -14,5 +14,7 @@ This folder contains examples demonstrating how to use Google Gemini models with
|
||||
|
||||
## Environment Variables
|
||||
|
||||
- `GEMINI_API_KEY`: Your Google AI Studio API key (get one from [Google AI Studio](https://aistudio.google.com/apikey))
|
||||
- `GEMINI_MODEL`: The Gemini model to use (e.g., `gemini-2.5-flash`, `gemini-2.5-pro`)
|
||||
- `GOOGLE_MODEL` or `GEMINI_MODEL`: The Gemini model to use (for example,
|
||||
`gemini-2.5-flash-lite` or `gemini-2.5-pro`)
|
||||
- For Gemini Developer API: `GEMINI_API_KEY` or `GOOGLE_API_KEY`
|
||||
- For Vertex AI: `GOOGLE_GENAI_USE_VERTEXAI=true`, `GOOGLE_CLOUD_PROJECT`, and `GOOGLE_CLOUD_LOCATION`
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
|
||||
Allows the model to reason through complex problems before responding.
|
||||
|
||||
Requires the following environment variables to be set:
|
||||
- GEMINI_API_KEY
|
||||
- GEMINI_MODEL
|
||||
Requires ``GOOGLE_MODEL`` or ``GEMINI_MODEL`` and either Gemini Developer API credentials
|
||||
(``GEMINI_API_KEY`` or ``GOOGLE_API_KEY``) or Vertex AI settings
|
||||
(``GOOGLE_GENAI_USE_VERTEXAI``, ``GOOGLE_CLOUD_PROJECT``, and ``GOOGLE_CLOUD_LOCATION``).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -23,10 +23,12 @@ async def main() -> None:
|
||||
"""Example of extended thinking with a Python version comparison question."""
|
||||
print("=== Extended thinking ===")
|
||||
|
||||
# 1. Configure Gemini extended thinking for a reasoning-heavy request.
|
||||
options: GeminiChatOptions = {
|
||||
"thinking_config": ThinkingConfig(thinking_budget=2048),
|
||||
}
|
||||
|
||||
# 2. Create the agent with the Gemini chat client and default thinking options.
|
||||
agent = Agent(
|
||||
client=GeminiChatClient(),
|
||||
name="PythonAgent",
|
||||
@@ -34,6 +36,7 @@ async def main() -> None:
|
||||
default_options=options,
|
||||
)
|
||||
|
||||
# 3. Stream the answer so you can see the final response as it arrives.
|
||||
query = "What new language features were introduced in Python between 3.10 and 3.14?"
|
||||
print(f"User: {query}")
|
||||
print("Agent: ", end="", flush=True)
|
||||
@@ -45,3 +48,12 @@ async def main() -> None:
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
Sample output:
|
||||
=== Extended thinking ===
|
||||
User: What new language features were introduced in Python between 3.10 and 3.14?
|
||||
Agent: Python 3.11 introduced exception groups and TaskGroup.
|
||||
Python 3.12 added PEP 695 type parameter syntax.
|
||||
Python 3.13-3.14 continued improving typing, performance, and developer ergonomics.
|
||||
"""
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
|
||||
Covers both non-streaming and streaming responses.
|
||||
|
||||
Requires the following environment variables to be set:
|
||||
- GEMINI_API_KEY
|
||||
- GEMINI_MODEL
|
||||
Requires ``GOOGLE_MODEL`` or ``GEMINI_MODEL`` and either Gemini Developer API credentials
|
||||
(``GEMINI_API_KEY`` or ``GOOGLE_API_KEY``) or Vertex AI settings
|
||||
(``GOOGLE_GENAI_USE_VERTEXAI``, ``GOOGLE_CLOUD_PROJECT``, and ``GOOGLE_CLOUD_LOCATION``).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -35,6 +35,7 @@ async def non_streaming_example() -> None:
|
||||
"""Runs the agent and waits for the complete response before printing it."""
|
||||
print("=== Non-streaming ===")
|
||||
|
||||
# 1. Create the agent with the Gemini chat client and local weather tool.
|
||||
agent = Agent(
|
||||
client=GeminiChatClient(),
|
||||
name="WeatherAgent",
|
||||
@@ -42,6 +43,7 @@ async def non_streaming_example() -> None:
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
# 2. Ask the agent for a single weather lookup and print the final response.
|
||||
query = "What's the weather like in Karlsruhe, Germany?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
@@ -52,6 +54,7 @@ async def streaming_example() -> None:
|
||||
"""Runs the agent and prints each chunk as it is received."""
|
||||
print("=== Streaming ===")
|
||||
|
||||
# 1. Create the same agent configuration for a streaming tool-call example.
|
||||
agent = Agent(
|
||||
client=GeminiChatClient(),
|
||||
name="WeatherAgent",
|
||||
@@ -59,6 +62,7 @@ async def streaming_example() -> None:
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
# 2. Ask a multi-location question and stream the model output as it arrives.
|
||||
query = "What's the weather like in Portland and in Paris?"
|
||||
print(f"User: {query}")
|
||||
print("Agent: ", end="", flush=True)
|
||||
@@ -76,3 +80,14 @@ async def main() -> None:
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
Sample output:
|
||||
=== Non-streaming ===
|
||||
User: What's the weather like in Karlsruhe, Germany?
|
||||
Result: The weather in Karlsruhe, Germany is currently sunny with a high of 16°C.
|
||||
|
||||
=== Streaming ===
|
||||
User: What's the weather like in Portland and in Paris?
|
||||
Agent: In Portland, it is currently rainy with a high of 11°C. In Paris, it is cloudy with a high of 27°C.
|
||||
"""
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
|
||||
Allows the model to write and run code in a sandboxed environment to answer questions.
|
||||
|
||||
Requires the following environment variables to be set:
|
||||
- GEMINI_API_KEY
|
||||
- GEMINI_MODEL
|
||||
Requires ``GOOGLE_MODEL`` or ``GEMINI_MODEL`` and either Gemini Developer API credentials
|
||||
(``GEMINI_API_KEY`` or ``GOOGLE_API_KEY``) or Vertex AI settings
|
||||
(``GOOGLE_GENAI_USE_VERTEXAI``, ``GOOGLE_CLOUD_PROJECT``, and ``GOOGLE_CLOUD_LOCATION``).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -23,6 +23,7 @@ async def main() -> None:
|
||||
"""Run the code execution example."""
|
||||
print("=== Code execution ===")
|
||||
|
||||
# 1. Create the agent with Gemini and the built-in code execution tool.
|
||||
agent = Agent(
|
||||
client=GeminiChatClient(),
|
||||
name="CodeAgent",
|
||||
@@ -30,6 +31,7 @@ async def main() -> None:
|
||||
tools=[GeminiChatClient.get_code_interpreter_tool()],
|
||||
)
|
||||
|
||||
# 2. Ask for a computed answer and stream the generated code and final result.
|
||||
query = "What are the first 20 prime numbers? Compute them in code."
|
||||
print(f"User: {query}")
|
||||
print("Agent: ", end="", flush=True)
|
||||
@@ -41,3 +43,10 @@ async def main() -> None:
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
Sample output:
|
||||
=== Code execution ===
|
||||
User: What are the first 20 prime numbers? Compute them in code.
|
||||
Agent: The first 20 prime numbers are 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, and 71.
|
||||
"""
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
|
||||
Allows Gemini to retrieve location and mapping information before responding.
|
||||
|
||||
Requires the following environment variables to be set:
|
||||
- GEMINI_API_KEY
|
||||
- GEMINI_MODEL
|
||||
Requires ``GOOGLE_MODEL`` or ``GEMINI_MODEL`` and either Gemini Developer API credentials
|
||||
(``GEMINI_API_KEY`` or ``GOOGLE_API_KEY``) or Vertex AI settings
|
||||
(``GOOGLE_GENAI_USE_VERTEXAI``, ``GOOGLE_CLOUD_PROJECT``, and ``GOOGLE_CLOUD_LOCATION``).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -23,6 +23,7 @@ async def main() -> None:
|
||||
"""Run the Google Maps grounding example."""
|
||||
print("=== Google Maps grounding ===")
|
||||
|
||||
# 1. Create the agent with Gemini and the built-in Google Maps grounding tool.
|
||||
agent = Agent(
|
||||
client=GeminiChatClient(),
|
||||
name="MapsAgent",
|
||||
@@ -30,6 +31,7 @@ async def main() -> None:
|
||||
tools=[GeminiChatClient.get_maps_grounding_tool()],
|
||||
)
|
||||
|
||||
# 2. Ask a location-aware question and stream the grounded answer.
|
||||
query = "What are some highly rated restaurants in the city center of Karlsruhe, Germany?"
|
||||
print(f"User: {query}")
|
||||
print("Agent: ", end="", flush=True)
|
||||
@@ -41,3 +43,11 @@ async def main() -> None:
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
Sample output:
|
||||
=== Google Maps grounding ===
|
||||
User: What are some highly rated restaurants in the city center of Karlsruhe, Germany?
|
||||
Agent: Here are several highly rated restaurants near Karlsruhe city center,
|
||||
along with their cuisine styles and approximate walking distance.
|
||||
"""
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
|
||||
Allows Gemini to retrieve up-to-date information from the web before responding.
|
||||
|
||||
Requires the following environment variables to be set:
|
||||
- GEMINI_API_KEY
|
||||
- GEMINI_MODEL
|
||||
Requires ``GOOGLE_MODEL`` or ``GEMINI_MODEL`` and either Gemini Developer API credentials
|
||||
(``GEMINI_API_KEY`` or ``GOOGLE_API_KEY``) or Vertex AI settings
|
||||
(``GOOGLE_GENAI_USE_VERTEXAI``, ``GOOGLE_CLOUD_PROJECT``, and ``GOOGLE_CLOUD_LOCATION``).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -23,6 +23,7 @@ async def main() -> None:
|
||||
"""Run the Google Search grounding example."""
|
||||
print("=== Google Search grounding ===")
|
||||
|
||||
# 1. Create the agent with Gemini and the built-in Google Search grounding tool.
|
||||
agent = Agent(
|
||||
client=GeminiChatClient(),
|
||||
name="SearchAgent",
|
||||
@@ -30,6 +31,7 @@ async def main() -> None:
|
||||
tools=[GeminiChatClient.get_web_search_tool()],
|
||||
)
|
||||
|
||||
# 2. Ask a current-events style question and stream the grounded answer.
|
||||
query = "What is the latest stable release of the .NET SDK?"
|
||||
print(f"User: {query}")
|
||||
print("Agent: ", end="", flush=True)
|
||||
@@ -41,3 +43,10 @@ async def main() -> None:
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
Sample output:
|
||||
=== Google Search grounding ===
|
||||
User: What is the latest stable release of the .NET SDK?
|
||||
Agent: As of April 14, 2026, the latest stable release of the .NET SDK is .NET 10.0 (SDK 10.0.201).
|
||||
"""
|
||||
|
||||
@@ -15,12 +15,28 @@ from pydantic import BaseModel
|
||||
|
||||
from agent_framework_gemini import GeminiChatClient, GeminiChatOptions, ThinkingConfig
|
||||
|
||||
skip_if_no_api_key = pytest.mark.skipif(
|
||||
not os.getenv("GEMINI_API_KEY"),
|
||||
reason="GEMINI_API_KEY not set; skipping integration tests.",
|
||||
|
||||
def _has_gemini_integration_credentials() -> bool:
|
||||
"""Return whether integration credentials for either Gemini API or Vertex AI appear to be configured."""
|
||||
if os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY"):
|
||||
return True
|
||||
|
||||
if os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() in {"true", "1", "yes", "on"}:
|
||||
return bool(
|
||||
os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
or os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
|
||||
or os.getenv("GOOGLE_API_KEY")
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
skip_if_no_credentials = pytest.mark.skipif(
|
||||
not _has_gemini_integration_credentials(),
|
||||
reason="Gemini Developer API or Vertex AI credentials not set; skipping integration tests.",
|
||||
)
|
||||
|
||||
_TEST_MODEL = "gemini-2.5-flash"
|
||||
_TEST_MODEL = os.getenv("GOOGLE_MODEL") or os.getenv("GEMINI_MODEL", "gemini-2.5-flash-lite")
|
||||
|
||||
# stub helpers
|
||||
|
||||
@@ -89,6 +105,7 @@ def _make_response(
|
||||
candidate.finish_reason = None
|
||||
|
||||
response.candidates = [candidate]
|
||||
response.finish_reason = finish_reason
|
||||
response.model_version = model_version
|
||||
|
||||
if prompt_tokens is not None or output_tokens is not None:
|
||||
@@ -115,6 +132,8 @@ def _make_gemini_client(
|
||||
) -> tuple[GeminiChatClient, MagicMock]:
|
||||
"""Return a (GeminiChatClient, mock_genai_client) pair."""
|
||||
mock = mock_client or MagicMock()
|
||||
mock._api_client.vertexai = False
|
||||
mock._api_client._http_options.base_url = "https://generativelanguage.googleapis.com/"
|
||||
client = GeminiChatClient(client=mock, model=model)
|
||||
return client, mock
|
||||
|
||||
@@ -135,12 +154,134 @@ def test_client_created_from_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert client.model == "gemini-2.5-flash"
|
||||
|
||||
|
||||
def test_missing_api_key_raises_when_no_client_injected(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Raises ValueError at construction when neither an API key nor a pre-built client is available."""
|
||||
def test_client_created_from_google_api_key_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Initialises successfully when the SDK-standard Google API key environment variable is set."""
|
||||
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GEMINI_MODEL", raising=False)
|
||||
monkeypatch.delenv("GOOGLE_GENAI_USE_VERTEXAI", raising=False)
|
||||
monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
|
||||
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "test-key-123")
|
||||
monkeypatch.setenv("GOOGLE_MODEL", "gemini-2.5-flash-lite")
|
||||
|
||||
with pytest.raises(ValueError, match="GEMINI_API_KEY"):
|
||||
mock_client = MagicMock()
|
||||
mock_client._api_client.vertexai = False
|
||||
mock_client._api_client._http_options.base_url = "https://generativelanguage.googleapis.com/"
|
||||
|
||||
with patch("agent_framework_gemini._chat_client.genai.Client") as client_factory:
|
||||
client_factory.return_value = mock_client
|
||||
client = GeminiChatClient()
|
||||
|
||||
assert client_factory.call_args.kwargs["api_key"] == "test-key-123"
|
||||
assert "vertexai" not in client_factory.call_args.kwargs
|
||||
assert client.model == "gemini-2.5-flash-lite"
|
||||
assert client.service_url() == "https://generativelanguage.googleapis.com"
|
||||
|
||||
|
||||
def test_client_created_from_vertex_ai_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Initialises a Vertex AI client when the SDK-standard Vertex AI environment variables are set."""
|
||||
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
|
||||
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
|
||||
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project")
|
||||
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "global")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client._api_client.vertexai = True
|
||||
mock_client._api_client._http_options.base_url = "https://aiplatform.googleapis.com/"
|
||||
|
||||
with patch("agent_framework_gemini._chat_client.genai.Client", return_value=mock_client) as client_factory:
|
||||
client = GeminiChatClient()
|
||||
|
||||
assert client_factory.call_args.kwargs["vertexai"] is True
|
||||
assert client_factory.call_args.kwargs["project"] == "test-project"
|
||||
assert client_factory.call_args.kwargs["location"] == "global"
|
||||
assert "api_key" not in client_factory.call_args.kwargs
|
||||
assert client.service_url() == "https://aiplatform.googleapis.com"
|
||||
|
||||
|
||||
def test_google_settings_take_precedence_over_gemini_aliases(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Prefers SDK-standard ``GOOGLE_*`` settings when both env families are present."""
|
||||
monkeypatch.setenv("GEMINI_API_KEY", "gemini-key")
|
||||
monkeypatch.setenv("GEMINI_MODEL", "gemini-model")
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "google-key")
|
||||
monkeypatch.setenv("GOOGLE_MODEL", "google-model")
|
||||
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
|
||||
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "google-project")
|
||||
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "global")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client._api_client.vertexai = True
|
||||
mock_client._api_client._http_options.base_url = "https://aiplatform.googleapis.com/"
|
||||
|
||||
with patch("agent_framework_gemini._chat_client.genai.Client", return_value=mock_client) as client_factory:
|
||||
client = GeminiChatClient()
|
||||
|
||||
assert client_factory.call_args.kwargs["vertexai"] is True
|
||||
assert client_factory.call_args.kwargs["project"] == "google-project"
|
||||
assert client_factory.call_args.kwargs["location"] == "global"
|
||||
assert "api_key" not in client_factory.call_args.kwargs
|
||||
assert client.model == "google-model"
|
||||
assert client.service_url() == "https://aiplatform.googleapis.com"
|
||||
|
||||
|
||||
def test_missing_api_key_raises_when_no_client_injected(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Raises ValueError at construction when neither Gemini API nor Vertex AI settings are available."""
|
||||
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GEMINI_MODEL", raising=False)
|
||||
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GOOGLE_GENAI_USE_VERTEXAI", raising=False)
|
||||
monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
|
||||
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
|
||||
|
||||
with pytest.raises(ValueError, match="requires an API key when Vertex AI is not enabled"):
|
||||
GeminiChatClient(model="gemini-2.5-flash")
|
||||
|
||||
|
||||
def test_vertex_ai_express_mode_uses_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Passes the API key in Vertex AI express mode when no project/location pair is configured."""
|
||||
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GEMINI_MODEL", raising=False)
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "test-key-123")
|
||||
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
|
||||
monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
|
||||
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client._api_client.vertexai = True
|
||||
mock_client._api_client._http_options.base_url = "https://aiplatform.googleapis.com/"
|
||||
|
||||
with patch("agent_framework_gemini._chat_client.genai.Client", return_value=mock_client) as client_factory:
|
||||
client = GeminiChatClient(model="gemini-2.5-flash-lite")
|
||||
|
||||
assert client_factory.call_args.kwargs["vertexai"] is True
|
||||
assert client_factory.call_args.kwargs["api_key"] == "test-key-123"
|
||||
assert "project" not in client_factory.call_args.kwargs
|
||||
assert "location" not in client_factory.call_args.kwargs
|
||||
assert client.service_url() == "https://aiplatform.googleapis.com"
|
||||
|
||||
|
||||
def test_vertex_ai_requires_configuration(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Raises a deterministic error when Vertex AI is enabled without any auth configuration."""
|
||||
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
|
||||
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
|
||||
monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False)
|
||||
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
|
||||
|
||||
with pytest.raises(ValueError, match="requires Vertex AI credentials or configuration"):
|
||||
GeminiChatClient(model="gemini-2.5-flash")
|
||||
|
||||
|
||||
def test_vertex_ai_requires_project_and_location_together(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Raises a deterministic error when only one Vertex AI location setting is present."""
|
||||
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
|
||||
monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "true")
|
||||
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project")
|
||||
monkeypatch.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
|
||||
|
||||
with pytest.raises(ValueError, match="requires both GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION"):
|
||||
GeminiChatClient(model="gemini-2.5-flash")
|
||||
|
||||
|
||||
@@ -495,6 +636,30 @@ async def test_thinking_parts_are_silently_skipped() -> None:
|
||||
assert response.messages[0].text == "The answer is 42."
|
||||
|
||||
|
||||
def test_function_call_part_preserves_thought_signature_from_raw_part() -> None:
|
||||
"""Reuses the original Gemini Part so tool loops retain thought_signature metadata."""
|
||||
client, _ = _make_gemini_client()
|
||||
raw_part = types.Part(
|
||||
function_call=types.FunctionCall(id="call-1", name="get_weather", args={"location": "Paris"}),
|
||||
thought_signature=b"sig-123",
|
||||
)
|
||||
content = Content.from_function_call(
|
||||
call_id="call-1",
|
||||
name="get_weather",
|
||||
arguments={"location": "Paris"},
|
||||
raw_representation=raw_part,
|
||||
)
|
||||
|
||||
parts = client._convert_message_contents([content], {})
|
||||
|
||||
assert len(parts) == 1
|
||||
assert parts[0].thought_signature == b"sig-123"
|
||||
assert parts[0].function_call is not None
|
||||
assert parts[0].function_call.id == "call-1"
|
||||
assert parts[0].function_call.name == "get_weather"
|
||||
assert parts[0].function_call.args == {"location": "Paris"}
|
||||
|
||||
|
||||
# code execution parts
|
||||
|
||||
|
||||
@@ -1283,12 +1448,26 @@ def test_service_url() -> None:
|
||||
assert client.service_url() == "https://generativelanguage.googleapis.com"
|
||||
|
||||
|
||||
def test_service_url_falls_back_when_sdk_base_url_is_unavailable() -> None:
|
||||
"""Falls back to the known service URL when the SDK client does not expose a base URL."""
|
||||
gemini_sdk_client = MagicMock()
|
||||
gemini_sdk_client._api_client.vertexai = False
|
||||
gemini_client = GeminiChatClient(client=gemini_sdk_client, model="gemini-2.5-flash")
|
||||
|
||||
vertex_sdk_client = MagicMock()
|
||||
vertex_sdk_client._api_client.vertexai = True
|
||||
vertex_client = GeminiChatClient(client=vertex_sdk_client, model="gemini-2.5-flash")
|
||||
|
||||
assert gemini_client.service_url() == "https://generativelanguage.googleapis.com"
|
||||
assert vertex_client.service_url() == "https://aiplatform.googleapis.com"
|
||||
|
||||
|
||||
# integration tests
|
||||
|
||||
|
||||
@pytest.mark.flaky
|
||||
@pytest.mark.integration
|
||||
@skip_if_no_api_key
|
||||
@skip_if_no_credentials
|
||||
async def test_integration_basic_chat() -> None:
|
||||
"""Basic request/response round-trip returns a non-empty text reply."""
|
||||
client = GeminiChatClient(model=_TEST_MODEL)
|
||||
@@ -1302,7 +1481,7 @@ async def test_integration_basic_chat() -> None:
|
||||
|
||||
@pytest.mark.flaky
|
||||
@pytest.mark.integration
|
||||
@skip_if_no_api_key
|
||||
@skip_if_no_credentials
|
||||
async def test_integration_streaming() -> None:
|
||||
"""Streaming yields multiple chunks that together form a non-empty response."""
|
||||
client = GeminiChatClient(model=_TEST_MODEL)
|
||||
@@ -1319,7 +1498,7 @@ async def test_integration_streaming() -> None:
|
||||
|
||||
@pytest.mark.flaky
|
||||
@pytest.mark.integration
|
||||
@skip_if_no_api_key
|
||||
@skip_if_no_credentials
|
||||
async def test_integration_structured_output() -> None:
|
||||
"""Structured output with a Pydantic response_format returns a parsed value via response.value."""
|
||||
|
||||
@@ -1340,7 +1519,7 @@ async def test_integration_structured_output() -> None:
|
||||
|
||||
@pytest.mark.flaky
|
||||
@pytest.mark.integration
|
||||
@skip_if_no_api_key
|
||||
@skip_if_no_credentials
|
||||
async def test_integration_tool_calling() -> None:
|
||||
"""Model invokes the registered tool when asked a question that requires it."""
|
||||
|
||||
@@ -1363,7 +1542,7 @@ async def test_integration_tool_calling() -> None:
|
||||
|
||||
@pytest.mark.flaky
|
||||
@pytest.mark.integration
|
||||
@skip_if_no_api_key
|
||||
@skip_if_no_credentials
|
||||
async def test_integration_thinking_config() -> None:
|
||||
"""Model accepts a thinking budget and returns a non-empty text reply."""
|
||||
options: GeminiChatOptions = {"thinking_config": ThinkingConfig(thinking_budget=512)}
|
||||
@@ -1380,7 +1559,7 @@ async def test_integration_thinking_config() -> None:
|
||||
|
||||
@pytest.mark.flaky
|
||||
@pytest.mark.integration
|
||||
@skip_if_no_api_key
|
||||
@skip_if_no_credentials
|
||||
async def test_integration_google_search_grounding() -> None:
|
||||
"""Google Search grounding returns a non-empty response for a current-events question."""
|
||||
client = GeminiChatClient(model=_TEST_MODEL)
|
||||
@@ -1396,7 +1575,7 @@ async def test_integration_google_search_grounding() -> None:
|
||||
|
||||
@pytest.mark.flaky
|
||||
@pytest.mark.integration
|
||||
@skip_if_no_api_key
|
||||
@skip_if_no_credentials
|
||||
async def test_integration_google_maps_grounding() -> None:
|
||||
"""Google Maps grounding returns a non-empty response for a location-based question."""
|
||||
client = GeminiChatClient(model=_TEST_MODEL)
|
||||
@@ -1417,7 +1596,7 @@ async def test_integration_google_maps_grounding() -> None:
|
||||
|
||||
@pytest.mark.flaky
|
||||
@pytest.mark.integration
|
||||
@skip_if_no_api_key
|
||||
@skip_if_no_credentials
|
||||
async def test_integration_code_execution() -> None:
|
||||
"""Code execution tool produces a non-empty response for a computation request."""
|
||||
client = GeminiChatClient(model=_TEST_MODEL)
|
||||
|
||||
@@ -2474,6 +2474,29 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
raw_representation=event,
|
||||
)
|
||||
)
|
||||
elif ann_type == "url_citation":
|
||||
ann_url = _get_ann_value("url")
|
||||
if ann_url:
|
||||
ann_start = _get_ann_value("start_index")
|
||||
ann_end = _get_ann_value("end_index")
|
||||
annotation_obj = Annotation(
|
||||
type="citation",
|
||||
title=_get_ann_value("title") or "",
|
||||
url=str(ann_url),
|
||||
additional_properties={"annotation_index": event.annotation_index},
|
||||
raw_representation=annotation,
|
||||
)
|
||||
if ann_start is not None and ann_end is not None:
|
||||
annotation_obj["annotated_regions"] = [
|
||||
TextSpanRegion(
|
||||
type="text_span",
|
||||
start_index=ann_start,
|
||||
end_index=ann_end,
|
||||
)
|
||||
]
|
||||
contents.append(
|
||||
Content.from_text(text="", annotations=[annotation_obj], raw_representation=event)
|
||||
)
|
||||
else:
|
||||
logger.debug("Unparsed annotation type in streaming: %s", ann_type)
|
||||
case "response.output_item.done":
|
||||
|
||||
@@ -2570,8 +2570,65 @@ def test_streaming_annotation_added_with_container_file_citation() -> None:
|
||||
assert content.additional_properties.get("end_index") == 50
|
||||
|
||||
|
||||
def test_streaming_annotation_added_with_unknown_type() -> None:
|
||||
"""Test streaming annotation added event with unknown type is ignored."""
|
||||
def test_streaming_annotation_added_with_url_citation() -> None:
|
||||
"""Test streaming annotation added event with url_citation type produces citation annotation."""
|
||||
client = OpenAIChatClient(model="test-model", api_key="test-key")
|
||||
chat_options = ChatOptions()
|
||||
function_call_ids: dict[int, tuple[str, str]] = {}
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.type = "response.output_text.annotation.added"
|
||||
mock_event.annotation_index = 0
|
||||
mock_event.annotation = {
|
||||
"type": "url_citation",
|
||||
"url": "https://example.sharepoint.com/sites/my-site/doc.pdf",
|
||||
"title": "doc.pdf",
|
||||
"start_index": 100,
|
||||
"end_index": 112,
|
||||
}
|
||||
|
||||
response = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids)
|
||||
|
||||
assert len(response.contents) == 1
|
||||
content = response.contents[0]
|
||||
assert content.type == "text"
|
||||
assert content.annotations is not None
|
||||
assert len(content.annotations) == 1
|
||||
annotation = content.annotations[0]
|
||||
assert annotation["type"] == "citation"
|
||||
assert annotation["title"] == "doc.pdf"
|
||||
assert annotation["url"] == "https://example.sharepoint.com/sites/my-site/doc.pdf"
|
||||
assert annotation["additional_properties"]["annotation_index"] == 0
|
||||
assert annotation["raw_representation"] == mock_event.annotation
|
||||
assert annotation["annotated_regions"] is not None
|
||||
assert len(annotation["annotated_regions"]) == 1
|
||||
region = annotation["annotated_regions"][0]
|
||||
assert region["type"] == "text_span"
|
||||
assert region["start_index"] == 100
|
||||
assert region["end_index"] == 112
|
||||
|
||||
|
||||
def test_streaming_annotation_added_with_url_citation_no_url() -> None:
|
||||
"""Test streaming annotation added event with url_citation but missing url is ignored."""
|
||||
client = OpenAIChatClient(model="test-model", api_key="test-key")
|
||||
chat_options = ChatOptions()
|
||||
function_call_ids: dict[int, tuple[str, str]] = {}
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.type = "response.output_text.annotation.added"
|
||||
mock_event.annotation_index = 0
|
||||
mock_event.annotation = {
|
||||
"type": "url_citation",
|
||||
"title": "doc.pdf",
|
||||
}
|
||||
|
||||
response = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids)
|
||||
|
||||
assert len(response.contents) == 0
|
||||
|
||||
|
||||
def test_streaming_annotation_added_with_url_citation_no_indices() -> None:
|
||||
"""Test streaming annotation with url_citation that has url but no start_index/end_index."""
|
||||
client = OpenAIChatClient(model="test-model", api_key="test-key")
|
||||
chat_options = ChatOptions()
|
||||
function_call_ids: dict[int, tuple[str, str]] = {}
|
||||
@@ -2582,11 +2639,36 @@ def test_streaming_annotation_added_with_unknown_type() -> None:
|
||||
mock_event.annotation = {
|
||||
"type": "url_citation",
|
||||
"url": "https://example.com",
|
||||
"title": "Example",
|
||||
}
|
||||
|
||||
response = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids)
|
||||
|
||||
assert len(response.contents) == 1
|
||||
annotation = response.contents[0].annotations[0]
|
||||
assert annotation["type"] == "citation"
|
||||
assert annotation["title"] == "Example"
|
||||
assert annotation["url"] == "https://example.com"
|
||||
assert annotation["additional_properties"]["annotation_index"] == 0
|
||||
assert "annotated_regions" not in annotation
|
||||
|
||||
|
||||
def test_streaming_annotation_added_with_unknown_type() -> None:
|
||||
"""Test streaming annotation added event with unknown type is ignored."""
|
||||
client = OpenAIChatClient(model="test-model", api_key="test-key")
|
||||
chat_options = ChatOptions()
|
||||
function_call_ids: dict[int, tuple[str, str]] = {}
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.type = "response.output_text.annotation.added"
|
||||
mock_event.annotation_index = 0
|
||||
mock_event.annotation = {
|
||||
"type": "some_future_annotation_type",
|
||||
"data": "test",
|
||||
}
|
||||
|
||||
response = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids)
|
||||
|
||||
# url_citation should not produce HostedFileContent
|
||||
assert len(response.contents) == 0
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user