Python: handle streamed A2A update events (#4919)

* Python: handle streamed A2A update events

* Python: preserve terminal A2A artifacts during streaming

* Python: harden streamed A2A update event handling

* Python: simplify streamed A2A update guard

---------

Co-authored-by: sztoplover-bit <253473756+sztoplover-bit@users.noreply.github.com>
Co-authored-by: Giles Odigwe <79032838+giles17@users.noreply.github.com>
This commit is contained in:
perry
2026-04-02 23:20:47 +08:00
committed by GitHub
Unverified
parent 524c0216e4
commit 5f06b68535
2 changed files with 266 additions and 2 deletions
@@ -365,6 +365,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
)
all_updates: list[AgentResponseUpdate] = []
streamed_artifact_ids_by_task: dict[str, set[str]] = {}
async for item in a2a_stream:
if isinstance(item, A2AMessage):
# Process A2A Message
@@ -378,12 +379,21 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
all_updates.append(update)
yield update
elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], Task):
task, _update_event = item
for update in self._updates_from_task(
task, update_event = item
updates = self._updates_from_task(
task,
update_event=update_event,
background=background,
emit_intermediate=emit_intermediate,
streamed_artifact_ids=streamed_artifact_ids_by_task.get(task.id),
)
if isinstance(update_event, TaskArtifactUpdateEvent) and any(
update.raw_representation is update_event for update in updates
):
streamed_artifact_ids_by_task.setdefault(task.id, set()).add(update_event.artifact.artifact_id)
if task.status.state in TERMINAL_TASK_STATES:
streamed_artifact_ids_by_task.pop(task.id, None)
for update in updates:
all_updates.append(update)
yield update
else:
@@ -403,8 +413,10 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
self,
task: Task,
*,
update_event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None,
background: bool = False,
emit_intermediate: bool = False,
streamed_artifact_ids: set[str] | None = None,
) -> list[AgentResponseUpdate]:
"""Convert an A2A Task into AgentResponseUpdate(s).
@@ -418,8 +430,21 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
"""
status = task.status
if (
emit_intermediate
and update_event is not None
and (event_updates := self._updates_from_task_update_event(update_event))
):
return event_updates
if status.state in TERMINAL_TASK_STATES:
task_messages = self._parse_messages_from_task(task)
if task.artifacts is not None and streamed_artifact_ids:
task_messages = [
message
for message in task_messages
if getattr(message.raw_representation, "artifact_id", None) not in streamed_artifact_ids
]
if task_messages:
return [
AgentResponseUpdate(
@@ -431,6 +456,8 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
)
for message in task_messages
]
if task.artifacts is not None:
return []
return [AgentResponseUpdate(contents=[], role="assistant", response_id=task.id, raw_representation=task)]
if background and status.state in IN_PROGRESS_TASK_STATES:
@@ -467,6 +494,44 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
return []
def _updates_from_task_update_event(
self, update_event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent
) -> list[AgentResponseUpdate]:
"""Convert A2A task update events into streaming AgentResponseUpdates."""
if isinstance(update_event, TaskArtifactUpdateEvent):
contents = self._parse_contents_from_a2a(update_event.artifact.parts)
if not contents:
return []
return [
AgentResponseUpdate(
contents=contents,
role="assistant",
response_id=update_event.task_id,
message_id=update_event.artifact.artifact_id,
raw_representation=update_event,
)
]
if not isinstance(update_event, TaskStatusUpdateEvent):
return []
message = update_event.status.message
if message is None or not message.parts:
return []
contents = self._parse_contents_from_a2a(message.parts)
if not contents:
return []
return [
AgentResponseUpdate(
contents=contents,
role="assistant" if message.role == A2ARole.agent else "user",
response_id=update_event.task_id,
raw_representation=update_event,
)
]
@staticmethod
def _build_continuation_token(task: Task) -> A2AContinuationToken | None:
"""Build an A2AContinuationToken from an A2A Task if it is still in progress."""
+199
View File
@@ -14,8 +14,10 @@ from a2a.types import (
FileWithUri,
Part,
Task,
TaskArtifactUpdateEvent,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
TextPart,
)
from a2a.types import Message as A2AMessage
@@ -1189,4 +1191,201 @@ async def test_streaming_working_update_with_empty_parts_is_skipped(
assert updates[0].contents[0].text == "Result"
async def test_streaming_artifact_update_event_yields_content(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that streaming artifact update events yield incremental content."""
task = Task(id="task-art", context_id="ctx-art", status=TaskStatus(state=TaskState.working, message=None))
artifact = Artifact(
artifact_id="artifact-1",
parts=[Part(root=TextPart(text="Hello"))],
)
update_event = TaskArtifactUpdateEvent(task_id="task-art", context_id="ctx-art", artifact=artifact, append=False)
mock_a2a_client.responses.append((task, update_event))
updates: list[AgentResponseUpdate] = []
async for update in a2a_agent.run("Hello", stream=True):
updates.append(update)
assert len(updates) == 1
assert updates[0].text == "Hello"
assert updates[0].message_id == "artifact-1"
assert updates[0].raw_representation == update_event
async def test_streaming_status_update_event_yields_content(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that streaming status update events surface message content directly from the update event."""
update_event = TaskStatusUpdateEvent(
task_id="task-status",
context_id="ctx-status",
status=TaskStatus(
state=TaskState.working,
message=A2AMessage(
message_id=str(uuid4()),
role=A2ARole.agent,
parts=[Part(root=TextPart(text="Still working"))],
),
),
final=False,
)
task = Task(id="task-status", context_id="ctx-status", status=TaskStatus(state=TaskState.working, message=None))
mock_a2a_client.responses.append((task, update_event))
updates: list[AgentResponseUpdate] = []
async for update in a2a_agent.run("Hello", stream=True):
updates.append(update)
assert len(updates) == 1
assert updates[0].text == "Still working"
assert updates[0].role == "assistant"
assert updates[0].raw_representation == update_event
async def test_streaming_artifact_update_event_does_not_duplicate_terminal_task_artifacts(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that streamed artifact chunks are not re-emitted from the final terminal task."""
working_task = Task(id="task-art-dup", context_id="ctx-art-dup", status=TaskStatus(state=TaskState.working))
first_chunk = TaskArtifactUpdateEvent(
task_id="task-art-dup",
context_id="ctx-art-dup",
artifact=Artifact(
artifact_id="artifact-dup",
parts=[Part(root=TextPart(text="Hello "))],
),
append=False,
)
second_chunk = TaskArtifactUpdateEvent(
task_id="task-art-dup",
context_id="ctx-art-dup",
artifact=Artifact(
artifact_id="artifact-dup",
parts=[Part(root=TextPart(text="world"))],
),
append=True,
)
terminal_task = Task(
id="task-art-dup",
context_id="ctx-art-dup",
status=TaskStatus(state=TaskState.completed, message=None),
artifacts=[
Artifact(
artifact_id="artifact-dup",
parts=[Part(root=TextPart(text="Hello world"))],
)
],
)
terminal_event = TaskStatusUpdateEvent(
task_id="task-art-dup",
context_id="ctx-art-dup",
status=TaskStatus(state=TaskState.completed, message=None),
final=True,
)
mock_a2a_client.responses.extend(
[
(working_task, first_chunk),
(working_task, second_chunk),
(terminal_task, terminal_event),
]
)
stream = a2a_agent.run("Hello", stream=True)
updates: list[AgentResponseUpdate] = []
async for update in stream:
updates.append(update)
response = await stream.get_final_response()
assert [update.text for update in updates] == ["Hello ", "world"]
assert response.text == "Hello world"
assert len(response.messages) == 1
async def test_streaming_terminal_task_artifacts_are_emitted_when_terminal_event_has_no_content(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that terminal task artifacts are still emitted when the final status event has no message."""
terminal_task = Task(
id="task-art-final",
context_id="ctx-art-final",
status=TaskStatus(state=TaskState.completed, message=None),
artifacts=[
Artifact(
artifact_id="artifact-final",
parts=[Part(root=TextPart(text="Final artifact"))],
)
],
)
terminal_event = TaskStatusUpdateEvent(
task_id="task-art-final",
context_id="ctx-art-final",
status=TaskStatus(state=TaskState.completed, message=None),
final=True,
)
mock_a2a_client.responses.append((terminal_task, terminal_event))
updates: list[AgentResponseUpdate] = []
async for update in a2a_agent.run("Hello", stream=True):
updates.append(update)
assert len(updates) == 1
assert updates[0].text == "Final artifact"
assert updates[0].message_id == "artifact-final"
async def test_streaming_terminal_task_only_emits_unstreamed_artifacts(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that the terminal task only emits artifacts that were not already streamed incrementally."""
working_task = Task(id="task-art-mixed", context_id="ctx-art-mixed", status=TaskStatus(state=TaskState.working))
streamed_chunk = TaskArtifactUpdateEvent(
task_id="task-art-mixed",
context_id="ctx-art-mixed",
artifact=Artifact(
artifact_id="artifact-streamed",
parts=[Part(root=TextPart(text="Hello"))],
),
append=False,
)
terminal_task = Task(
id="task-art-mixed",
context_id="ctx-art-mixed",
status=TaskStatus(state=TaskState.completed, message=None),
artifacts=[
Artifact(
artifact_id="artifact-streamed",
parts=[Part(root=TextPart(text="Hello"))],
),
Artifact(
artifact_id="artifact-final",
parts=[Part(root=TextPart(text="Goodbye"))],
),
],
)
terminal_event = TaskStatusUpdateEvent(
task_id="task-art-mixed",
context_id="ctx-art-mixed",
status=TaskStatus(state=TaskState.completed, message=None),
final=True,
)
mock_a2a_client.responses.extend(
[
(working_task, streamed_chunk),
(terminal_task, terminal_event),
]
)
stream = a2a_agent.run("Hello", stream=True)
updates: list[AgentResponseUpdate] = []
async for update in stream:
updates.append(update)
response = await stream.get_final_response()
assert [update.text for update in updates] == ["Hello", "Goodbye"]
assert [message.text for message in response.messages] == ["Hello", "Goodbye"]
# endregion