mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix A2A v1.0 non-streaming response and sample runtime issues (#5849)
- Fix non-streaming empty response by accumulating intermediate WORKING status updates and flushing them when an empty terminal event arrives - Fix sample agent_executor.py to enqueue Task before status events (required by v1.0 ActiveTask validation) - Fix create_jsonrpc_routes() calls to include required rpc_url param - Fix TYPE_CHECKING imports in sample agent_definitions.py - Add tests for non-streaming content accumulation behavior Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
410268b624
commit
68357b0250
@@ -42,7 +42,7 @@ request_handler = DefaultRequestHandler(
|
||||
app = Starlette(
|
||||
routes=[
|
||||
*create_agent_card_routes(my_agent_card),
|
||||
*create_jsonrpc_routes(request_handler),
|
||||
*create_jsonrpc_routes(request_handler, "/"),
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
@@ -78,7 +78,7 @@ class A2AExecutor(AgentExecutor):
|
||||
app = Starlette(
|
||||
routes=[
|
||||
*create_agent_card_routes(public_agent_card),
|
||||
*create_jsonrpc_routes(request_handler),
|
||||
*create_jsonrpc_routes(request_handler, "/"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -365,6 +365,10 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
|
||||
all_updates: list[AgentResponseUpdate] = []
|
||||
streamed_artifact_ids_by_task: dict[str, set[str]] = {}
|
||||
# In non-streaming mode, accumulate intermediate status content so it
|
||||
# can be surfaced when the terminal event arrives (mirroring v0.3.x
|
||||
# behavior where the full Task history was available at completion).
|
||||
pending_updates_by_task: dict[str, list[AgentResponseUpdate]] = {}
|
||||
async for item in a2a_stream:
|
||||
payload_type = item.WhichOneof("payload")
|
||||
if payload_type == "message":
|
||||
@@ -391,27 +395,55 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
)
|
||||
if task.status.state in TERMINAL_TASK_STATES:
|
||||
streamed_artifact_ids_by_task.pop(task.id, None)
|
||||
# If the terminal Task has no content, flush accumulated updates
|
||||
if not updates or all(not u.contents for u in updates):
|
||||
pending = pending_updates_by_task.pop(task.id, [])
|
||||
for update in pending:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
else:
|
||||
pending_updates_by_task.pop(task.id, None)
|
||||
for update in updates:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
elif payload_type == "status_update":
|
||||
status_event = item.status_update
|
||||
updates = self._updates_from_task_update_event(status_event)
|
||||
is_terminal = status_event.status.state in TERMINAL_TASK_STATES
|
||||
if emit_intermediate:
|
||||
for update in updates:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
elif is_terminal:
|
||||
if updates:
|
||||
# Terminal event with content — discard accumulated intermediates
|
||||
pending_updates_by_task.pop(status_event.task_id, None)
|
||||
for update in updates:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
else:
|
||||
# Terminal event with NO content — flush accumulated updates
|
||||
pending = pending_updates_by_task.pop(status_event.task_id, [])
|
||||
for update in pending:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
else:
|
||||
# Non-streaming intermediate: accumulate for later
|
||||
if updates:
|
||||
pending_updates_by_task.setdefault(status_event.task_id, []).extend(updates)
|
||||
elif payload_type == "artifact_update":
|
||||
artifact_event = item.artifact_update
|
||||
updates = self._updates_from_task_update_event(artifact_event)
|
||||
# Always yield artifact updates — they carry actual response
|
||||
# content (files, data). Track IDs so that a subsequent
|
||||
# terminal Task doesn't duplicate the same artifacts.
|
||||
if updates:
|
||||
streamed_artifact_ids_by_task.setdefault(artifact_event.task_id, set()).add(
|
||||
artifact_event.artifact.artifact_id
|
||||
)
|
||||
if emit_intermediate:
|
||||
for update in updates:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
for update in updates:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported StreamResponse payload: {payload_type}")
|
||||
|
||||
|
||||
@@ -1570,4 +1570,102 @@ async def test_none_metadata_leaves_additional_properties_empty(
|
||||
assert not response.additional_properties
|
||||
|
||||
|
||||
async def test_non_streaming_terminal_status_update_surfaces_content(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Non-streaming run() should surface content from terminal status_update events."""
|
||||
completed_msg = A2AMessage(
|
||||
message_id="msg-complete",
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="Done! Here is your answer.")],
|
||||
)
|
||||
status = TaskStatus(state=TaskState.TASK_STATE_COMPLETED, message=completed_msg)
|
||||
event = TaskStatusUpdateEvent(task_id="task-ts", context_id="ctx-ts", status=status)
|
||||
mock_a2a_client.responses.append(StreamResponse(status_update=event))
|
||||
|
||||
response = await a2a_agent.run("Hello")
|
||||
|
||||
assert len(response.messages) == 1
|
||||
assert response.messages[0].text == "Done! Here is your answer."
|
||||
|
||||
|
||||
async def test_non_streaming_accumulates_working_content_for_empty_terminal(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Non-streaming run() accumulates WORKING content and flushes on empty terminal event."""
|
||||
# Intermediate WORKING event with content
|
||||
working_msg = A2AMessage(
|
||||
message_id="msg-working",
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="Here is your answer from working state.")],
|
||||
)
|
||||
working_status = TaskStatus(state=TaskState.TASK_STATE_WORKING, message=working_msg)
|
||||
working_event = TaskStatusUpdateEvent(task_id="task-acc", context_id="ctx-acc", status=working_status)
|
||||
mock_a2a_client.responses.append(StreamResponse(status_update=working_event))
|
||||
|
||||
# Terminal COMPLETED event with NO content
|
||||
completed_status = TaskStatus(state=TaskState.TASK_STATE_COMPLETED)
|
||||
completed_event = TaskStatusUpdateEvent(task_id="task-acc", context_id="ctx-acc", status=completed_status)
|
||||
mock_a2a_client.responses.append(StreamResponse(status_update=completed_event))
|
||||
|
||||
response = await a2a_agent.run("Hello")
|
||||
|
||||
# The accumulated WORKING content is flushed when terminal arrives empty
|
||||
assert len(response.messages) == 1
|
||||
assert response.messages[0].text == "Here is your answer from working state."
|
||||
|
||||
|
||||
async def test_non_streaming_intermediate_discarded_when_terminal_has_content(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Non-streaming: if terminal event has content, intermediate content is discarded."""
|
||||
# Intermediate WORKING event
|
||||
working_msg = A2AMessage(
|
||||
message_id="msg-working",
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="Still thinking...")],
|
||||
)
|
||||
working_status = TaskStatus(state=TaskState.TASK_STATE_WORKING, message=working_msg)
|
||||
working_event = TaskStatusUpdateEvent(task_id="task-wi", context_id="ctx-wi", status=working_status)
|
||||
mock_a2a_client.responses.append(StreamResponse(status_update=working_event))
|
||||
|
||||
# Terminal COMPLETED event WITH content
|
||||
completed_msg = A2AMessage(
|
||||
message_id="msg-final",
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="Final answer")],
|
||||
)
|
||||
completed_status = TaskStatus(state=TaskState.TASK_STATE_COMPLETED, message=completed_msg)
|
||||
completed_event = TaskStatusUpdateEvent(task_id="task-wi", context_id="ctx-wi", status=completed_status)
|
||||
mock_a2a_client.responses.append(StreamResponse(status_update=completed_event))
|
||||
|
||||
response = await a2a_agent.run("Hello")
|
||||
|
||||
# Terminal content supersedes accumulated intermediates
|
||||
assert len(response.messages) == 1
|
||||
assert response.messages[0].text == "Final answer"
|
||||
|
||||
|
||||
async def test_non_streaming_artifact_update_surfaces_content(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Non-streaming run() should surface content from artifact_update events."""
|
||||
artifact = Artifact(
|
||||
artifact_id="art-ns",
|
||||
parts=[Part(text="Artifact content")],
|
||||
)
|
||||
event = TaskArtifactUpdateEvent(task_id="task-anu", context_id="ctx-anu", artifact=artifact, append=False)
|
||||
mock_a2a_client.responses.append(StreamResponse(artifact_update=event))
|
||||
|
||||
# Terminal task with the same artifact ID — should be deduped
|
||||
mock_a2a_client.add_task_response("task-anu", [{"id": "art-ns", "content": "Artifact content"}])
|
||||
|
||||
response = await a2a_agent.run("Hello")
|
||||
|
||||
# Artifact update + terminal task with same artifact ID = content emitted once from
|
||||
# the artifact_update, then the duplicate from the task is filtered by streamed_artifact_ids
|
||||
assert len(response.messages) == 1
|
||||
assert response.messages[0].text == "Artifact content"
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -103,7 +103,7 @@ def main() -> None:
|
||||
app = Starlette(
|
||||
routes=[
|
||||
*create_agent_card_routes(agent_card),
|
||||
*create_jsonrpc_routes(request_handler),
|
||||
*create_jsonrpc_routes(request_handler, "/"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -8,16 +8,11 @@ AgentCards for the invoice, policy, and logistics agent types.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from a2a.types import AgentCapabilities, AgentCard, AgentInterface, AgentSkill
|
||||
from agent_framework import Agent
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from invoice_data import query_by_invoice_id, query_by_transaction_id, query_invoices
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent_framework import Agent
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent instructions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -10,18 +10,12 @@ published back through the a2a-sdk event queue.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from a2a.helpers import new_task_from_user_message
|
||||
from a2a.server.agent_execution.agent_executor import AgentExecutor
|
||||
from a2a.types import (
|
||||
Message,
|
||||
Part,
|
||||
Role,
|
||||
TaskState,
|
||||
TaskStatus,
|
||||
TaskStatusUpdateEvent,
|
||||
)
|
||||
from a2a.server.tasks import TaskUpdater
|
||||
from a2a.types import Part, TaskState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.server.agent_execution.context import RequestContext
|
||||
@@ -47,17 +41,17 @@ class AgentFrameworkExecutor(AgentExecutor):
|
||||
if not user_text:
|
||||
user_text = "Hello"
|
||||
|
||||
task_id = context.task_id or str(uuid.uuid4())
|
||||
context_id = context.context_id or str(uuid.uuid4())
|
||||
# v1.0 requires a Task object in the queue before any TaskStatusUpdateEvent
|
||||
task = context.current_task
|
||||
if not task and context.message:
|
||||
task = new_task_from_user_message(context.message)
|
||||
await event_queue.enqueue_event(task)
|
||||
|
||||
task_id = task.id if task else context.task_id
|
||||
updater = TaskUpdater(event_queue, task_id, context.context_id)
|
||||
|
||||
# Signal that the agent is working
|
||||
await event_queue.enqueue_event(
|
||||
TaskStatusUpdateEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
|
||||
)
|
||||
)
|
||||
await updater.start_work()
|
||||
|
||||
try:
|
||||
response = await self.agent.run(user_text)
|
||||
@@ -71,48 +65,19 @@ class AgentFrameworkExecutor(AgentExecutor):
|
||||
if not response_parts:
|
||||
response_parts.append(Part(text=str(response)))
|
||||
|
||||
# Publish the agent's response as a completed message
|
||||
await event_queue.enqueue_event(
|
||||
TaskStatusUpdateEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
status=TaskStatus(
|
||||
state=TaskState.TASK_STATE_COMPLETED,
|
||||
message=Message(
|
||||
message_id=str(uuid.uuid4()),
|
||||
role=Role.ROLE_AGENT,
|
||||
parts=response_parts,
|
||||
),
|
||||
),
|
||||
)
|
||||
# Publish the agent's response and mark as completed
|
||||
await updater.complete(
|
||||
message=updater.new_agent_message(response_parts),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
await event_queue.enqueue_event(
|
||||
TaskStatusUpdateEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
status=TaskStatus(
|
||||
state=TaskState.TASK_STATE_FAILED,
|
||||
message=Message(
|
||||
message_id=str(uuid.uuid4()),
|
||||
role=Role.ROLE_AGENT,
|
||||
parts=[Part(text=f"Agent error: {e}")],
|
||||
),
|
||||
),
|
||||
)
|
||||
await updater.update_status(
|
||||
state=TaskState.TASK_STATE_FAILED,
|
||||
message=updater.new_agent_message([Part(text=f"Agent error: {e}")]),
|
||||
)
|
||||
|
||||
async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
|
||||
"""Handle cancellation by publishing a canceled status."""
|
||||
task_id = context.task_id or str(uuid.uuid4())
|
||||
context_id = context.context_id or str(uuid.uuid4())
|
||||
|
||||
await event_queue.enqueue_event(
|
||||
TaskStatusUpdateEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_CANCELED),
|
||||
)
|
||||
)
|
||||
updater = TaskUpdater(event_queue, context.task_id, context.context_id)
|
||||
await updater.update_status(state=TaskState.TASK_STATE_CANCELED)
|
||||
|
||||
@@ -65,7 +65,7 @@ if __name__ == "__main__":
|
||||
server = Starlette(
|
||||
routes=[
|
||||
*create_agent_card_routes(public_agent_card),
|
||||
*create_jsonrpc_routes(request_handler),
|
||||
*create_jsonrpc_routes(request_handler, "/"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user