mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: handle streamed A2A update events (#4919)
* Python: handle streamed A2A update events * Python: preserve terminal A2A artifacts during streaming * Python: harden streamed A2A update event handling * Python: simplify streamed A2A update guard --------- Co-authored-by: sztoplover-bit <253473756+sztoplover-bit@users.noreply.github.com> Co-authored-by: Giles Odigwe <79032838+giles17@users.noreply.github.com>
This commit is contained in:
@@ -365,6 +365,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
)
|
||||
|
||||
all_updates: list[AgentResponseUpdate] = []
|
||||
streamed_artifact_ids_by_task: dict[str, set[str]] = {}
|
||||
async for item in a2a_stream:
|
||||
if isinstance(item, A2AMessage):
|
||||
# Process A2A Message
|
||||
@@ -378,12 +379,21 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], Task):
|
||||
task, _update_event = item
|
||||
for update in self._updates_from_task(
|
||||
task, update_event = item
|
||||
updates = self._updates_from_task(
|
||||
task,
|
||||
update_event=update_event,
|
||||
background=background,
|
||||
emit_intermediate=emit_intermediate,
|
||||
streamed_artifact_ids=streamed_artifact_ids_by_task.get(task.id),
|
||||
)
|
||||
if isinstance(update_event, TaskArtifactUpdateEvent) and any(
|
||||
update.raw_representation is update_event for update in updates
|
||||
):
|
||||
streamed_artifact_ids_by_task.setdefault(task.id, set()).add(update_event.artifact.artifact_id)
|
||||
if task.status.state in TERMINAL_TASK_STATES:
|
||||
streamed_artifact_ids_by_task.pop(task.id, None)
|
||||
for update in updates:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
else:
|
||||
@@ -403,8 +413,10 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
self,
|
||||
task: Task,
|
||||
*,
|
||||
update_event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None,
|
||||
background: bool = False,
|
||||
emit_intermediate: bool = False,
|
||||
streamed_artifact_ids: set[str] | None = None,
|
||||
) -> list[AgentResponseUpdate]:
|
||||
"""Convert an A2A Task into AgentResponseUpdate(s).
|
||||
|
||||
@@ -418,8 +430,21 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
"""
|
||||
status = task.status
|
||||
|
||||
if (
|
||||
emit_intermediate
|
||||
and update_event is not None
|
||||
and (event_updates := self._updates_from_task_update_event(update_event))
|
||||
):
|
||||
return event_updates
|
||||
|
||||
if status.state in TERMINAL_TASK_STATES:
|
||||
task_messages = self._parse_messages_from_task(task)
|
||||
if task.artifacts is not None and streamed_artifact_ids:
|
||||
task_messages = [
|
||||
message
|
||||
for message in task_messages
|
||||
if getattr(message.raw_representation, "artifact_id", None) not in streamed_artifact_ids
|
||||
]
|
||||
if task_messages:
|
||||
return [
|
||||
AgentResponseUpdate(
|
||||
@@ -431,6 +456,8 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
)
|
||||
for message in task_messages
|
||||
]
|
||||
if task.artifacts is not None:
|
||||
return []
|
||||
return [AgentResponseUpdate(contents=[], role="assistant", response_id=task.id, raw_representation=task)]
|
||||
|
||||
if background and status.state in IN_PROGRESS_TASK_STATES:
|
||||
@@ -467,6 +494,44 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
|
||||
return []
|
||||
|
||||
def _updates_from_task_update_event(
|
||||
self, update_event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent
|
||||
) -> list[AgentResponseUpdate]:
|
||||
"""Convert A2A task update events into streaming AgentResponseUpdates."""
|
||||
if isinstance(update_event, TaskArtifactUpdateEvent):
|
||||
contents = self._parse_contents_from_a2a(update_event.artifact.parts)
|
||||
if not contents:
|
||||
return []
|
||||
return [
|
||||
AgentResponseUpdate(
|
||||
contents=contents,
|
||||
role="assistant",
|
||||
response_id=update_event.task_id,
|
||||
message_id=update_event.artifact.artifact_id,
|
||||
raw_representation=update_event,
|
||||
)
|
||||
]
|
||||
|
||||
if not isinstance(update_event, TaskStatusUpdateEvent):
|
||||
return []
|
||||
|
||||
message = update_event.status.message
|
||||
if message is None or not message.parts:
|
||||
return []
|
||||
|
||||
contents = self._parse_contents_from_a2a(message.parts)
|
||||
if not contents:
|
||||
return []
|
||||
|
||||
return [
|
||||
AgentResponseUpdate(
|
||||
contents=contents,
|
||||
role="assistant" if message.role == A2ARole.agent else "user",
|
||||
response_id=update_event.task_id,
|
||||
raw_representation=update_event,
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _build_continuation_token(task: Task) -> A2AContinuationToken | None:
|
||||
"""Build an A2AContinuationToken from an A2A Task if it is still in progress."""
|
||||
|
||||
@@ -14,8 +14,10 @@ from a2a.types import (
|
||||
FileWithUri,
|
||||
Part,
|
||||
Task,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskState,
|
||||
TaskStatus,
|
||||
TaskStatusUpdateEvent,
|
||||
TextPart,
|
||||
)
|
||||
from a2a.types import Message as A2AMessage
|
||||
@@ -1189,4 +1191,201 @@ async def test_streaming_working_update_with_empty_parts_is_skipped(
|
||||
assert updates[0].contents[0].text == "Result"
|
||||
|
||||
|
||||
async def test_streaming_artifact_update_event_yields_content(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Test that streaming artifact update events yield incremental content."""
|
||||
task = Task(id="task-art", context_id="ctx-art", status=TaskStatus(state=TaskState.working, message=None))
|
||||
artifact = Artifact(
|
||||
artifact_id="artifact-1",
|
||||
parts=[Part(root=TextPart(text="Hello"))],
|
||||
)
|
||||
update_event = TaskArtifactUpdateEvent(task_id="task-art", context_id="ctx-art", artifact=artifact, append=False)
|
||||
mock_a2a_client.responses.append((task, update_event))
|
||||
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
async for update in a2a_agent.run("Hello", stream=True):
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 1
|
||||
assert updates[0].text == "Hello"
|
||||
assert updates[0].message_id == "artifact-1"
|
||||
assert updates[0].raw_representation == update_event
|
||||
|
||||
|
||||
async def test_streaming_status_update_event_yields_content(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Test that streaming status update events surface message content directly from the update event."""
|
||||
update_event = TaskStatusUpdateEvent(
|
||||
task_id="task-status",
|
||||
context_id="ctx-status",
|
||||
status=TaskStatus(
|
||||
state=TaskState.working,
|
||||
message=A2AMessage(
|
||||
message_id=str(uuid4()),
|
||||
role=A2ARole.agent,
|
||||
parts=[Part(root=TextPart(text="Still working"))],
|
||||
),
|
||||
),
|
||||
final=False,
|
||||
)
|
||||
task = Task(id="task-status", context_id="ctx-status", status=TaskStatus(state=TaskState.working, message=None))
|
||||
mock_a2a_client.responses.append((task, update_event))
|
||||
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
async for update in a2a_agent.run("Hello", stream=True):
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 1
|
||||
assert updates[0].text == "Still working"
|
||||
assert updates[0].role == "assistant"
|
||||
assert updates[0].raw_representation == update_event
|
||||
|
||||
|
||||
async def test_streaming_artifact_update_event_does_not_duplicate_terminal_task_artifacts(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Test that streamed artifact chunks are not re-emitted from the final terminal task."""
|
||||
working_task = Task(id="task-art-dup", context_id="ctx-art-dup", status=TaskStatus(state=TaskState.working))
|
||||
first_chunk = TaskArtifactUpdateEvent(
|
||||
task_id="task-art-dup",
|
||||
context_id="ctx-art-dup",
|
||||
artifact=Artifact(
|
||||
artifact_id="artifact-dup",
|
||||
parts=[Part(root=TextPart(text="Hello "))],
|
||||
),
|
||||
append=False,
|
||||
)
|
||||
second_chunk = TaskArtifactUpdateEvent(
|
||||
task_id="task-art-dup",
|
||||
context_id="ctx-art-dup",
|
||||
artifact=Artifact(
|
||||
artifact_id="artifact-dup",
|
||||
parts=[Part(root=TextPart(text="world"))],
|
||||
),
|
||||
append=True,
|
||||
)
|
||||
terminal_task = Task(
|
||||
id="task-art-dup",
|
||||
context_id="ctx-art-dup",
|
||||
status=TaskStatus(state=TaskState.completed, message=None),
|
||||
artifacts=[
|
||||
Artifact(
|
||||
artifact_id="artifact-dup",
|
||||
parts=[Part(root=TextPart(text="Hello world"))],
|
||||
)
|
||||
],
|
||||
)
|
||||
terminal_event = TaskStatusUpdateEvent(
|
||||
task_id="task-art-dup",
|
||||
context_id="ctx-art-dup",
|
||||
status=TaskStatus(state=TaskState.completed, message=None),
|
||||
final=True,
|
||||
)
|
||||
|
||||
mock_a2a_client.responses.extend(
|
||||
[
|
||||
(working_task, first_chunk),
|
||||
(working_task, second_chunk),
|
||||
(terminal_task, terminal_event),
|
||||
]
|
||||
)
|
||||
|
||||
stream = a2a_agent.run("Hello", stream=True)
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
async for update in stream:
|
||||
updates.append(update)
|
||||
response = await stream.get_final_response()
|
||||
|
||||
assert [update.text for update in updates] == ["Hello ", "world"]
|
||||
assert response.text == "Hello world"
|
||||
assert len(response.messages) == 1
|
||||
|
||||
|
||||
async def test_streaming_terminal_task_artifacts_are_emitted_when_terminal_event_has_no_content(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Test that terminal task artifacts are still emitted when the final status event has no message."""
|
||||
terminal_task = Task(
|
||||
id="task-art-final",
|
||||
context_id="ctx-art-final",
|
||||
status=TaskStatus(state=TaskState.completed, message=None),
|
||||
artifacts=[
|
||||
Artifact(
|
||||
artifact_id="artifact-final",
|
||||
parts=[Part(root=TextPart(text="Final artifact"))],
|
||||
)
|
||||
],
|
||||
)
|
||||
terminal_event = TaskStatusUpdateEvent(
|
||||
task_id="task-art-final",
|
||||
context_id="ctx-art-final",
|
||||
status=TaskStatus(state=TaskState.completed, message=None),
|
||||
final=True,
|
||||
)
|
||||
mock_a2a_client.responses.append((terminal_task, terminal_event))
|
||||
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
async for update in a2a_agent.run("Hello", stream=True):
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 1
|
||||
assert updates[0].text == "Final artifact"
|
||||
assert updates[0].message_id == "artifact-final"
|
||||
|
||||
|
||||
async def test_streaming_terminal_task_only_emits_unstreamed_artifacts(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Test that the terminal task only emits artifacts that were not already streamed incrementally."""
|
||||
working_task = Task(id="task-art-mixed", context_id="ctx-art-mixed", status=TaskStatus(state=TaskState.working))
|
||||
streamed_chunk = TaskArtifactUpdateEvent(
|
||||
task_id="task-art-mixed",
|
||||
context_id="ctx-art-mixed",
|
||||
artifact=Artifact(
|
||||
artifact_id="artifact-streamed",
|
||||
parts=[Part(root=TextPart(text="Hello"))],
|
||||
),
|
||||
append=False,
|
||||
)
|
||||
terminal_task = Task(
|
||||
id="task-art-mixed",
|
||||
context_id="ctx-art-mixed",
|
||||
status=TaskStatus(state=TaskState.completed, message=None),
|
||||
artifacts=[
|
||||
Artifact(
|
||||
artifact_id="artifact-streamed",
|
||||
parts=[Part(root=TextPart(text="Hello"))],
|
||||
),
|
||||
Artifact(
|
||||
artifact_id="artifact-final",
|
||||
parts=[Part(root=TextPart(text="Goodbye"))],
|
||||
),
|
||||
],
|
||||
)
|
||||
terminal_event = TaskStatusUpdateEvent(
|
||||
task_id="task-art-mixed",
|
||||
context_id="ctx-art-mixed",
|
||||
status=TaskStatus(state=TaskState.completed, message=None),
|
||||
final=True,
|
||||
)
|
||||
|
||||
mock_a2a_client.responses.extend(
|
||||
[
|
||||
(working_task, streamed_chunk),
|
||||
(terminal_task, terminal_event),
|
||||
]
|
||||
)
|
||||
|
||||
stream = a2a_agent.run("Hello", stream=True)
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
async for update in stream:
|
||||
updates.append(update)
|
||||
response = await stream.get_final_response()
|
||||
|
||||
assert [update.text for update in updates] == ["Hello", "Goodbye"]
|
||||
assert [message.text for message in response.messages] == ["Hello", "Goodbye"]
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user