mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: [BREAKING] Migrate agent-framework-a2a to a2a-sdk v1.0 (#5752)
* Python: Migrate agent-framework-a2a to a2a-sdk v1.0
Upgrade the a2a-sdk dependency from v0.3.x to v1.0.0 and migrate all
source, tests, samples, and documentation to the v1.0 API.
Key changes:
- Dependency: a2a-sdk>=1.0.0,<2 (was >=0.3.5,<0.3.24)
- Types are now protobuf-based: Part replaces TextPart/FilePart/DataPart
- Enums use SCREAMING_SNAKE_CASE (e.g. TaskState.TASK_STATE_COMPLETED)
- Roles: Role.ROLE_AGENT, Role.ROLE_USER
- Client: SendMessageRequest wrapper, subscribe() replaces resubscribe()
- Server: A2AStarletteApplication replaced by Starlette + route factories
- DefaultRequestHandler now requires agent_card parameter
- TaskUpdater: final parameter removed, add_artifact gains last_chunk
- AgentCard.url removed; use supported_interfaces with AgentInterface
- Stream yields StreamResponse with WhichOneof('payload')
Closes #5661
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
* Address PR review: validate fallback URL, remove unused task_id vars
- Raise ValueError with clear message when transport negotiation fails
and no fallback URL is available (neither url arg nor supported_interfaces)
- Remove unused task_id local in status_update branch
- Inline artifact_event.task_id directly in artifact_update branch
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
---------
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
e3875f2c91
commit
4ad96b64e7
@@ -23,23 +23,28 @@ response = await a2a_agent.run("Hello!")
|
||||
|
||||
```python
|
||||
from agent_framework.a2a import A2AExecutor
|
||||
from a2a.server.apps import A2AStarletteApplication
|
||||
from a2a.server.request_handlers import DefaultRequestHandler
|
||||
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
|
||||
from a2a.server.tasks import InMemoryTaskStore
|
||||
from starlette.applications import Starlette
|
||||
|
||||
# Create an A2A executor for your agent
|
||||
executor = A2AExecutor(agent=my_agent)
|
||||
|
||||
# Set up the request handler and server application
|
||||
# Set up the request handler (agent_card is required)
|
||||
request_handler = DefaultRequestHandler(
|
||||
agent_executor=executor,
|
||||
task_store=InMemoryTaskStore(),
|
||||
agent_card=my_agent_card,
|
||||
)
|
||||
|
||||
app = A2AStarletteApplication(
|
||||
agent_card=my_agent_card,
|
||||
http_handler=request_handler,
|
||||
).build()
|
||||
# Build a Starlette app with A2A routes
|
||||
app = Starlette(
|
||||
routes=[
|
||||
*create_agent_card_routes(my_agent_card),
|
||||
*create_jsonrpc_routes(request_handler),
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
## Import Path
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from asyncio import CancelledError
|
||||
from collections.abc import Mapping
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from a2a.helpers import new_task_from_user_message
|
||||
from a2a.server.agent_execution import AgentExecutor, RequestContext
|
||||
from a2a.server.events import EventQueue
|
||||
from a2a.server.tasks import TaskUpdater
|
||||
from a2a.types import FilePart, FileWithBytes, FileWithUri, Part, TaskState, TextPart
|
||||
from a2a.utils import new_task
|
||||
from a2a.types import Part, TaskState
|
||||
from agent_framework import (
|
||||
AgentResponseUpdate,
|
||||
AgentSession,
|
||||
@@ -39,21 +40,24 @@ class A2AExecutor(AgentExecutor):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from a2a.server.apps import A2AStarletteApplication
|
||||
from a2a.server.request_handlers import DefaultRequestHandler
|
||||
from a2a.server.routes import create_jsonrpc_routes, create_agent_card_routes
|
||||
from a2a.server.tasks import InMemoryTaskStore
|
||||
from a2a.types import AgentCapabilities, AgentCard
|
||||
from a2a.types import AgentCapabilities, AgentCard, AgentInterface
|
||||
from agent_framework.a2a import A2AExecutor
|
||||
from agent_framework.openai import OpenAIResponsesClient
|
||||
from starlette.applications import Starlette
|
||||
|
||||
public_agent_card = AgentCard(
|
||||
name="Food Agent",
|
||||
description="A simple agent that provides food-related information.",
|
||||
url="http://localhost:9999/",
|
||||
version="1.0.0",
|
||||
defaultInputModes=["text"],
|
||||
defaultOutputModes=["text"],
|
||||
default_input_modes=["text"],
|
||||
default_output_modes=["text"],
|
||||
capabilities=AgentCapabilities(streaming=True),
|
||||
supported_interfaces=[
|
||||
AgentInterface(url="http://localhost:9999/", protocol_binding="JSONRPC"),
|
||||
],
|
||||
skills=[],
|
||||
)
|
||||
|
||||
@@ -68,12 +72,15 @@ class A2AExecutor(AgentExecutor):
|
||||
request_handler = DefaultRequestHandler(
|
||||
agent_executor=A2AExecutor(agent, stream=True, run_kwargs={"client_kwargs": {"max_tokens": 500}}),
|
||||
task_store=InMemoryTaskStore(),
|
||||
agent_card=public_agent_card,
|
||||
)
|
||||
|
||||
server = A2AStarletteApplication(
|
||||
agent_card=public_agent_card,
|
||||
http_handler=request_handler,
|
||||
).build()
|
||||
app = Starlette(
|
||||
routes=[
|
||||
*create_agent_card_routes(public_agent_card),
|
||||
*create_jsonrpc_routes(request_handler),
|
||||
],
|
||||
)
|
||||
|
||||
Args:
|
||||
agent: The AI agent to execute.
|
||||
@@ -143,7 +150,7 @@ class A2AExecutor(AgentExecutor):
|
||||
task = context.current_task
|
||||
|
||||
if not task:
|
||||
task = new_task(context.message)
|
||||
task = new_task_from_user_message(context.message)
|
||||
await event_queue.enqueue_event(task)
|
||||
|
||||
updater = TaskUpdater(event_queue, task.id, context.context_id)
|
||||
@@ -162,13 +169,12 @@ class A2AExecutor(AgentExecutor):
|
||||
# Mark as complete
|
||||
await updater.complete()
|
||||
except CancelledError:
|
||||
await updater.update_status(state=TaskState.canceled, final=True)
|
||||
await updater.update_status(state=TaskState.TASK_STATE_CANCELED)
|
||||
except Exception as e:
|
||||
logger.exception("A2AExecutor encountered an error during execution.", exc_info=e)
|
||||
await updater.update_status(
|
||||
state=TaskState.failed,
|
||||
final=True,
|
||||
message=updater.new_agent_message([Part(root=TextPart(text=str(e)))]),
|
||||
state=TaskState.TASK_STATE_FAILED,
|
||||
message=updater.new_agent_message([Part(text=str(e))]),
|
||||
)
|
||||
|
||||
async def _run_stream(self, query: Any, session: AgentSession, updater: TaskUpdater) -> None:
|
||||
@@ -221,9 +227,9 @@ class A2AExecutor(AgentExecutor):
|
||||
) -> None:
|
||||
# Custom logic to transform item contents
|
||||
if item.role == "assistant" and item.contents:
|
||||
parts = [Part(root=TextPart(text=f"Custom: {item.contents[0].text}"))]
|
||||
parts = [Part(text=f"Custom: {item.contents[0].text}")]
|
||||
await updater.update_status(
|
||||
state=TaskState.working,
|
||||
state=TaskState.TASK_STATE_WORKING,
|
||||
message=updater.new_agent_message(parts=parts),
|
||||
)
|
||||
else:
|
||||
@@ -242,12 +248,12 @@ class A2AExecutor(AgentExecutor):
|
||||
|
||||
for content in contents:
|
||||
if content.type == "text" and content.text:
|
||||
parts.append(Part(root=TextPart(text=content.text)))
|
||||
parts.append(Part(text=content.text))
|
||||
elif content.type == "data" and content.uri:
|
||||
base64_str = get_uri_data(content.uri)
|
||||
parts.append(Part(root=FilePart(file=FileWithBytes(bytes=base64_str, mime_type=content.media_type))))
|
||||
parts.append(Part(raw=base64.b64decode(base64_str), media_type=content.media_type or ""))
|
||||
elif content.type == "uri" and content.uri:
|
||||
parts.append(Part(root=FilePart(file=FileWithUri(uri=content.uri, mime_type=content.media_type))))
|
||||
parts.append(Part(url=content.uri, media_type=content.media_type or ""))
|
||||
else:
|
||||
# Silently skip unsupported content types
|
||||
logger.warning("A2AExecutor does not yet support content type: %s. Omitted.", content.type)
|
||||
@@ -270,6 +276,6 @@ class A2AExecutor(AgentExecutor):
|
||||
else:
|
||||
# For final messages, we send TaskStatusUpdateEvent with 'working' state
|
||||
await updater.update_status(
|
||||
state=TaskState.working,
|
||||
state=TaskState.TASK_STATE_WORKING,
|
||||
message=updater.new_agent_message(parts=parts, metadata=metadata),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence
|
||||
from typing import Any, Final, Literal, TypeAlias, overload
|
||||
@@ -14,17 +13,14 @@ from a2a.client.auth.interceptor import AuthInterceptor
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
Artifact,
|
||||
FilePart,
|
||||
FileWithBytes,
|
||||
FileWithUri,
|
||||
GetTaskRequest,
|
||||
SendMessageRequest,
|
||||
StreamResponse,
|
||||
SubscribeToTaskRequest,
|
||||
Task,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskIdParams,
|
||||
TaskQueryParams,
|
||||
TaskState,
|
||||
TaskStatusUpdateEvent,
|
||||
TextPart,
|
||||
TransportProtocol,
|
||||
)
|
||||
from a2a.types import Message as A2AMessage
|
||||
from a2a.types import Part as A2APart
|
||||
@@ -45,6 +41,7 @@ from agent_framework import (
|
||||
)
|
||||
from agent_framework._types import AgentRunInputs
|
||||
from agent_framework.observability import AgentTelemetryLayer
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
__all__ = ["A2AAgent", "A2AContinuationToken"]
|
||||
|
||||
@@ -61,20 +58,19 @@ class A2AContinuationToken(ContinuationToken):
|
||||
|
||||
|
||||
TERMINAL_TASK_STATES = [
|
||||
TaskState.completed,
|
||||
TaskState.failed,
|
||||
TaskState.canceled,
|
||||
TaskState.rejected,
|
||||
TaskState.TASK_STATE_COMPLETED,
|
||||
TaskState.TASK_STATE_FAILED,
|
||||
TaskState.TASK_STATE_CANCELED,
|
||||
TaskState.TASK_STATE_REJECTED,
|
||||
]
|
||||
IN_PROGRESS_TASK_STATES = [
|
||||
TaskState.submitted,
|
||||
TaskState.working,
|
||||
TaskState.input_required,
|
||||
TaskState.auth_required,
|
||||
TaskState.TASK_STATE_SUBMITTED,
|
||||
TaskState.TASK_STATE_WORKING,
|
||||
TaskState.TASK_STATE_INPUT_REQUIRED,
|
||||
TaskState.TASK_STATE_AUTH_REQUIRED,
|
||||
]
|
||||
|
||||
A2AClientEvent: TypeAlias = tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None]
|
||||
A2AStreamItem: TypeAlias = A2AMessage | A2AClientEvent
|
||||
A2AStreamItem: TypeAlias = StreamResponse
|
||||
|
||||
|
||||
class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
@@ -139,7 +135,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
if url is None:
|
||||
raise ValueError("Either agent_card or url must be provided")
|
||||
# Create minimal agent card from URL
|
||||
agent_card = minimal_agent_card(url, [TransportProtocol.jsonrpc])
|
||||
agent_card = minimal_agent_card(url, ["JSONRPC"])
|
||||
|
||||
# Create or use provided httpx client
|
||||
if http_client is None:
|
||||
@@ -151,7 +147,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
# Create A2A client using factory
|
||||
config = ClientConfig(
|
||||
httpx_client=http_client,
|
||||
supported_transports=[TransportProtocol.jsonrpc],
|
||||
supported_protocol_bindings=["JSONRPC"],
|
||||
)
|
||||
factory = ClientFactory(config)
|
||||
interceptors = [auth_interceptor] if auth_interceptor is not None else None
|
||||
@@ -161,7 +157,16 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
self.client = factory.create(agent_card, interceptors=interceptors) # type: ignore
|
||||
except Exception as transport_error:
|
||||
# Transport negotiation failed - fall back to minimal agent card with JSONRPC
|
||||
fallback_card = minimal_agent_card(agent_card.url, [TransportProtocol.jsonrpc])
|
||||
fallback_url = (
|
||||
agent_card.supported_interfaces[0].url if agent_card.supported_interfaces else url
|
||||
)
|
||||
if not fallback_url:
|
||||
raise ValueError(
|
||||
"A2A transport negotiation failed and no fallback URL is available. "
|
||||
"Provide a 'url' argument or ensure 'agent_card.supported_interfaces' "
|
||||
"contains at least one interface with a URL."
|
||||
) from transport_error
|
||||
fallback_card = minimal_agent_card(fallback_url, ["JSONRPC"])
|
||||
try:
|
||||
self.client = factory.create(fallback_card, interceptors=interceptors) # type: ignore
|
||||
except Exception as fallback_error:
|
||||
@@ -280,8 +285,8 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
normalized_messages = normalize_messages(messages)
|
||||
|
||||
if continuation_token is not None:
|
||||
a2a_stream: AsyncIterable[A2AStreamItem] = self.client.resubscribe(
|
||||
TaskIdParams(id=continuation_token["task_id"])
|
||||
a2a_stream: AsyncIterable[A2AStreamItem] = self.client.subscribe(
|
||||
SubscribeToTaskRequest(id=continuation_token["task_id"])
|
||||
)
|
||||
else:
|
||||
if not normalized_messages:
|
||||
@@ -290,7 +295,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
normalized_messages[-1],
|
||||
context_id=session.service_session_id if session else None,
|
||||
)
|
||||
a2a_stream = self.client.send_message(a2a_message)
|
||||
a2a_stream = self.client.send_message(SendMessageRequest(message=a2a_message))
|
||||
|
||||
provider_session = session
|
||||
if provider_session is None and self.context_providers:
|
||||
@@ -361,38 +366,54 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
all_updates: list[AgentResponseUpdate] = []
|
||||
streamed_artifact_ids_by_task: dict[str, set[str]] = {}
|
||||
async for item in a2a_stream:
|
||||
if isinstance(item, A2AMessage):
|
||||
payload_type = item.WhichOneof("payload")
|
||||
if payload_type == "message":
|
||||
# Process A2A Message
|
||||
contents = self._parse_contents_from_a2a(item.parts)
|
||||
msg = item.message
|
||||
contents = self._parse_contents_from_a2a(msg.parts)
|
||||
metadata = MessageToDict(msg.metadata) if msg.metadata else None
|
||||
update = AgentResponseUpdate(
|
||||
contents=contents,
|
||||
role="assistant" if item.role == A2ARole.agent else "user",
|
||||
response_id=str(getattr(item, "message_id", uuid.uuid4())),
|
||||
additional_properties={"a2a_metadata": item.metadata} if item.metadata else None,
|
||||
raw_representation=item,
|
||||
role="assistant" if msg.role == A2ARole.ROLE_AGENT else "user",
|
||||
response_id=msg.message_id or str(uuid.uuid4()),
|
||||
additional_properties={"a2a_metadata": metadata} if metadata else None,
|
||||
raw_representation=msg,
|
||||
)
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], Task):
|
||||
task, update_event = item
|
||||
elif payload_type == "task":
|
||||
task = item.task
|
||||
updates = self._updates_from_task(
|
||||
task,
|
||||
update_event=update_event,
|
||||
background=background,
|
||||
emit_intermediate=emit_intermediate,
|
||||
streamed_artifact_ids=streamed_artifact_ids_by_task.get(task.id),
|
||||
)
|
||||
if isinstance(update_event, TaskArtifactUpdateEvent) and any(
|
||||
update.raw_representation is update_event for update in updates
|
||||
):
|
||||
streamed_artifact_ids_by_task.setdefault(task.id, set()).add(update_event.artifact.artifact_id)
|
||||
if task.status.state in TERMINAL_TASK_STATES:
|
||||
streamed_artifact_ids_by_task.pop(task.id, None)
|
||||
for update in updates:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
elif payload_type == "status_update":
|
||||
status_event = item.status_update
|
||||
updates = self._updates_from_task_update_event(status_event)
|
||||
if emit_intermediate:
|
||||
for update in updates:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
elif payload_type == "artifact_update":
|
||||
artifact_event = item.artifact_update
|
||||
updates = self._updates_from_task_update_event(artifact_event)
|
||||
if updates:
|
||||
streamed_artifact_ids_by_task.setdefault(artifact_event.task_id, set()).add(
|
||||
artifact_event.artifact.artifact_id
|
||||
)
|
||||
if emit_intermediate:
|
||||
for update in updates:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
else:
|
||||
raise NotImplementedError("Only Message and Task responses are supported")
|
||||
raise NotImplementedError(f"Unsupported StreamResponse payload: {payload_type}")
|
||||
|
||||
# Set the response on the context for after_run providers
|
||||
if all_updates:
|
||||
@@ -408,7 +429,6 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
self,
|
||||
task: Task,
|
||||
*,
|
||||
update_event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None,
|
||||
background: bool = False,
|
||||
emit_intermediate: bool = False,
|
||||
streamed_artifact_ids: set[str] | None = None,
|
||||
@@ -424,17 +444,11 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
completion.
|
||||
"""
|
||||
status = task.status
|
||||
|
||||
if (
|
||||
emit_intermediate
|
||||
and update_event is not None
|
||||
and (event_updates := self._updates_from_task_update_event(update_event))
|
||||
):
|
||||
return event_updates
|
||||
task_metadata = MessageToDict(task.metadata) if task.metadata else None
|
||||
|
||||
if status.state in TERMINAL_TASK_STATES:
|
||||
task_messages = self._parse_messages_from_task(task)
|
||||
if task.artifacts is not None and streamed_artifact_ids:
|
||||
if task.artifacts and streamed_artifact_ids:
|
||||
task_messages = [
|
||||
message
|
||||
for message in task_messages
|
||||
@@ -448,20 +462,20 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
response_id=task.id,
|
||||
message_id=getattr(message.raw_representation, "artifact_id", None),
|
||||
additional_properties={"a2a_metadata": merged}
|
||||
if (merged := {**message.additional_properties, **(task.metadata or {})})
|
||||
if (merged := {**message.additional_properties, **(task_metadata or {})})
|
||||
else None,
|
||||
raw_representation=task,
|
||||
)
|
||||
for message in task_messages
|
||||
]
|
||||
if task.artifacts is not None:
|
||||
if task.artifacts:
|
||||
return []
|
||||
return [
|
||||
AgentResponseUpdate(
|
||||
contents=[],
|
||||
role="assistant",
|
||||
response_id=task.id,
|
||||
additional_properties={"a2a_metadata": task.metadata} if task.metadata else None,
|
||||
additional_properties={"a2a_metadata": task_metadata} if task_metadata else None,
|
||||
raw_representation=task,
|
||||
)
|
||||
]
|
||||
@@ -474,18 +488,16 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
role="assistant",
|
||||
response_id=task.id,
|
||||
continuation_token=token,
|
||||
additional_properties={"a2a_metadata": task.metadata} if task.metadata else None,
|
||||
additional_properties={"a2a_metadata": task_metadata} if task_metadata else None,
|
||||
raw_representation=task,
|
||||
)
|
||||
]
|
||||
|
||||
# Surface message content from in-progress status updates (e.g. working state)
|
||||
# Only emitted when the caller opts in (streaming), so non-streaming
|
||||
# consumers keep receiving only terminal task outputs.
|
||||
if (
|
||||
emit_intermediate
|
||||
and status.state in IN_PROGRESS_TASK_STATES
|
||||
and status.message is not None
|
||||
and status.HasField("message")
|
||||
and status.message.parts
|
||||
):
|
||||
contents = self._parse_contents_from_a2a(status.message.parts)
|
||||
@@ -493,9 +505,9 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
return [
|
||||
AgentResponseUpdate(
|
||||
contents=contents,
|
||||
role="assistant" if status.message.role == A2ARole.agent else "user",
|
||||
role="assistant" if status.message.role == A2ARole.ROLE_AGENT else "user",
|
||||
response_id=task.id,
|
||||
additional_properties={"a2a_metadata": task.metadata} if task.metadata else None,
|
||||
additional_properties={"a2a_metadata": task_metadata} if task_metadata else None,
|
||||
raw_representation=task,
|
||||
)
|
||||
]
|
||||
@@ -510,10 +522,9 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
contents = self._parse_contents_from_a2a(update_event.artifact.parts)
|
||||
if not contents:
|
||||
return []
|
||||
merged_metadata = {
|
||||
**(update_event.artifact.metadata or {}),
|
||||
**(update_event.metadata or {}),
|
||||
} or None
|
||||
artifact_meta = MessageToDict(update_event.artifact.metadata) if update_event.artifact.metadata else {}
|
||||
event_meta = MessageToDict(update_event.metadata) if update_event.metadata else {}
|
||||
merged_metadata = {**artifact_meta, **event_meta} or None
|
||||
return [
|
||||
AgentResponseUpdate(
|
||||
contents=contents,
|
||||
@@ -528,22 +539,21 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
if not isinstance(update_event, TaskStatusUpdateEvent):
|
||||
return []
|
||||
|
||||
message = update_event.status.message
|
||||
if message is None or not message.parts:
|
||||
if not update_event.status.HasField("message") or not update_event.status.message.parts:
|
||||
return []
|
||||
|
||||
message = update_event.status.message
|
||||
contents = self._parse_contents_from_a2a(message.parts)
|
||||
if not contents:
|
||||
return []
|
||||
|
||||
merged_metadata = {
|
||||
**(message.metadata or {}),
|
||||
**(update_event.metadata or {}),
|
||||
} or None
|
||||
msg_meta = MessageToDict(message.metadata) if message.metadata else {}
|
||||
event_meta = MessageToDict(update_event.metadata) if update_event.metadata else {}
|
||||
merged_metadata = {**msg_meta, **event_meta} or None
|
||||
return [
|
||||
AgentResponseUpdate(
|
||||
contents=contents,
|
||||
role="assistant" if message.role == A2ARole.agent else "user",
|
||||
role="assistant" if message.role == A2ARole.ROLE_AGENT else "user",
|
||||
response_id=update_event.task_id,
|
||||
additional_properties={"a2a_metadata": merged_metadata} if merged_metadata else None,
|
||||
raw_representation=update_event,
|
||||
@@ -572,7 +582,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
is still in progress, or ``None`` when it has reached a terminal state.
|
||||
"""
|
||||
task_id = continuation_token["task_id"]
|
||||
task = await self.client.get_task(TaskQueryParams(id=task_id))
|
||||
task = await self.client.get_task(GetTaskRequest(id=task_id))
|
||||
updates = self._updates_from_task(task, background=True)
|
||||
if updates:
|
||||
return AgentResponse.from_updates(updates)
|
||||
@@ -607,19 +617,15 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
raise ValueError("Text content requires a non-null text value")
|
||||
parts.append(
|
||||
A2APart(
|
||||
root=TextPart(
|
||||
text=content.text,
|
||||
metadata=content.additional_properties,
|
||||
)
|
||||
text=content.text,
|
||||
metadata=content.additional_properties or {},
|
||||
)
|
||||
)
|
||||
case "error":
|
||||
parts.append(
|
||||
A2APart(
|
||||
root=TextPart(
|
||||
text=content.message or "An error occurred.",
|
||||
metadata=content.additional_properties,
|
||||
)
|
||||
text=content.message or "An error occurred.",
|
||||
metadata=content.additional_properties or {},
|
||||
)
|
||||
)
|
||||
case "uri":
|
||||
@@ -627,27 +633,20 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
raise ValueError("URI content requires a non-null uri value")
|
||||
parts.append(
|
||||
A2APart(
|
||||
root=FilePart(
|
||||
file=FileWithUri(
|
||||
uri=content.uri,
|
||||
mime_type=content.media_type,
|
||||
),
|
||||
metadata=content.additional_properties,
|
||||
)
|
||||
url=content.uri,
|
||||
media_type=content.media_type or "",
|
||||
metadata=content.additional_properties or {},
|
||||
)
|
||||
)
|
||||
case "data":
|
||||
if content.uri is None:
|
||||
raise ValueError("Data content requires a non-null uri value")
|
||||
base64_data = get_uri_data(content.uri)
|
||||
parts.append(
|
||||
A2APart(
|
||||
root=FilePart(
|
||||
file=FileWithBytes(
|
||||
bytes=get_uri_data(content.uri),
|
||||
mime_type=content.media_type,
|
||||
),
|
||||
metadata=content.additional_properties,
|
||||
)
|
||||
raw=base64.b64decode(base64_data),
|
||||
media_type=content.media_type or "",
|
||||
metadata=content.additional_properties or {},
|
||||
)
|
||||
)
|
||||
case "hosted_file":
|
||||
@@ -655,93 +654,91 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
raise ValueError("Hosted file content requires a non-null file_id value")
|
||||
parts.append(
|
||||
A2APart(
|
||||
root=FilePart(
|
||||
file=FileWithUri(
|
||||
uri=content.file_id,
|
||||
mime_type=None, # HostedFileContent doesn't specify media_type
|
||||
),
|
||||
metadata=content.additional_properties,
|
||||
)
|
||||
url=content.file_id,
|
||||
metadata=content.additional_properties or {},
|
||||
)
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unknown content type: {content.type}")
|
||||
|
||||
metadata = message.additional_properties.get("a2a_metadata")
|
||||
a2a_metadata = message.additional_properties.get("a2a_metadata")
|
||||
|
||||
return A2AMessage(
|
||||
role=A2ARole("user"),
|
||||
role=A2ARole.ROLE_USER,
|
||||
parts=parts,
|
||||
message_id=message.message_id or uuid.uuid4().hex,
|
||||
context_id=message.additional_properties.get("context_id") or context_id,
|
||||
metadata=metadata,
|
||||
metadata=a2a_metadata or {},
|
||||
)
|
||||
|
||||
def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Content]:
|
||||
"""Parse A2A Parts into Agent Framework Content.
|
||||
|
||||
Transforms A2A protocol Parts into framework-native Content objects,
|
||||
handling text, file (URI/bytes), and data parts with metadata preservation.
|
||||
handling text, url, raw, and data parts with metadata preservation.
|
||||
"""
|
||||
contents: list[Content] = []
|
||||
for part in parts:
|
||||
inner_part = part.root
|
||||
match inner_part.kind:
|
||||
part_metadata = MessageToDict(part.metadata) if part.metadata else None
|
||||
content_type = part.WhichOneof("content")
|
||||
match content_type:
|
||||
case "text":
|
||||
contents.append(
|
||||
Content.from_text(
|
||||
text=inner_part.text,
|
||||
additional_properties=inner_part.metadata,
|
||||
raw_representation=inner_part,
|
||||
text=part.text,
|
||||
additional_properties=part_metadata,
|
||||
raw_representation=part,
|
||||
)
|
||||
)
|
||||
case "file":
|
||||
if isinstance(inner_part.file, FileWithUri):
|
||||
contents.append(
|
||||
Content.from_uri(
|
||||
uri=inner_part.file.uri,
|
||||
media_type=inner_part.file.mime_type or "",
|
||||
additional_properties=inner_part.metadata,
|
||||
raw_representation=inner_part,
|
||||
)
|
||||
case "url":
|
||||
contents.append(
|
||||
Content.from_uri(
|
||||
uri=part.url,
|
||||
media_type=part.media_type or "",
|
||||
additional_properties=part_metadata,
|
||||
raw_representation=part,
|
||||
)
|
||||
elif isinstance(inner_part.file, FileWithBytes):
|
||||
contents.append(
|
||||
Content.from_data(
|
||||
data=base64.b64decode(inner_part.file.bytes),
|
||||
media_type=inner_part.file.mime_type or "",
|
||||
additional_properties=inner_part.metadata,
|
||||
raw_representation=inner_part,
|
||||
)
|
||||
)
|
||||
case "raw":
|
||||
contents.append(
|
||||
Content.from_data(
|
||||
data=part.raw,
|
||||
media_type=part.media_type or "",
|
||||
additional_properties=part_metadata,
|
||||
raw_representation=part,
|
||||
)
|
||||
)
|
||||
case "data":
|
||||
from google.protobuf.json_format import MessageToJson
|
||||
|
||||
contents.append(
|
||||
Content.from_text(
|
||||
text=json.dumps(inner_part.data),
|
||||
additional_properties=inner_part.metadata,
|
||||
raw_representation=inner_part,
|
||||
text=MessageToJson(part.data),
|
||||
additional_properties=part_metadata,
|
||||
raw_representation=part,
|
||||
)
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unknown Part kind: {inner_part.kind}")
|
||||
raise ValueError(f"Unknown Part content type: {content_type}")
|
||||
return contents
|
||||
|
||||
def _parse_messages_from_task(self, task: Task) -> list[Message]:
|
||||
"""Parse A2A Task artifacts into Messages with ASSISTANT role."""
|
||||
messages: list[Message] = []
|
||||
|
||||
if task.artifacts is not None:
|
||||
if task.artifacts:
|
||||
for artifact in task.artifacts:
|
||||
messages.append(self._parse_message_from_artifact(artifact))
|
||||
elif task.history is not None and len(task.history) > 0:
|
||||
elif task.history:
|
||||
# Include the last history item as the agent response
|
||||
history_item = task.history[-1]
|
||||
contents = self._parse_contents_from_a2a(history_item.parts)
|
||||
history_metadata = MessageToDict(history_item.metadata) if history_item.metadata else None
|
||||
messages.append(
|
||||
Message(
|
||||
role="assistant" if history_item.role == A2ARole.agent else "user",
|
||||
role="assistant" if history_item.role == A2ARole.ROLE_AGENT else "user",
|
||||
contents=contents,
|
||||
additional_properties=history_item.metadata,
|
||||
additional_properties=history_metadata,
|
||||
raw_representation=history_item,
|
||||
)
|
||||
)
|
||||
@@ -751,9 +748,10 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
def _parse_message_from_artifact(self, artifact: Artifact) -> Message:
|
||||
"""Parse A2A Artifact into Message using part contents."""
|
||||
contents = self._parse_contents_from_a2a(artifact.parts)
|
||||
artifact_metadata = MessageToDict(artifact.metadata) if artifact.metadata else None
|
||||
return Message(
|
||||
role="assistant",
|
||||
contents=contents,
|
||||
additional_properties=artifact.metadata,
|
||||
additional_properties=artifact_metadata,
|
||||
raw_representation=artifact,
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"agent-framework-core>=1.3.0,<2",
|
||||
"a2a-sdk>=0.3.5,<0.3.24",
|
||||
"a2a-sdk>=1.0.0,<2",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
|
||||
@@ -9,16 +9,13 @@ import httpx
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
Artifact,
|
||||
DataPart,
|
||||
FilePart,
|
||||
FileWithUri,
|
||||
Part,
|
||||
StreamResponse,
|
||||
Task,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskState,
|
||||
TaskStatus,
|
||||
TaskStatusUpdateEvent,
|
||||
TextPart,
|
||||
)
|
||||
from a2a.types import Message as A2AMessage
|
||||
from a2a.types import Role as A2ARole
|
||||
@@ -43,59 +40,42 @@ class MockA2AClient:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.call_count: int = 0
|
||||
self.responses: list[Any] = []
|
||||
self.resubscribe_responses: list[Any] = []
|
||||
self.responses: list[StreamResponse] = []
|
||||
self.subscribe_responses: list[StreamResponse] = []
|
||||
self.get_task_response: Task | None = None
|
||||
self.last_message: Any = None
|
||||
|
||||
def add_message_response(self, message_id: str, text: str, role: str = "agent") -> None:
|
||||
"""Add a mock Message response."""
|
||||
|
||||
# Create actual TextPart instance and wrap it in Part
|
||||
text_part = Part(root=TextPart(text=text))
|
||||
|
||||
# Create actual Message instance
|
||||
message = A2AMessage(
|
||||
message_id=message_id, role=A2ARole.agent if role == "agent" else A2ARole.user, parts=[text_part]
|
||||
message_id=message_id,
|
||||
role=A2ARole.ROLE_AGENT if role == "agent" else A2ARole.ROLE_USER,
|
||||
parts=[Part(text=text)],
|
||||
)
|
||||
self.responses.append(message)
|
||||
self.responses.append(StreamResponse(message=message))
|
||||
|
||||
def add_task_response(self, task_id: str, artifacts: list[dict[str, Any]]) -> None:
|
||||
"""Add a mock Task response."""
|
||||
# Create mock artifacts
|
||||
mock_artifacts = []
|
||||
for artifact_data in artifacts:
|
||||
# Create actual TextPart instance and wrap it in Part
|
||||
text_part = Part(root=TextPart(text=artifact_data.get("content", "Test content")))
|
||||
|
||||
artifact = Artifact(
|
||||
artifact_id=artifact_data.get("id", str(uuid4())),
|
||||
name=artifact_data.get("name", "test-artifact"),
|
||||
description=artifact_data.get("description", "Test artifact"),
|
||||
parts=[text_part],
|
||||
parts=[Part(text=artifact_data.get("content", "Test content"))],
|
||||
)
|
||||
mock_artifacts.append(artifact)
|
||||
|
||||
# Create task status
|
||||
status = TaskStatus(state=TaskState.completed, message=None)
|
||||
|
||||
# Create actual Task instance
|
||||
task = Task(
|
||||
id=task_id, context_id="test-context", status=status, artifacts=mock_artifacts if mock_artifacts else None
|
||||
)
|
||||
|
||||
# Mock the ClientEvent tuple format
|
||||
update_event = None # No specific update event for completed tasks
|
||||
client_event = (task, update_event)
|
||||
self.responses.append(client_event)
|
||||
status = TaskStatus(state=TaskState.TASK_STATE_COMPLETED)
|
||||
task = Task(id=task_id, context_id="test-context", status=status, artifacts=mock_artifacts)
|
||||
self.responses.append(StreamResponse(task=task))
|
||||
|
||||
def add_in_progress_task_response(
|
||||
self,
|
||||
task_id: str,
|
||||
context_id: str = "test-context",
|
||||
state: TaskState = TaskState.working,
|
||||
state: TaskState = TaskState.TASK_STATE_WORKING,
|
||||
text: str | None = None,
|
||||
role: A2ARole = A2ARole.agent,
|
||||
role: A2ARole = A2ARole.ROLE_AGENT,
|
||||
) -> None:
|
||||
"""Add a mock in-progress Task response (non-terminal)."""
|
||||
message = None
|
||||
@@ -103,30 +83,28 @@ class MockA2AClient:
|
||||
message = A2AMessage(
|
||||
message_id=str(uuid4()),
|
||||
role=role,
|
||||
parts=[Part(root=TextPart(text=text))],
|
||||
parts=[Part(text=text)],
|
||||
)
|
||||
status = TaskStatus(state=state, message=message)
|
||||
task = Task(id=task_id, context_id=context_id, status=status)
|
||||
client_event = (task, None)
|
||||
self.responses.append(client_event)
|
||||
self.responses.append(StreamResponse(task=task))
|
||||
|
||||
async def send_message(self, message: Any) -> AsyncIterator[Any]:
|
||||
async def send_message(self, request: Any) -> AsyncIterator[StreamResponse]:
|
||||
"""Mock send_message method that yields responses."""
|
||||
self.last_message = message
|
||||
self.last_message = getattr(request, "message", request)
|
||||
self.call_count += 1
|
||||
|
||||
# All queued responses are delivered as a single streaming batch per call.
|
||||
for response in self.responses:
|
||||
yield response
|
||||
self.responses.clear()
|
||||
|
||||
async def resubscribe(self, request: Any) -> AsyncIterator[Any]:
|
||||
"""Mock resubscribe method that yields responses."""
|
||||
async def subscribe(self, request: Any) -> AsyncIterator[StreamResponse]:
|
||||
"""Mock subscribe method that yields responses."""
|
||||
self.call_count += 1
|
||||
|
||||
for response in self.resubscribe_responses:
|
||||
for response in self.subscribe_responses:
|
||||
yield response
|
||||
self.resubscribe_responses.clear()
|
||||
self.subscribe_responses.clear()
|
||||
|
||||
async def get_task(self, request: Any) -> Task:
|
||||
"""Mock get_task method that returns a task."""
|
||||
@@ -282,16 +260,16 @@ async def test_run_with_task_response_no_artifacts(a2a_agent: A2AAgent, mock_a2a
|
||||
|
||||
async def test_run_with_unknown_response_type_raises_error(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test run() method with unknown response type raises NotImplementedError."""
|
||||
mock_a2a_client.responses.append("invalid_response")
|
||||
# An empty StreamResponse has no payload set (WhichOneof returns None)
|
||||
mock_a2a_client.responses.append(StreamResponse())
|
||||
|
||||
with raises(NotImplementedError, match="Only Message and Task responses are supported"):
|
||||
with raises(NotImplementedError, match="Unsupported StreamResponse payload"):
|
||||
await a2a_agent.run("Test message")
|
||||
|
||||
|
||||
def test_parse_messages_from_task_empty_artifacts(a2a_agent: A2AAgent) -> None:
|
||||
"""Test _parse_messages_from_task with task containing no artifacts."""
|
||||
task = MagicMock()
|
||||
task.artifacts = None
|
||||
task = Task(id="test", context_id="test", status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED))
|
||||
|
||||
result = a2a_agent._parse_messages_from_task(task)
|
||||
|
||||
@@ -300,28 +278,14 @@ def test_parse_messages_from_task_empty_artifacts(a2a_agent: A2AAgent) -> None:
|
||||
|
||||
def test_parse_messages_from_task_with_artifacts(a2a_agent: A2AAgent) -> None:
|
||||
"""Test _parse_messages_from_task with task containing artifacts."""
|
||||
task = MagicMock()
|
||||
|
||||
# Create mock artifacts
|
||||
artifact1 = MagicMock()
|
||||
artifact1.artifact_id = "art-1"
|
||||
text_part1 = MagicMock()
|
||||
text_part1.root = MagicMock()
|
||||
text_part1.root.kind = "text"
|
||||
text_part1.root.text = "Content 1"
|
||||
text_part1.root.metadata = None
|
||||
artifact1.parts = [text_part1]
|
||||
|
||||
artifact2 = MagicMock()
|
||||
artifact2.artifact_id = "art-2"
|
||||
text_part2 = MagicMock()
|
||||
text_part2.root = MagicMock()
|
||||
text_part2.root.kind = "text"
|
||||
text_part2.root.text = "Content 2"
|
||||
text_part2.root.metadata = None
|
||||
artifact2.parts = [text_part2]
|
||||
|
||||
task.artifacts = [artifact1, artifact2]
|
||||
artifact1 = Artifact(artifact_id="art-1", parts=[Part(text="Content 1")])
|
||||
artifact2 = Artifact(artifact_id="art-2", parts=[Part(text="Content 2")])
|
||||
task = Task(
|
||||
id="test",
|
||||
context_id="test",
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
|
||||
artifacts=[artifact1, artifact2],
|
||||
)
|
||||
|
||||
result = a2a_agent._parse_messages_from_task(task)
|
||||
|
||||
@@ -333,16 +297,7 @@ def test_parse_messages_from_task_with_artifacts(a2a_agent: A2AAgent) -> None:
|
||||
|
||||
def test_parse_message_from_artifact(a2a_agent: A2AAgent) -> None:
|
||||
"""Test _parse_message_from_artifact conversion."""
|
||||
artifact = MagicMock()
|
||||
artifact.artifact_id = "test-artifact"
|
||||
|
||||
text_part = MagicMock()
|
||||
text_part.root = MagicMock()
|
||||
text_part.root.kind = "text"
|
||||
text_part.root.text = "Artifact content"
|
||||
text_part.root.metadata = None
|
||||
|
||||
artifact.parts = [text_part]
|
||||
artifact = Artifact(artifact_id="test-artifact", parts=[Part(text="Artifact content")])
|
||||
|
||||
result = a2a_agent._parse_message_from_artifact(artifact)
|
||||
|
||||
@@ -373,7 +328,7 @@ def test_parse_contents_from_a2a_conversion(a2a_agent: A2AAgent) -> None:
|
||||
agent = A2AAgent(name="Test Agent", client=MockA2AClient(), http_client=None)
|
||||
|
||||
# Create A2A parts
|
||||
parts = [Part(root=TextPart(text="First part")), Part(root=TextPart(text="Second part"))]
|
||||
parts = [Part(text="First part"), Part(text="Second part")]
|
||||
|
||||
# Convert to contents
|
||||
contents = agent._parse_contents_from_a2a(parts)
|
||||
@@ -398,7 +353,7 @@ def test_prepare_message_for_a2a_with_error_content(a2a_agent: A2AAgent) -> None
|
||||
|
||||
# Verify conversion
|
||||
assert len(a2a_message.parts) == 1
|
||||
assert a2a_message.parts[0].root.text == "Test error message"
|
||||
assert a2a_message.parts[0].text == "Test error message"
|
||||
|
||||
|
||||
def test_prepare_message_for_a2a_with_uri_content(a2a_agent: A2AAgent) -> None:
|
||||
@@ -413,8 +368,8 @@ def test_prepare_message_for_a2a_with_uri_content(a2a_agent: A2AAgent) -> None:
|
||||
|
||||
# Verify conversion
|
||||
assert len(a2a_message.parts) == 1
|
||||
assert a2a_message.parts[0].root.file.uri == "http://example.com/file.pdf"
|
||||
assert a2a_message.parts[0].root.file.mime_type == "application/pdf"
|
||||
assert a2a_message.parts[0].url == "http://example.com/file.pdf"
|
||||
assert a2a_message.parts[0].media_type == "application/pdf"
|
||||
|
||||
|
||||
def test_prepare_message_for_a2a_with_data_content(a2a_agent: A2AAgent) -> None:
|
||||
@@ -429,8 +384,8 @@ def test_prepare_message_for_a2a_with_data_content(a2a_agent: A2AAgent) -> None:
|
||||
|
||||
# Verify conversion
|
||||
assert len(a2a_message.parts) == 1
|
||||
assert a2a_message.parts[0].root.file.bytes == "SGVsbG8gV29ybGQ="
|
||||
assert a2a_message.parts[0].root.file.mime_type == "text/plain"
|
||||
assert a2a_message.parts[0].raw == b"Hello World"
|
||||
assert a2a_message.parts[0].media_type == "text/plain"
|
||||
|
||||
|
||||
def test_prepare_message_for_a2a_empty_contents_raises_error(a2a_agent: A2AAgent) -> None:
|
||||
@@ -518,10 +473,10 @@ def test_prepare_message_for_a2a_with_multiple_contents() -> None:
|
||||
assert len(result.parts) == 4
|
||||
|
||||
# Check each part type
|
||||
assert result.parts[0].root.kind == "text" # Regular text
|
||||
assert result.parts[1].root.kind == "file" # Binary data
|
||||
assert result.parts[2].root.kind == "file" # URI content
|
||||
assert result.parts[3].root.kind == "text" # JSON text remains as text (no parsing)
|
||||
assert result.parts[0].WhichOneof("content") == "text" # Regular text
|
||||
assert result.parts[1].WhichOneof("content") == "raw" # Binary data
|
||||
assert result.parts[2].WhichOneof("content") == "url" # URI content
|
||||
assert result.parts[3].WhichOneof("content") == "text" # JSON text remains as text (no parsing)
|
||||
|
||||
|
||||
def test_prepare_message_for_a2a_forwards_context_id() -> None:
|
||||
@@ -573,19 +528,29 @@ def test_prepare_message_for_a2a_message_context_id_takes_precedence() -> None:
|
||||
|
||||
|
||||
def test_parse_contents_from_a2a_with_data_part() -> None:
|
||||
"""Test conversion of A2A DataPart."""
|
||||
"""Test conversion of A2A data Part."""
|
||||
from google.protobuf.json_format import ParseDict
|
||||
from google.protobuf.struct_pb2 import Struct, Value
|
||||
|
||||
agent = A2AAgent(client=MagicMock(), http_client=None)
|
||||
|
||||
# Create DataPart
|
||||
data_part = Part(root=DataPart(data={"key": "value", "number": 42}, metadata={"source": "test"}))
|
||||
# Create Part with data (protobuf Value containing a struct)
|
||||
value = ParseDict({"key": "value", "number": 42}, Value())
|
||||
metadata = Struct()
|
||||
metadata.update({"source": "test"})
|
||||
data_part = Part(data=value, metadata=metadata)
|
||||
|
||||
contents = agent._parse_contents_from_a2a([data_part])
|
||||
|
||||
assert len(contents) == 1
|
||||
|
||||
assert contents[0].type == "text"
|
||||
assert contents[0].text == '{"key": "value", "number": 42}'
|
||||
# MessageToJson may format slightly differently — verify the parsed structure
|
||||
import json
|
||||
|
||||
parsed = json.loads(contents[0].text)
|
||||
assert parsed["key"] == "value"
|
||||
assert parsed["number"] == 42
|
||||
assert contents[0].additional_properties == {"source": "test"}
|
||||
|
||||
|
||||
@@ -593,12 +558,11 @@ def test_parse_contents_from_a2a_unknown_part_kind() -> None:
|
||||
"""Test error handling for unknown A2A part kind."""
|
||||
agent = A2AAgent(client=MagicMock(), http_client=None)
|
||||
|
||||
# Create a mock part with unknown kind
|
||||
mock_part = MagicMock()
|
||||
mock_part.root.kind = "unknown_kind"
|
||||
# Create a Part with no content field set (WhichOneof returns None)
|
||||
empty_part = Part()
|
||||
|
||||
with raises(ValueError, match="Unknown Part kind: unknown_kind"):
|
||||
agent._parse_contents_from_a2a([mock_part])
|
||||
with raises(ValueError, match="Unknown Part content type"):
|
||||
agent._parse_contents_from_a2a([empty_part])
|
||||
|
||||
|
||||
def test_prepare_message_for_a2a_with_hosted_file() -> None:
|
||||
@@ -617,14 +581,8 @@ def test_prepare_message_for_a2a_with_hosted_file() -> None:
|
||||
# Verify the conversion
|
||||
assert len(result.parts) == 1
|
||||
part = result.parts[0]
|
||||
assert part.root.kind == "file"
|
||||
|
||||
# Verify it's a FilePart with FileWithUri
|
||||
|
||||
assert isinstance(part.root, FilePart)
|
||||
assert isinstance(part.root.file, FileWithUri)
|
||||
assert part.root.file.uri == "hosted://storage/document.pdf"
|
||||
assert part.root.file.mime_type is None # HostedFileContent doesn't specify media_type
|
||||
assert part.WhichOneof("content") == "url"
|
||||
assert part.url == "hosted://storage/document.pdf"
|
||||
|
||||
|
||||
def test_parse_contents_from_a2a_with_hosted_file_uri() -> None:
|
||||
@@ -632,15 +590,8 @@ def test_parse_contents_from_a2a_with_hosted_file_uri() -> None:
|
||||
|
||||
agent = A2AAgent(client=MagicMock(), http_client=None)
|
||||
|
||||
# Create FilePart with hosted file URI (simulating what A2A would send back)
|
||||
file_part = Part(
|
||||
root=FilePart(
|
||||
file=FileWithUri(
|
||||
uri="hosted://storage/document.pdf",
|
||||
mime_type=None,
|
||||
)
|
||||
)
|
||||
)
|
||||
# Create Part with hosted file URL (simulating what A2A would send back)
|
||||
file_part = Part(url="hosted://storage/document.pdf")
|
||||
|
||||
contents = agent._parse_contents_from_a2a([file_part]) # noqa: SLF001
|
||||
|
||||
@@ -671,9 +622,11 @@ def test_auth_interceptor_parameter() -> None:
|
||||
|
||||
def test_transport_negotiation_both_fail() -> None:
|
||||
"""Test that RuntimeError is raised when both primary and fallback transport negotiation fail."""
|
||||
# Create a mock agent card
|
||||
# Create a mock agent card with supported_interfaces
|
||||
mock_agent_card = MagicMock(spec=AgentCard)
|
||||
mock_agent_card.url = "http://test-agent.example.com"
|
||||
mock_interface = MagicMock()
|
||||
mock_interface.url = "http://test-agent.example.com"
|
||||
mock_agent_card.supported_interfaces = [mock_interface]
|
||||
mock_agent_card.name = "Test Agent"
|
||||
mock_agent_card.description = "A test agent"
|
||||
|
||||
@@ -751,7 +704,7 @@ def test_a2a_agent_initialization_with_timeout_parameter() -> None:
|
||||
|
||||
async def test_working_task_emits_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that a working (non-terminal) task yields an update with a continuation token when background=True."""
|
||||
mock_a2a_client.add_in_progress_task_response("task-wip", context_id="ctx-1", state=TaskState.working)
|
||||
mock_a2a_client.add_in_progress_task_response("task-wip", context_id="ctx-1", state=TaskState.TASK_STATE_WORKING)
|
||||
|
||||
response = await a2a_agent.run("Start long task", background=True)
|
||||
|
||||
@@ -763,7 +716,7 @@ async def test_working_task_emits_continuation_token(a2a_agent: A2AAgent, mock_a
|
||||
|
||||
async def test_submitted_task_emits_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that a submitted task yields a continuation token when background=True."""
|
||||
mock_a2a_client.add_in_progress_task_response("task-sub", state=TaskState.submitted)
|
||||
mock_a2a_client.add_in_progress_task_response("task-sub", state=TaskState.TASK_STATE_SUBMITTED)
|
||||
|
||||
response = await a2a_agent.run("Submit task", background=True)
|
||||
|
||||
@@ -775,7 +728,7 @@ async def test_input_required_task_emits_continuation_token(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Test that an input_required task yields a continuation token when background=True."""
|
||||
mock_a2a_client.add_in_progress_task_response("task-input", state=TaskState.input_required)
|
||||
mock_a2a_client.add_in_progress_task_response("task-input", state=TaskState.TASK_STATE_INPUT_REQUIRED)
|
||||
|
||||
response = await a2a_agent.run("Need input", background=True)
|
||||
|
||||
@@ -785,7 +738,7 @@ async def test_input_required_task_emits_continuation_token(
|
||||
|
||||
async def test_working_task_no_token_without_background(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that background=False (default) does not emit continuation tokens for in-progress tasks."""
|
||||
mock_a2a_client.add_in_progress_task_response("task-fg", context_id="ctx-fg", state=TaskState.working)
|
||||
mock_a2a_client.add_in_progress_task_response("task-fg", context_id="ctx-fg", state=TaskState.TASK_STATE_WORKING)
|
||||
|
||||
response = await a2a_agent.run("Foreground task")
|
||||
|
||||
@@ -805,7 +758,7 @@ async def test_completed_task_has_no_continuation_token(a2a_agent: A2AAgent, moc
|
||||
|
||||
async def test_streaming_emits_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that streaming with background=True yields updates with continuation tokens."""
|
||||
mock_a2a_client.add_in_progress_task_response("task-stream", context_id="ctx-s", state=TaskState.working)
|
||||
mock_a2a_client.add_in_progress_task_response("task-stream", context_id="ctx-s", state=TaskState.TASK_STATE_WORKING)
|
||||
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
async for update in a2a_agent.run("Stream task", stream=True, background=True):
|
||||
@@ -820,14 +773,14 @@ async def test_streaming_emits_continuation_token(a2a_agent: A2AAgent, mock_a2a_
|
||||
async def test_resume_via_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that run() with continuation_token uses resubscribe instead of send_message."""
|
||||
# Set up the resubscribe response (completed task)
|
||||
status = TaskStatus(state=TaskState.completed, message=None)
|
||||
status = TaskStatus(state=TaskState.TASK_STATE_COMPLETED, message=None)
|
||||
artifact = Artifact(
|
||||
artifact_id="art-resume",
|
||||
name="result",
|
||||
parts=[Part(root=TextPart(text="Resumed result"))],
|
||||
parts=[Part(text="Resumed result")],
|
||||
)
|
||||
task = Task(id="task-resume", context_id="ctx-r", status=status, artifacts=[artifact])
|
||||
mock_a2a_client.resubscribe_responses.append((task, None))
|
||||
mock_a2a_client.subscribe_responses.append(StreamResponse(task=task))
|
||||
|
||||
token = A2AContinuationToken(task_id="task-resume", context_id="ctx-r")
|
||||
response = await a2a_agent.run(continuation_token=token)
|
||||
@@ -841,17 +794,17 @@ async def test_resume_via_continuation_token(a2a_agent: A2AAgent, mock_a2a_clien
|
||||
async def test_resume_streaming_via_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that streaming run() with continuation_token and background=True uses resubscribe."""
|
||||
# Still working
|
||||
status_wip = TaskStatus(state=TaskState.working, message=None)
|
||||
status_wip = TaskStatus(state=TaskState.TASK_STATE_WORKING, message=None)
|
||||
task_wip = Task(id="task-rs", context_id="ctx-rs", status=status_wip)
|
||||
# Then completed
|
||||
status_done = TaskStatus(state=TaskState.completed, message=None)
|
||||
status_done = TaskStatus(state=TaskState.TASK_STATE_COMPLETED, message=None)
|
||||
artifact = Artifact(
|
||||
artifact_id="art-rs",
|
||||
name="result",
|
||||
parts=[Part(root=TextPart(text="Stream resumed"))],
|
||||
parts=[Part(text="Stream resumed")],
|
||||
)
|
||||
task_done = Task(id="task-rs", context_id="ctx-rs", status=status_done, artifacts=[artifact])
|
||||
mock_a2a_client.resubscribe_responses.extend([(task_wip, None), (task_done, None)])
|
||||
mock_a2a_client.subscribe_responses.extend([StreamResponse(task=task_wip), StreamResponse(task=task_done)])
|
||||
|
||||
token = A2AContinuationToken(task_id="task-rs", context_id="ctx-rs")
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
@@ -868,7 +821,7 @@ async def test_resume_streaming_via_continuation_token(a2a_agent: A2AAgent, mock
|
||||
|
||||
async def test_poll_task_in_progress(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test poll_task returns continuation token when task is still in progress."""
|
||||
status = TaskStatus(state=TaskState.working, message=None)
|
||||
status = TaskStatus(state=TaskState.TASK_STATE_WORKING, message=None)
|
||||
mock_a2a_client.get_task_response = Task(id="task-poll", context_id="ctx-p", status=status)
|
||||
|
||||
token = A2AContinuationToken(task_id="task-poll", context_id="ctx-p")
|
||||
@@ -880,11 +833,11 @@ async def test_poll_task_in_progress(a2a_agent: A2AAgent, mock_a2a_client: MockA
|
||||
|
||||
async def test_poll_task_completed(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test poll_task returns result with no continuation token when task is complete."""
|
||||
status = TaskStatus(state=TaskState.completed, message=None)
|
||||
status = TaskStatus(state=TaskState.TASK_STATE_COMPLETED, message=None)
|
||||
artifact = Artifact(
|
||||
artifact_id="art-poll",
|
||||
name="result",
|
||||
parts=[Part(root=TextPart(text="Poll result"))],
|
||||
parts=[Part(text="Poll result")],
|
||||
)
|
||||
mock_a2a_client.get_task_response = Task(
|
||||
id="task-poll-done", context_id="ctx-pd", status=status, artifacts=[artifact]
|
||||
@@ -1105,9 +1058,9 @@ async def test_run_with_continuation_token_does_not_require_messages(mock_a2a_cl
|
||||
task = Task(
|
||||
id="task-cont",
|
||||
context_id="ctx-cont",
|
||||
status=TaskStatus(state=TaskState.completed, message=None),
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED, message=None),
|
||||
)
|
||||
mock_a2a_client.resubscribe_responses.append((task, None))
|
||||
mock_a2a_client.subscribe_responses.append(StreamResponse(task=task))
|
||||
|
||||
agent = A2AAgent(
|
||||
name="Test Agent",
|
||||
@@ -1176,8 +1129,10 @@ async def test_streaming_working_update_without_message_is_skipped(
|
||||
|
||||
|
||||
async def test_streaming_working_update_user_role_mapping(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that A2ARole.user in status message maps to role='user'."""
|
||||
mock_a2a_client.add_in_progress_task_response("task-u", context_id="ctx-u", text="User echo", role=A2ARole.user)
|
||||
"""Test that A2ARole.ROLE_USER in status message maps to role='user'."""
|
||||
mock_a2a_client.add_in_progress_task_response(
|
||||
"task-u", context_id="ctx-u", text="User echo", role=A2ARole.ROLE_USER
|
||||
)
|
||||
mock_a2a_client.add_task_response("task-u", [{"id": "art-u", "content": "Done"}])
|
||||
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
@@ -1224,9 +1179,9 @@ async def test_terminal_no_artifacts_after_working_with_content(
|
||||
"""Test that a terminal task with no artifacts after working-state messages does not re-emit the working content."""
|
||||
mock_a2a_client.add_in_progress_task_response("task-t", context_id="ctx-t", text="Working on it...")
|
||||
# Terminal task with no artifacts and no history
|
||||
status = TaskStatus(state=TaskState.completed, message=None)
|
||||
status = TaskStatus(state=TaskState.TASK_STATE_COMPLETED, message=None)
|
||||
task = Task(id="task-t", context_id="ctx-t", status=status)
|
||||
mock_a2a_client.responses.append((task, None))
|
||||
mock_a2a_client.responses.append(StreamResponse(task=task))
|
||||
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
async for update in a2a_agent.run("Hello", stream=True):
|
||||
@@ -1245,12 +1200,12 @@ async def test_streaming_working_update_with_empty_parts_is_skipped(
|
||||
# Construct a message with an empty parts list (distinct from message=None)
|
||||
message = A2AMessage(
|
||||
message_id=str(uuid4()),
|
||||
role=A2ARole.agent,
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[],
|
||||
)
|
||||
status = TaskStatus(state=TaskState.working, message=message)
|
||||
status = TaskStatus(state=TaskState.TASK_STATE_WORKING, message=message)
|
||||
task = Task(id="task-ep", context_id="ctx-ep", status=status)
|
||||
mock_a2a_client.responses.append((task, None))
|
||||
mock_a2a_client.responses.append(StreamResponse(task=task))
|
||||
mock_a2a_client.add_task_response("task-ep", [{"id": "art-ep", "content": "Result"}])
|
||||
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
@@ -1265,13 +1220,12 @@ async def test_streaming_artifact_update_event_yields_content(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Test that streaming artifact update events yield incremental content."""
|
||||
task = Task(id="task-art", context_id="ctx-art", status=TaskStatus(state=TaskState.working, message=None))
|
||||
artifact = Artifact(
|
||||
artifact_id="artifact-1",
|
||||
parts=[Part(root=TextPart(text="Hello"))],
|
||||
parts=[Part(text="Hello")],
|
||||
)
|
||||
update_event = TaskArtifactUpdateEvent(task_id="task-art", context_id="ctx-art", artifact=artifact, append=False)
|
||||
mock_a2a_client.responses.append((task, update_event))
|
||||
mock_a2a_client.responses.append(StreamResponse(artifact_update=update_event))
|
||||
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
async for update in a2a_agent.run("Hello", stream=True):
|
||||
@@ -1291,17 +1245,15 @@ async def test_streaming_status_update_event_yields_content(
|
||||
task_id="task-status",
|
||||
context_id="ctx-status",
|
||||
status=TaskStatus(
|
||||
state=TaskState.working,
|
||||
state=TaskState.TASK_STATE_WORKING,
|
||||
message=A2AMessage(
|
||||
message_id=str(uuid4()),
|
||||
role=A2ARole.agent,
|
||||
parts=[Part(root=TextPart(text="Still working"))],
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="Still working")],
|
||||
),
|
||||
),
|
||||
final=False,
|
||||
)
|
||||
task = Task(id="task-status", context_id="ctx-status", status=TaskStatus(state=TaskState.working, message=None))
|
||||
mock_a2a_client.responses.append((task, update_event))
|
||||
mock_a2a_client.responses.append(StreamResponse(status_update=update_event))
|
||||
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
async for update in a2a_agent.run("Hello", stream=True):
|
||||
@@ -1317,13 +1269,12 @@ async def test_streaming_artifact_update_event_does_not_duplicate_terminal_task_
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Test that streamed artifact chunks are not re-emitted from the final terminal task."""
|
||||
working_task = Task(id="task-art-dup", context_id="ctx-art-dup", status=TaskStatus(state=TaskState.working))
|
||||
first_chunk = TaskArtifactUpdateEvent(
|
||||
task_id="task-art-dup",
|
||||
context_id="ctx-art-dup",
|
||||
artifact=Artifact(
|
||||
artifact_id="artifact-dup",
|
||||
parts=[Part(root=TextPart(text="Hello "))],
|
||||
parts=[Part(text="Hello ")],
|
||||
),
|
||||
append=False,
|
||||
)
|
||||
@@ -1332,32 +1283,26 @@ async def test_streaming_artifact_update_event_does_not_duplicate_terminal_task_
|
||||
context_id="ctx-art-dup",
|
||||
artifact=Artifact(
|
||||
artifact_id="artifact-dup",
|
||||
parts=[Part(root=TextPart(text="world"))],
|
||||
parts=[Part(text="world")],
|
||||
),
|
||||
append=True,
|
||||
)
|
||||
terminal_task = Task(
|
||||
id="task-art-dup",
|
||||
context_id="ctx-art-dup",
|
||||
status=TaskStatus(state=TaskState.completed, message=None),
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
|
||||
artifacts=[
|
||||
Artifact(
|
||||
artifact_id="artifact-dup",
|
||||
parts=[Part(root=TextPart(text="Hello world"))],
|
||||
parts=[Part(text="Hello world")],
|
||||
)
|
||||
],
|
||||
)
|
||||
terminal_event = TaskStatusUpdateEvent(
|
||||
task_id="task-art-dup",
|
||||
context_id="ctx-art-dup",
|
||||
status=TaskStatus(state=TaskState.completed, message=None),
|
||||
final=True,
|
||||
)
|
||||
|
||||
mock_a2a_client.responses.extend([
|
||||
(working_task, first_chunk),
|
||||
(working_task, second_chunk),
|
||||
(terminal_task, terminal_event),
|
||||
StreamResponse(artifact_update=first_chunk),
|
||||
StreamResponse(artifact_update=second_chunk),
|
||||
StreamResponse(task=terminal_task),
|
||||
])
|
||||
|
||||
stream = a2a_agent.run("Hello", stream=True)
|
||||
@@ -1378,21 +1323,15 @@ async def test_streaming_terminal_task_artifacts_are_emitted_when_terminal_event
|
||||
terminal_task = Task(
|
||||
id="task-art-final",
|
||||
context_id="ctx-art-final",
|
||||
status=TaskStatus(state=TaskState.completed, message=None),
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
|
||||
artifacts=[
|
||||
Artifact(
|
||||
artifact_id="artifact-final",
|
||||
parts=[Part(root=TextPart(text="Final artifact"))],
|
||||
parts=[Part(text="Final artifact")],
|
||||
)
|
||||
],
|
||||
)
|
||||
terminal_event = TaskStatusUpdateEvent(
|
||||
task_id="task-art-final",
|
||||
context_id="ctx-art-final",
|
||||
status=TaskStatus(state=TaskState.completed, message=None),
|
||||
final=True,
|
||||
)
|
||||
mock_a2a_client.responses.append((terminal_task, terminal_event))
|
||||
mock_a2a_client.responses.append(StreamResponse(task=terminal_task))
|
||||
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
async for update in a2a_agent.run("Hello", stream=True):
|
||||
@@ -1407,41 +1346,34 @@ async def test_streaming_terminal_task_only_emits_unstreamed_artifacts(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Test that the terminal task only emits artifacts that were not already streamed incrementally."""
|
||||
working_task = Task(id="task-art-mixed", context_id="ctx-art-mixed", status=TaskStatus(state=TaskState.working))
|
||||
streamed_chunk = TaskArtifactUpdateEvent(
|
||||
task_id="task-art-mixed",
|
||||
context_id="ctx-art-mixed",
|
||||
artifact=Artifact(
|
||||
artifact_id="artifact-streamed",
|
||||
parts=[Part(root=TextPart(text="Hello"))],
|
||||
parts=[Part(text="Hello")],
|
||||
),
|
||||
append=False,
|
||||
)
|
||||
terminal_task = Task(
|
||||
id="task-art-mixed",
|
||||
context_id="ctx-art-mixed",
|
||||
status=TaskStatus(state=TaskState.completed, message=None),
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
|
||||
artifacts=[
|
||||
Artifact(
|
||||
artifact_id="artifact-streamed",
|
||||
parts=[Part(root=TextPart(text="Hello"))],
|
||||
parts=[Part(text="Hello")],
|
||||
),
|
||||
Artifact(
|
||||
artifact_id="artifact-final",
|
||||
parts=[Part(root=TextPart(text="Goodbye"))],
|
||||
parts=[Part(text="Goodbye")],
|
||||
),
|
||||
],
|
||||
)
|
||||
terminal_event = TaskStatusUpdateEvent(
|
||||
task_id="task-art-mixed",
|
||||
context_id="ctx-art-mixed",
|
||||
status=TaskStatus(state=TaskState.completed, message=None),
|
||||
final=True,
|
||||
)
|
||||
|
||||
mock_a2a_client.responses.extend([
|
||||
(working_task, streamed_chunk),
|
||||
(terminal_task, terminal_event),
|
||||
StreamResponse(artifact_update=streamed_chunk),
|
||||
StreamResponse(task=terminal_task),
|
||||
])
|
||||
|
||||
stream = a2a_agent.run("Hello", stream=True)
|
||||
@@ -1463,11 +1395,11 @@ async def test_message_metadata_propagated(a2a_agent: A2AAgent, mock_a2a_client:
|
||||
"""A2AMessage.metadata should appear on response.additional_properties."""
|
||||
msg = A2AMessage(
|
||||
message_id="msg-meta",
|
||||
role=A2ARole.agent,
|
||||
parts=[Part(root=TextPart(text="hi"))],
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="hi")],
|
||||
metadata={"source": "server", "trace_id": "abc"},
|
||||
)
|
||||
mock_a2a_client.responses.append(msg)
|
||||
mock_a2a_client.responses.append(StreamResponse(message=msg))
|
||||
|
||||
response = await a2a_agent.run("hello")
|
||||
assert response.additional_properties["a2a_metadata"]["source"] == "server"
|
||||
@@ -1479,16 +1411,16 @@ async def test_artifact_metadata_propagated(a2a_agent: A2AAgent, mock_a2a_client
|
||||
task = Task(
|
||||
id="task-art-meta",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.completed),
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
|
||||
artifacts=[
|
||||
Artifact(
|
||||
artifact_id="a1",
|
||||
parts=[Part(root=TextPart(text="result"))],
|
||||
parts=[Part(text="result")],
|
||||
metadata={"artifact_key": "artifact_value"},
|
||||
),
|
||||
],
|
||||
)
|
||||
mock_a2a_client.responses.append((task, None))
|
||||
mock_a2a_client.responses.append(StreamResponse(task=task))
|
||||
|
||||
response = await a2a_agent.run("go")
|
||||
assert response.additional_properties["a2a_metadata"]["artifact_key"] == "artifact_value"
|
||||
@@ -1499,13 +1431,13 @@ async def test_task_metadata_propagated_to_response(a2a_agent: A2AAgent, mock_a2
|
||||
task = Task(
|
||||
id="task-meta",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.completed),
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
|
||||
artifacts=[
|
||||
Artifact(artifact_id="a1", parts=[Part(root=TextPart(text="done"))]),
|
||||
Artifact(artifact_id="a1", parts=[Part(text="done")]),
|
||||
],
|
||||
metadata={"task_key": "task_value"},
|
||||
)
|
||||
mock_a2a_client.responses.append((task, None))
|
||||
mock_a2a_client.responses.append(StreamResponse(task=task))
|
||||
|
||||
response = await a2a_agent.run("go")
|
||||
assert response.additional_properties["a2a_metadata"]["task_key"] == "task_value"
|
||||
@@ -1518,33 +1450,22 @@ async def test_task_artifact_update_event_metadata_merged(a2a_agent: A2AAgent, m
|
||||
context_id="ctx",
|
||||
artifact=Artifact(
|
||||
artifact_id="a1",
|
||||
parts=[Part(root=TextPart(text="chunk"))],
|
||||
parts=[Part(text="chunk")],
|
||||
metadata={"from_artifact": True},
|
||||
),
|
||||
metadata={"from_event": True},
|
||||
)
|
||||
working_task = Task(
|
||||
id="task-ae",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.working),
|
||||
)
|
||||
terminal_task = Task(
|
||||
id="task-ae",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.completed),
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
|
||||
artifacts=[
|
||||
Artifact(artifact_id="a1", parts=[Part(root=TextPart(text="chunk"))]),
|
||||
Artifact(artifact_id="a1", parts=[Part(text="chunk")]),
|
||||
],
|
||||
)
|
||||
terminal_event = TaskStatusUpdateEvent(
|
||||
task_id="task-ae",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.completed),
|
||||
final=True,
|
||||
)
|
||||
mock_a2a_client.responses.extend([
|
||||
(working_task, artifact_event),
|
||||
(terminal_task, terminal_event),
|
||||
StreamResponse(artifact_update=artifact_event),
|
||||
StreamResponse(task=terminal_task),
|
||||
])
|
||||
|
||||
stream = a2a_agent.run("hello", stream=True)
|
||||
@@ -1563,39 +1484,27 @@ async def test_task_status_update_event_metadata_merged(a2a_agent: A2AAgent, moc
|
||||
task_id="task-se",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(
|
||||
state=TaskState.working,
|
||||
state=TaskState.TASK_STATE_WORKING,
|
||||
message=A2AMessage(
|
||||
message_id="m1",
|
||||
role=A2ARole.agent,
|
||||
parts=[Part(root=TextPart(text="working..."))],
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="working...")],
|
||||
metadata={"msg_key": "msg_val"},
|
||||
),
|
||||
),
|
||||
final=False,
|
||||
metadata={"event_key": "event_val"},
|
||||
)
|
||||
working_task = Task(
|
||||
id="task-se",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.working),
|
||||
)
|
||||
terminal_task = Task(
|
||||
id="task-se",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.completed),
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
|
||||
artifacts=[
|
||||
Artifact(artifact_id="a1", parts=[Part(root=TextPart(text="done"))]),
|
||||
Artifact(artifact_id="a1", parts=[Part(text="done")]),
|
||||
],
|
||||
)
|
||||
terminal_event = TaskStatusUpdateEvent(
|
||||
task_id="task-se",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.completed),
|
||||
final=True,
|
||||
)
|
||||
mock_a2a_client.responses.extend([
|
||||
(working_task, status_event),
|
||||
(terminal_task, terminal_event),
|
||||
StreamResponse(status_update=status_event),
|
||||
StreamResponse(task=terminal_task),
|
||||
])
|
||||
|
||||
stream = a2a_agent.run("hello", stream=True)
|
||||
@@ -1613,17 +1522,17 @@ async def test_history_message_metadata_propagated(a2a_agent: A2AAgent, mock_a2a
|
||||
task = Task(
|
||||
id="task-hist",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.completed),
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
|
||||
history=[
|
||||
A2AMessage(
|
||||
message_id="h1",
|
||||
role=A2ARole.agent,
|
||||
parts=[Part(root=TextPart(text="reply"))],
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="reply")],
|
||||
metadata={"history_key": "history_value"},
|
||||
),
|
||||
],
|
||||
)
|
||||
mock_a2a_client.responses.append((task, None))
|
||||
mock_a2a_client.responses.append(StreamResponse(task=task))
|
||||
|
||||
response = await a2a_agent.run("go")
|
||||
assert response.additional_properties["a2a_metadata"]["history_key"] == "history_value"
|
||||
@@ -1636,10 +1545,10 @@ async def test_continuation_token_update_carries_task_metadata(
|
||||
task = Task(
|
||||
id="task-cont",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.working),
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
|
||||
metadata={"bg_key": "bg_value"},
|
||||
)
|
||||
mock_a2a_client.responses.append((task, None))
|
||||
mock_a2a_client.responses.append(StreamResponse(task=task))
|
||||
|
||||
response = await a2a_agent.run("go", background=True)
|
||||
assert response.continuation_token is not None
|
||||
@@ -1652,10 +1561,10 @@ async def test_none_metadata_leaves_additional_properties_empty(
|
||||
"""When A2A types have no metadata, additional_properties should remain empty/default."""
|
||||
msg = A2AMessage(
|
||||
message_id="msg-none",
|
||||
role=A2ARole.agent,
|
||||
parts=[Part(root=TextPart(text="no meta"))],
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="no meta")],
|
||||
)
|
||||
mock_a2a_client.responses.append(msg)
|
||||
mock_a2a_client.responses.append(StreamResponse(message=msg))
|
||||
|
||||
response = await a2a_agent.run("hello")
|
||||
assert not response.additional_properties
|
||||
|
||||
@@ -3,7 +3,7 @@ from asyncio import CancelledError
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from a2a.types import Task, TaskState, TextPart
|
||||
from a2a.types import Part, Task, TaskState
|
||||
from agent_framework import (
|
||||
AgentResponseUpdate,
|
||||
Content,
|
||||
@@ -48,7 +48,7 @@ def mock_task() -> Task:
|
||||
task = MagicMock(spec=Task)
|
||||
task.id = str(uuid4())
|
||||
task.context_id = str(uuid4())
|
||||
task.state = TaskState.completed
|
||||
task.state = TaskState.TASK_STATE_COMPLETED
|
||||
return task
|
||||
|
||||
|
||||
@@ -244,7 +244,7 @@ class TestA2AExecutorExecute:
|
||||
executor._agent.run = AsyncMock(return_value=response)
|
||||
executor._agent.create_session = MagicMock()
|
||||
|
||||
with patch("agent_framework_a2a._a2a_executor.new_task") as mock_new_task:
|
||||
with patch("agent_framework_a2a._a2a_executor.new_task_from_user_message") as mock_new_task:
|
||||
mock_task = MagicMock(spec=Task)
|
||||
mock_task.id = "task-new"
|
||||
mock_task.context_id = "ctx-123"
|
||||
@@ -341,9 +341,7 @@ class TestA2AExecutorExecute:
|
||||
# Assert
|
||||
mock_updater.update_status.assert_called()
|
||||
call_args_list = mock_updater.update_status.call_args_list
|
||||
assert any(
|
||||
call[1].get("state") == TaskState.canceled and call[1].get("final") is True for call in call_args_list
|
||||
)
|
||||
assert any(call[1].get("state") == TaskState.TASK_STATE_CANCELED for call in call_args_list)
|
||||
|
||||
async def test_execute_handles_generic_exception(
|
||||
self,
|
||||
@@ -382,14 +380,12 @@ class TestA2AExecutorExecute:
|
||||
args, _ = mock_updater.new_agent_message.call_args
|
||||
parts = args[0]
|
||||
assert len(parts) == 1
|
||||
assert isinstance(parts[0].root, TextPart)
|
||||
assert parts[0].root.text == error_message
|
||||
assert isinstance(parts[0], Part)
|
||||
assert parts[0].text == error_message
|
||||
|
||||
call_args_list = mock_updater.update_status.call_args_list
|
||||
assert any(
|
||||
call[1].get("state") == TaskState.failed
|
||||
and call[1].get("final") is True
|
||||
and call[1].get("message") == "error_message_obj"
|
||||
call[1].get("state") == TaskState.TASK_STATE_FAILED and call[1].get("message") == "error_message_obj"
|
||||
for call in call_args_list
|
||||
)
|
||||
|
||||
@@ -630,7 +626,7 @@ class TestA2AExecutorHandleEvents:
|
||||
# Assert
|
||||
mock_updater.update_status.assert_called_once()
|
||||
call_args = mock_updater.update_status.call_args
|
||||
assert call_args.kwargs["state"] == TaskState.working
|
||||
assert call_args.kwargs["state"] == TaskState.TASK_STATE_WORKING
|
||||
assert mock_updater.new_agent_message.called
|
||||
|
||||
async def test_handle_multiple_text_contents(self, executor: A2AExecutor, mock_updater: MagicMock) -> None:
|
||||
@@ -666,7 +662,7 @@ class TestA2AExecutorHandleEvents:
|
||||
# Assert
|
||||
mock_updater.update_status.assert_called_once()
|
||||
call_args = mock_updater.update_status.call_args
|
||||
assert call_args.kwargs["state"] == TaskState.working
|
||||
assert call_args.kwargs["state"] == TaskState.TASK_STATE_WORKING
|
||||
|
||||
async def test_handle_uri_content(self, executor: A2AExecutor, mock_updater: MagicMock) -> None:
|
||||
"""Test handling messages with URI content."""
|
||||
@@ -683,7 +679,7 @@ class TestA2AExecutorHandleEvents:
|
||||
# Assert
|
||||
mock_updater.update_status.assert_called_once()
|
||||
call_args = mock_updater.update_status.call_args
|
||||
assert call_args.kwargs["state"] == TaskState.working
|
||||
assert call_args.kwargs["state"] == TaskState.TASK_STATE_WORKING
|
||||
|
||||
async def test_handle_mixed_content_types(self, executor: A2AExecutor, mock_updater: MagicMock) -> None:
|
||||
"""Test handling messages with mixed content types."""
|
||||
@@ -705,7 +701,7 @@ class TestA2AExecutorHandleEvents:
|
||||
# Assert
|
||||
mock_updater.update_status.assert_called_once()
|
||||
call_args = mock_updater.update_status.call_args
|
||||
assert call_args.kwargs["state"] == TaskState.working
|
||||
assert call_args.kwargs["state"] == TaskState.TASK_STATE_WORKING
|
||||
|
||||
async def test_handle_with_additional_properties(self, executor: A2AExecutor, mock_updater: MagicMock) -> None:
|
||||
"""Test handling messages with additional properties metadata."""
|
||||
@@ -778,7 +774,7 @@ class TestA2AExecutorHandleEvents:
|
||||
|
||||
# Assert
|
||||
call_kwargs = mock_updater.update_status.call_args.kwargs
|
||||
assert call_kwargs["state"] == TaskState.working
|
||||
assert call_kwargs["state"] == TaskState.TASK_STATE_WORKING
|
||||
|
||||
async def test_handle_agent_response_update_no_streamed_set(
|
||||
self, executor: A2AExecutor, mock_updater: MagicMock
|
||||
|
||||
Reference in New Issue
Block a user