mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
feat(a2a): add A2AAgentSession with reference_task_ids and input-required support (#5980)
* feat(a2a): link follow-up messages via reference_task_ids Track the task_id from A2A responses (task, status_update, artifact_update, and message payloads) on session.state and include it as reference_task_ids on subsequent outgoing messages. This enables remote agents to correlate follow-up messages as task refinements per the A2A spec. Resolves #5938 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * feat(a2a): add A2AAgentSession for typed protocol state tracking Introduce A2AAgentSession (subclass of AgentSession) with context_id, task_id, and task_state properties. This follows the DurableAgentSession pattern and mirrors the .NET A2AAgentSession design. - Track task_id, context_id, and task_state from all response payload types - Validate context_id consistency (raise on mismatch) - Auto-assign server-generated context_id when not set - Only A2AAgentSession gets reference tracking (no state dict fallback) - Plain AgentSession continues to work without reference tracking - Add serialization support (to_dict/from_dict) - Export via agent_framework.a2a and agent_framework_a2a Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * style: remove unnecessary string annotation (pyupgrade) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: use AgentSession.from_dict for state deserialization Avoids importing private _deserialize_state, matching the DurableAgentSession pattern. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: track context_id from message payloads in A2AAgentSession Previously, context_id was only captured from task, status_update, and artifact_update payloads. Message-only responses (which carry context_id but may lack task_id) were silently lost. This fix: - Captures msg.context_id in the message handler - Persists session state when either last_task_id or last_context_id is present (not only when task_id is truthy) - Only updates task_id/task_state when a task_id was actually returned - Adds a test for message-only context_id tracking Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * addressed comments * Gate status content to INPUT_REQUIRED/terminal states (match .NET) Match .NET's GetUserInputRequests pattern: only emit TaskStatusUpdateEvent message content when state is INPUT_REQUIRED or terminal. Intermediate status text (WORKING, SUBMITTED) is no longer surfaced to callers. When state is INPUT_REQUIRED, set additional_properties['input_required'] = True so callers can distinguish input requests from final responses. Closes #5937 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review: remove message task_id tracking, defensive fallbacks, and input_required flag - Do not track task_id from Message payloads (simple interactions without task tracking) - Remove 'or last_task_id' fallback from status_update and artifact_update handlers (spec guarantees task_id is always set) - Remove additional_properties['input_required'] flag (content gating to INPUT_REQUIRED/terminal states is the signal itself) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
371a869e44
commit
efdabd56dc
@@ -3,7 +3,7 @@
|
||||
import importlib.metadata
|
||||
|
||||
from ._a2a_executor import A2AExecutor
|
||||
from ._agent import A2AAgent, A2AContinuationToken
|
||||
from ._agent import A2AAgent, A2AAgentSession, A2AContinuationToken
|
||||
|
||||
try:
|
||||
__version__ = importlib.metadata.version(__name__)
|
||||
@@ -12,6 +12,7 @@ except importlib.metadata.PackageNotFoundError:
|
||||
|
||||
__all__ = [
|
||||
"A2AAgent",
|
||||
"A2AAgentSession",
|
||||
"A2AContinuationToken",
|
||||
"A2AExecutor",
|
||||
"__version__",
|
||||
|
||||
@@ -43,11 +43,89 @@ from agent_framework._types import AgentRunInputs
|
||||
from agent_framework.observability import AgentTelemetryLayer
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
__all__ = ["A2AAgent", "A2AContinuationToken"]
|
||||
__all__ = ["A2AAgent", "A2AAgentSession", "A2AContinuationToken"]
|
||||
|
||||
from agent_framework_a2a._utils import get_uri_data
|
||||
|
||||
|
||||
class A2AAgentSession(AgentSession):
|
||||
"""Session for A2A-based agents.
|
||||
|
||||
Extends AgentSession with A2A protocol-specific state: context_id for
|
||||
conversation tracking, task_id for the most recent task, and task_state
|
||||
for detecting input-required continuations vs. task refinements.
|
||||
|
||||
Attributes:
|
||||
context_id: The A2A conversation context identifier.
|
||||
task_id: The most recent task ID returned by the remote agent.
|
||||
task_state: The state of the most recent task (e.g., completed, input-required).
|
||||
"""
|
||||
|
||||
_CONTEXT_ID_KEY = "a2a_context_id"
|
||||
_TASK_ID_KEY = "a2a_task_id"
|
||||
_TASK_STATE_KEY = "a2a_task_state"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
context_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
task_state: TaskState | None = None,
|
||||
) -> None:
|
||||
"""Initialize the A2A agent session.
|
||||
|
||||
Keyword Args:
|
||||
context_id: Optional A2A context ID for conversation tracking.
|
||||
task_id: Optional task ID from a previous interaction.
|
||||
task_state: Optional state of the most recent task.
|
||||
"""
|
||||
super().__init__(service_session_id=context_id)
|
||||
self.context_id: str | None = context_id
|
||||
self.task_id: str | None = task_id
|
||||
self.task_state: TaskState | None = task_state
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialize session to a plain dict for storage/transfer."""
|
||||
data = super().to_dict()
|
||||
if self.context_id is not None:
|
||||
data[self._CONTEXT_ID_KEY] = self.context_id
|
||||
if self.task_id is not None:
|
||||
data[self._TASK_ID_KEY] = self.task_id
|
||||
if self.task_state is not None:
|
||||
data[self._TASK_STATE_KEY] = self.task_state
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> A2AAgentSession:
|
||||
"""Restore session from a previously serialized dict.
|
||||
|
||||
Args:
|
||||
data: Dict from a previous ``to_dict()`` call.
|
||||
|
||||
Returns:
|
||||
Restored A2AAgentSession instance.
|
||||
"""
|
||||
data = dict(data) # defensive copy
|
||||
context_id = data.pop(cls._CONTEXT_ID_KEY, None)
|
||||
task_id = data.pop(cls._TASK_ID_KEY, None)
|
||||
task_state_value = data.pop(cls._TASK_STATE_KEY, None)
|
||||
|
||||
# TaskState is a protobuf enum (int values); store and restore as-is
|
||||
task_state: TaskState | None = task_state_value if task_state_value is not None else None
|
||||
|
||||
# Delegate state deserialization to the base class
|
||||
base_session = AgentSession.from_dict(data)
|
||||
|
||||
session = cls(
|
||||
context_id=context_id or base_session.service_session_id,
|
||||
task_id=task_id,
|
||||
task_state=task_state,
|
||||
)
|
||||
session._session_id = base_session.session_id
|
||||
session.state.update(base_session.state)
|
||||
return session
|
||||
|
||||
|
||||
class A2AContinuationToken(ContinuationToken):
|
||||
"""Continuation token for A2A protocol long-running tasks."""
|
||||
|
||||
@@ -314,10 +392,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
else:
|
||||
if not normalized_messages:
|
||||
raise ValueError("At least one message is required when starting a new task (no continuation_token).")
|
||||
a2a_message = self._prepare_message_for_a2a(
|
||||
normalized_messages[-1],
|
||||
context_id=session.service_session_id if session else None,
|
||||
)
|
||||
a2a_message = self._prepare_message_for_a2a(normalized_messages[-1], session=session)
|
||||
request = SendMessageRequest(message=a2a_message)
|
||||
if background and not stream:
|
||||
# return_immediately only applies to non-streaming (message/send)
|
||||
@@ -392,6 +467,9 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
|
||||
all_updates: list[AgentResponseUpdate] = []
|
||||
streamed_artifact_ids_by_task: dict[str, set[str]] = {}
|
||||
last_task_id: str | None = None
|
||||
last_context_id: str | None = None
|
||||
last_task_state: TaskState | None = None
|
||||
# In non-streaming mode, accumulate intermediate status content so it
|
||||
# can be surfaced when the terminal event arrives (mirroring v0.3.x
|
||||
# behavior where the full Task history was available at completion).
|
||||
@@ -401,6 +479,8 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
if payload_type == "message":
|
||||
# Process A2A Message
|
||||
msg = item.message
|
||||
if msg.context_id:
|
||||
last_context_id = msg.context_id
|
||||
contents = self._parse_contents_from_a2a(msg.parts)
|
||||
metadata = MessageToDict(msg.metadata) if msg.metadata else None
|
||||
update = AgentResponseUpdate(
|
||||
@@ -414,6 +494,10 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
yield update
|
||||
elif payload_type == "task":
|
||||
task = item.task
|
||||
last_task_id = task.id
|
||||
if task.context_id:
|
||||
last_context_id = task.context_id
|
||||
last_task_state = task.status.state
|
||||
updates = self._updates_from_task(
|
||||
task,
|
||||
background=background,
|
||||
@@ -435,20 +519,25 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
yield update
|
||||
elif payload_type == "status_update":
|
||||
status_event = item.status_update
|
||||
last_task_id = status_event.task_id
|
||||
if status_event.context_id:
|
||||
last_context_id = status_event.context_id
|
||||
last_task_state = status_event.status.state
|
||||
updates = self._updates_from_task_update_event(status_event)
|
||||
is_terminal = status_event.status.state in TERMINAL_TASK_STATES
|
||||
is_input_required = status_event.status.state == TaskState.TASK_STATE_INPUT_REQUIRED
|
||||
if emit_intermediate:
|
||||
for update in updates:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
elif is_terminal:
|
||||
elif is_terminal or is_input_required:
|
||||
if updates:
|
||||
# Terminal event with content — discard accumulated intermediates
|
||||
# Terminal/input-required event with content — discard accumulated intermediates
|
||||
pending_updates_by_task.pop(status_event.task_id, None)
|
||||
for update in updates:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
else:
|
||||
elif is_terminal:
|
||||
# Terminal event with NO content — flush accumulated updates
|
||||
pending = pending_updates_by_task.pop(status_event.task_id, [])
|
||||
for update in pending:
|
||||
@@ -460,6 +549,9 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
pending_updates_by_task.setdefault(status_event.task_id, []).extend(updates)
|
||||
elif payload_type == "artifact_update":
|
||||
artifact_event = item.artifact_update
|
||||
last_task_id = artifact_event.task_id
|
||||
if artifact_event.context_id:
|
||||
last_context_id = artifact_event.context_id
|
||||
updates = self._updates_from_task_update_event(artifact_event)
|
||||
# Always yield artifact updates — they carry actual response
|
||||
# content (files, data). Track IDs so that a subsequent
|
||||
@@ -478,6 +570,22 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
if all_updates:
|
||||
session_context._response = AgentResponse.from_updates(all_updates) # type: ignore[assignment]
|
||||
|
||||
# Persist A2A protocol state on the session for follow-up message linking.
|
||||
if isinstance(session, A2AAgentSession) and (last_task_id or last_context_id):
|
||||
# Validate context_id consistency
|
||||
if session.context_id is not None and last_context_id and session.context_id != last_context_id:
|
||||
raise RuntimeError(
|
||||
f"The context_id returned from the A2A agent ('{last_context_id}') "
|
||||
f"differs from the session's context_id ('{session.context_id}')."
|
||||
)
|
||||
# Assign server-generated context_id if not already set
|
||||
if session.context_id is None and last_context_id:
|
||||
session.context_id = last_context_id
|
||||
session.service_session_id = last_context_id
|
||||
if last_task_id:
|
||||
session.task_id = last_task_id
|
||||
session.task_state = last_task_state
|
||||
|
||||
await self._run_after_providers(session=session, context=session_context)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -601,6 +709,10 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
if not update_event.status.HasField("message") or not update_event.status.message.parts:
|
||||
return []
|
||||
|
||||
state = update_event.status.state
|
||||
if state not in TERMINAL_TASK_STATES and state != TaskState.TASK_STATE_INPUT_REQUIRED:
|
||||
return []
|
||||
|
||||
message = update_event.status.message
|
||||
contents = self._parse_contents_from_a2a(message.parts)
|
||||
if not contents:
|
||||
@@ -609,6 +721,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
msg_meta = MessageToDict(message.metadata) if message.metadata else {}
|
||||
event_meta = MessageToDict(update_event.metadata) if update_event.metadata else {}
|
||||
merged_metadata = {**msg_meta, **event_meta} or None
|
||||
|
||||
return [
|
||||
AgentResponseUpdate(
|
||||
contents=contents,
|
||||
@@ -647,7 +760,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
return AgentResponse.from_updates(updates)
|
||||
return AgentResponse(messages=[], response_id=task.id, raw_representation=task)
|
||||
|
||||
def _prepare_message_for_a2a(self, message: Message, *, context_id: str | None = None) -> A2AMessage:
|
||||
def _prepare_message_for_a2a(self, message: Message, *, session: AgentSession | None = None) -> A2AMessage:
|
||||
"""Prepare a Message for the A2A protocol.
|
||||
|
||||
Transforms Agent Framework Message objects into A2A protocol Messages by:
|
||||
@@ -656,14 +769,33 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
- Converting file references (URI/data/hosted_file) to FilePart objects
|
||||
- Preserving metadata and additional properties from the original message
|
||||
- Setting the role to 'user' as framework messages are treated as user input
|
||||
- Linking follow-up messages to previous tasks via reference_task_ids or task_id
|
||||
|
||||
When the session is an ``A2AAgentSession``, the method reads context_id,
|
||||
task_id, and task_state directly. If the task is in INPUT_REQUIRED state,
|
||||
the outbound message's ``task_id`` is set (continuing the same task);
|
||||
otherwise ``reference_task_ids`` is used for task refinement linking.
|
||||
|
||||
Args:
|
||||
message: The framework Message to convert.
|
||||
context_id: Optional fallback context identifier (e.g. derived from
|
||||
``AgentSession.service_session_id``). When the *message* already
|
||||
carries a ``context_id`` in its ``additional_properties`` that
|
||||
value takes precedence; otherwise this fallback is used.
|
||||
|
||||
Keyword Args:
|
||||
session: Optional session to read A2A state from. If an
|
||||
``A2AAgentSession``, context_id/task_id/task_state are used for
|
||||
linking. A plain ``AgentSession`` provides service_session_id as
|
||||
a fallback context_id.
|
||||
"""
|
||||
# Extract A2A state from the session
|
||||
context_id: str | None = None
|
||||
previous_task_id: str | None = None
|
||||
task_state: TaskState | None = None
|
||||
if isinstance(session, A2AAgentSession):
|
||||
context_id = session.context_id
|
||||
previous_task_id = session.task_id
|
||||
task_state = session.task_state
|
||||
elif session is not None:
|
||||
context_id = session.service_session_id
|
||||
|
||||
parts: list[A2APart] = []
|
||||
if not message.contents:
|
||||
raise ValueError("Message.contents is empty; cannot convert to A2AMessage.")
|
||||
@@ -722,14 +854,24 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
|
||||
a2a_metadata = message.additional_properties.get("a2a_metadata")
|
||||
|
||||
return A2AMessage(
|
||||
a2a_message = A2AMessage(
|
||||
role=A2ARole.ROLE_USER,
|
||||
parts=parts,
|
||||
message_id=message.message_id or uuid.uuid4().hex,
|
||||
context_id=message.additional_properties.get("context_id") or context_id,
|
||||
context_id=context_id,
|
||||
metadata=a2a_metadata or {},
|
||||
)
|
||||
|
||||
if previous_task_id:
|
||||
if task_state == TaskState.TASK_STATE_INPUT_REQUIRED:
|
||||
# Task is waiting for user input — set task_id to continue the same task
|
||||
a2a_message.task_id = previous_task_id
|
||||
else:
|
||||
# Link as a follow-up (task refinement)
|
||||
a2a_message.reference_task_ids.append(previous_task_id)
|
||||
|
||||
return a2a_message
|
||||
|
||||
def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Content]:
|
||||
"""Parse A2A Parts into Agent Framework Content.
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ from agent_framework import (
|
||||
from agent_framework.a2a import A2AAgent
|
||||
from pytest import fixture, mark, raises
|
||||
|
||||
from agent_framework_a2a import A2AContinuationToken
|
||||
from agent_framework_a2a import A2AAgentSession, A2AContinuationToken
|
||||
from agent_framework_a2a._utils import get_uri_data
|
||||
|
||||
|
||||
@@ -482,24 +482,25 @@ def test_prepare_message_for_a2a_with_multiple_contents() -> None:
|
||||
|
||||
|
||||
def test_prepare_message_for_a2a_forwards_context_id() -> None:
|
||||
"""Test conversion of Message preserves context_id without duplicating it in metadata."""
|
||||
"""Test conversion of Message uses context_id from A2AAgentSession."""
|
||||
|
||||
agent = A2AAgent(client=MagicMock(), http_client=None)
|
||||
|
||||
message = Message(
|
||||
role="user",
|
||||
contents=[Content.from_text(text="Continue the task")],
|
||||
additional_properties={"context_id": "ctx-123", "a2a_metadata": {"trace_id": "trace-456"}},
|
||||
additional_properties={"a2a_metadata": {"trace_id": "trace-456"}},
|
||||
)
|
||||
|
||||
result = agent._prepare_message_for_a2a(message)
|
||||
session = A2AAgentSession(context_id="ctx-123")
|
||||
result = agent._prepare_message_for_a2a(message, session=session)
|
||||
|
||||
assert result.context_id == "ctx-123"
|
||||
assert result.metadata == {"trace_id": "trace-456"}
|
||||
|
||||
|
||||
def test_prepare_message_for_a2a_uses_fallback_context_id() -> None:
|
||||
"""Test that context_id kwarg is used when message has no context_id property."""
|
||||
"""Test that service_session_id from a plain session is used when message has no context_id property."""
|
||||
|
||||
agent = A2AAgent(client=MagicMock(), http_client=None)
|
||||
|
||||
@@ -508,25 +509,26 @@ def test_prepare_message_for_a2a_uses_fallback_context_id() -> None:
|
||||
contents=[Content.from_text(text="Hello")],
|
||||
)
|
||||
|
||||
result = agent._prepare_message_for_a2a(message, context_id="session-ctx-1")
|
||||
session = AgentSession(service_session_id="session-ctx-1")
|
||||
result = agent._prepare_message_for_a2a(message, session=session)
|
||||
|
||||
assert result.context_id == "session-ctx-1"
|
||||
|
||||
|
||||
def test_prepare_message_for_a2a_message_context_id_takes_precedence() -> None:
|
||||
"""Test that message.additional_properties context_id wins over the fallback."""
|
||||
def test_prepare_message_for_a2a_a2a_session_context_id_takes_precedence() -> None:
|
||||
"""Test that A2AAgentSession.context_id is used over plain session service_session_id."""
|
||||
|
||||
agent = A2AAgent(client=MagicMock(), http_client=None)
|
||||
|
||||
message = Message(
|
||||
role="user",
|
||||
contents=[Content.from_text(text="Hello")],
|
||||
additional_properties={"context_id": "explicit-ctx"},
|
||||
)
|
||||
|
||||
result = agent._prepare_message_for_a2a(message, context_id="session-ctx-1")
|
||||
session = A2AAgentSession(context_id="a2a-ctx")
|
||||
result = agent._prepare_message_for_a2a(message, session=session)
|
||||
|
||||
assert result.context_id == "explicit-ctx"
|
||||
assert result.context_id == "a2a-ctx"
|
||||
|
||||
|
||||
def test_parse_contents_from_a2a_with_data_part() -> None:
|
||||
@@ -963,21 +965,16 @@ async def test_run_passes_session_service_session_id_as_context_id(mock_a2a_clie
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_run_message_context_id_takes_precedence_over_session(mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that an explicit context_id on the message wins over session.service_session_id."""
|
||||
async def test_run_a2a_session_context_id_used_over_service_session_id(mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that A2AAgentSession.context_id is used for outbound messages."""
|
||||
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
|
||||
mock_a2a_client.add_message_response("msg-ctx2", "reply")
|
||||
|
||||
session = AgentSession(service_session_id="svc-session-42")
|
||||
message = Message(
|
||||
role="user",
|
||||
contents=[Content.from_text(text="Hello")],
|
||||
additional_properties={"context_id": "explicit-ctx"},
|
||||
)
|
||||
await agent.run(messages=[message], session=session)
|
||||
session = A2AAgentSession(context_id="a2a-ctx-99")
|
||||
await agent.run("Hello", session=session)
|
||||
|
||||
assert mock_a2a_client.last_message is not None
|
||||
assert mock_a2a_client.last_message.context_id == "explicit-ctx"
|
||||
assert mock_a2a_client.last_message.context_id == "a2a-ctx-99"
|
||||
|
||||
|
||||
# endregion
|
||||
@@ -1332,16 +1329,17 @@ async def test_streaming_artifact_update_event_yields_content(
|
||||
async def test_streaming_status_update_event_yields_content(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Test that streaming status update events surface message content directly from the update event."""
|
||||
"""Test that streaming status update events surface content for terminal/input-required states only."""
|
||||
# COMPLETED state should yield content (terminal)
|
||||
update_event = TaskStatusUpdateEvent(
|
||||
task_id="task-status",
|
||||
context_id="ctx-status",
|
||||
status=TaskStatus(
|
||||
state=TaskState.TASK_STATE_WORKING,
|
||||
state=TaskState.TASK_STATE_COMPLETED,
|
||||
message=A2AMessage(
|
||||
message_id=str(uuid4()),
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="Still working")],
|
||||
parts=[Part(text="Done")],
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -1352,11 +1350,64 @@ async def test_streaming_status_update_event_yields_content(
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 1
|
||||
assert updates[0].text == "Still working"
|
||||
assert updates[0].text == "Done"
|
||||
assert updates[0].role == "assistant"
|
||||
assert updates[0].raw_representation == update_event
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_streaming_input_required_emits_content(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Test that input-required status updates emit content (gated states that pass through)."""
|
||||
update_event = TaskStatusUpdateEvent(
|
||||
task_id="task-status",
|
||||
context_id="ctx-status",
|
||||
status=TaskStatus(
|
||||
state=TaskState.TASK_STATE_INPUT_REQUIRED,
|
||||
message=A2AMessage(
|
||||
message_id=str(uuid4()),
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="What is your name?")],
|
||||
),
|
||||
),
|
||||
)
|
||||
mock_a2a_client.responses.append(StreamResponse(status_update=update_event))
|
||||
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
async for update in a2a_agent.run("Hello", stream=True):
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 1
|
||||
assert updates[0].text == "What is your name?"
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_streaming_working_status_gates_content(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Test that intermediate WORKING status updates do NOT emit content (gated like .NET)."""
|
||||
update_event = TaskStatusUpdateEvent(
|
||||
task_id="task-status",
|
||||
context_id="ctx-status",
|
||||
status=TaskStatus(
|
||||
state=TaskState.TASK_STATE_WORKING,
|
||||
message=A2AMessage(
|
||||
message_id=str(uuid4()),
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="Processing...")],
|
||||
),
|
||||
),
|
||||
)
|
||||
mock_a2a_client.responses.append(StreamResponse(status_update=update_event))
|
||||
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
async for update in a2a_agent.run("Hello", stream=True):
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 0
|
||||
|
||||
|
||||
async def test_streaming_artifact_update_event_does_not_duplicate_terminal_task_artifacts(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
@@ -1576,28 +1627,17 @@ async def test_task_status_update_event_metadata_merged(a2a_agent: A2AAgent, moc
|
||||
task_id="task-se",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(
|
||||
state=TaskState.TASK_STATE_WORKING,
|
||||
state=TaskState.TASK_STATE_INPUT_REQUIRED,
|
||||
message=A2AMessage(
|
||||
message_id="m1",
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="working...")],
|
||||
parts=[Part(text="need input")],
|
||||
metadata={"msg_key": "msg_val"},
|
||||
),
|
||||
),
|
||||
metadata={"event_key": "event_val"},
|
||||
)
|
||||
terminal_task = Task(
|
||||
id="task-se",
|
||||
context_id="ctx",
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
|
||||
artifacts=[
|
||||
Artifact(artifact_id="a1", parts=[Part(text="done")]),
|
||||
],
|
||||
)
|
||||
mock_a2a_client.responses.extend([
|
||||
StreamResponse(status_update=status_event),
|
||||
StreamResponse(task=terminal_task),
|
||||
])
|
||||
mock_a2a_client.responses.append(StreamResponse(status_update=status_event))
|
||||
|
||||
stream = a2a_agent.run("hello", stream=True)
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
@@ -1681,11 +1721,11 @@ async def test_non_streaming_terminal_status_update_surfaces_content(
|
||||
assert response.messages[0].text == "Done! Here is your answer."
|
||||
|
||||
|
||||
async def test_non_streaming_accumulates_working_content_for_empty_terminal(
|
||||
async def test_non_streaming_working_content_gated(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Non-streaming run() accumulates WORKING content and flushes on empty terminal event."""
|
||||
# Intermediate WORKING event with content
|
||||
"""Non-streaming: WORKING status content is gated and not surfaced to callers."""
|
||||
# Intermediate WORKING event with content — should be gated
|
||||
working_msg = A2AMessage(
|
||||
message_id="msg-working",
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
@@ -1702,9 +1742,8 @@ async def test_non_streaming_accumulates_working_content_for_empty_terminal(
|
||||
|
||||
response = await a2a_agent.run("Hello")
|
||||
|
||||
# The accumulated WORKING content is flushed when terminal arrives empty
|
||||
assert len(response.messages) == 1
|
||||
assert response.messages[0].text == "Here is your answer from working state."
|
||||
# WORKING content is gated — nothing to accumulate or flush
|
||||
assert len(response.messages) == 0
|
||||
|
||||
|
||||
async def test_non_streaming_intermediate_discarded_when_terminal_has_content(
|
||||
@@ -1761,3 +1800,268 @@ async def test_non_streaming_artifact_update_surfaces_content(
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Reference Task IDs Tests
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_first_message_has_no_reference_task_ids(mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that the first message sent has no reference_task_ids."""
|
||||
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
|
||||
mock_a2a_client.add_task_response("task-first", [{"content": "Hello back"}])
|
||||
|
||||
session = A2AAgentSession()
|
||||
await agent.run("Hello", session=session)
|
||||
|
||||
assert mock_a2a_client.last_message is not None
|
||||
assert list(mock_a2a_client.last_message.reference_task_ids) == []
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_follow_up_message_includes_reference_task_ids(mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that a follow-up message references the previous task_id."""
|
||||
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
|
||||
mock_a2a_client.add_task_response("task-abc-123", [{"content": "First reply"}])
|
||||
|
||||
session = A2AAgentSession()
|
||||
await agent.run("Hello", session=session)
|
||||
|
||||
# Verify task_id was persisted on session
|
||||
assert session.task_id == "task-abc-123"
|
||||
|
||||
# Send a follow-up message
|
||||
mock_a2a_client.add_task_response("task-def-456", [{"content": "Second reply"}])
|
||||
await agent.run("Follow up", session=session)
|
||||
|
||||
assert mock_a2a_client.last_message is not None
|
||||
assert list(mock_a2a_client.last_message.reference_task_ids) == ["task-abc-123"]
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_reference_task_ids_updated_after_each_interaction(mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that reference_task_ids always points to the most recent task."""
|
||||
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
|
||||
|
||||
session = A2AAgentSession()
|
||||
|
||||
# First interaction
|
||||
mock_a2a_client.add_task_response("task-1", [{"content": "Reply 1"}])
|
||||
await agent.run("Message 1", session=session)
|
||||
assert session.task_id == "task-1"
|
||||
|
||||
# Second interaction
|
||||
mock_a2a_client.add_task_response("task-2", [{"content": "Reply 2"}])
|
||||
await agent.run("Message 2", session=session)
|
||||
assert mock_a2a_client.last_message.reference_task_ids == ["task-1"]
|
||||
assert session.task_id == "task-2"
|
||||
|
||||
# Third interaction references the second task
|
||||
mock_a2a_client.add_task_response("task-3", [{"content": "Reply 3"}])
|
||||
await agent.run("Message 3", session=session)
|
||||
assert mock_a2a_client.last_message.reference_task_ids == ["task-2"]
|
||||
assert session.task_id == "task-3"
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_task_id_tracked_from_status_update_events(mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that task_id is tracked even when response only contains status update events."""
|
||||
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
|
||||
|
||||
# Simulate a stream that only has status_update events (no full task payload)
|
||||
status_event = TaskStatusUpdateEvent(
|
||||
task_id="task-from-status",
|
||||
context_id="ctx-1",
|
||||
status=TaskStatus(
|
||||
state=TaskState.TASK_STATE_COMPLETED,
|
||||
message=A2AMessage(
|
||||
message_id="msg-status",
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="Done")],
|
||||
),
|
||||
),
|
||||
)
|
||||
mock_a2a_client.responses.append(StreamResponse(status_update=status_event))
|
||||
|
||||
session = A2AAgentSession()
|
||||
await agent.run("Hello", session=session)
|
||||
|
||||
assert session.task_id == "task-from-status"
|
||||
assert session.task_state == TaskState.TASK_STATE_COMPLETED
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_no_session_does_not_crash_reference_task_ids(mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that running without a session (no reference tracking) works fine."""
|
||||
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
|
||||
mock_a2a_client.add_task_response("task-no-session", [{"content": "Reply"}])
|
||||
|
||||
# Should not raise — no session means no reference_task_ids
|
||||
response = await agent.run("Hello")
|
||||
assert response is not None
|
||||
assert mock_a2a_client.last_message.reference_task_ids == []
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_task_id_not_tracked_from_message_payload(mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that task_id is NOT tracked from message payloads (simple interactions without task tracking)."""
|
||||
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
|
||||
|
||||
# Simulate a response that is a message with task_id set (no task/status_update events).
|
||||
# Per A2A spec, a Message response indicates simple interaction — task_id should not be persisted.
|
||||
message_with_task = A2AMessage(
|
||||
message_id="msg-with-task",
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="Response")],
|
||||
task_id="task-from-message",
|
||||
)
|
||||
mock_a2a_client.responses.append(StreamResponse(message=message_with_task))
|
||||
|
||||
session = A2AAgentSession()
|
||||
await agent.run("Hello", session=session)
|
||||
|
||||
assert session.task_id is None
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_context_id_assigned_from_response(mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that context_id is assigned from the response when not set on session."""
|
||||
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
|
||||
mock_a2a_client.add_task_response("task-ctx", [{"content": "Reply"}])
|
||||
|
||||
session = A2AAgentSession()
|
||||
assert session.context_id is None
|
||||
|
||||
await agent.run("Hello", session=session)
|
||||
|
||||
# context_id from the task response should be assigned
|
||||
assert session.context_id == "test-context"
|
||||
assert session.service_session_id == "test-context"
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_context_id_tracked_from_message_payload(mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that context_id is captured from message-only responses (no task payload)."""
|
||||
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
|
||||
|
||||
# Simulate a response with only a message that has context_id but no task_id
|
||||
message_with_context = A2AMessage(
|
||||
message_id="msg-ctx-only",
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="Hello!")],
|
||||
context_id="server-ctx-123",
|
||||
)
|
||||
mock_a2a_client.responses.append(StreamResponse(message=message_with_context))
|
||||
|
||||
session = A2AAgentSession()
|
||||
await agent.run("Hi", session=session)
|
||||
|
||||
# context_id should be captured even without a task_id
|
||||
assert session.context_id == "server-ctx-123"
|
||||
assert session.service_session_id == "server-ctx-123"
|
||||
assert session.task_id is None
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_context_id_mismatch_raises_error(mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that a context_id mismatch between session and response raises an error."""
|
||||
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
|
||||
|
||||
# Task response has context_id="test-context" (from add_task_response helper)
|
||||
mock_a2a_client.add_task_response("task-mismatch", [{"content": "Reply"}])
|
||||
|
||||
# Session already has a different context_id
|
||||
session = A2AAgentSession(context_id="different-context")
|
||||
|
||||
with raises(RuntimeError, match="differs from the session's context_id"):
|
||||
await agent.run("Hello", session=session)
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_task_state_tracked_on_session(mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that task_state is tracked on A2AAgentSession."""
|
||||
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
|
||||
|
||||
# Add a task that ends in INPUT_REQUIRED
|
||||
mock_a2a_client.add_in_progress_task_response(
|
||||
"task-input",
|
||||
context_id="ctx-input",
|
||||
state=TaskState.TASK_STATE_INPUT_REQUIRED,
|
||||
text="What is your name?",
|
||||
)
|
||||
|
||||
session = A2AAgentSession()
|
||||
await agent.run("Start", session=session)
|
||||
|
||||
assert session.task_id == "task-input"
|
||||
assert session.task_state == TaskState.TASK_STATE_INPUT_REQUIRED
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_plain_agent_session_no_reference_tracking(mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that a plain AgentSession works but does not get reference_task_ids tracking."""
|
||||
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
|
||||
mock_a2a_client.add_task_response("task-plain", [{"content": "Reply"}])
|
||||
|
||||
session = AgentSession()
|
||||
await agent.run("Hello", session=session)
|
||||
|
||||
# Plain session does not get task_id tracking
|
||||
assert "a2a_task_id" not in session.state
|
||||
|
||||
# Follow-up has no reference_task_ids (no tracking on plain session)
|
||||
mock_a2a_client.add_task_response("task-plain-2", [{"content": "Reply 2"}])
|
||||
await agent.run("Follow up", session=session)
|
||||
assert list(mock_a2a_client.last_message.reference_task_ids) == []
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_a2a_agent_session_serialization() -> None:
|
||||
"""Test A2AAgentSession serialization and deserialization."""
|
||||
session = A2AAgentSession(
|
||||
context_id="ctx-456",
|
||||
task_id="task-789",
|
||||
task_state=TaskState.TASK_STATE_COMPLETED,
|
||||
)
|
||||
|
||||
data = session.to_dict()
|
||||
restored = A2AAgentSession.from_dict(data)
|
||||
|
||||
assert restored.session_id == session.session_id
|
||||
assert restored.context_id == "ctx-456"
|
||||
assert restored.task_id == "task-789"
|
||||
assert restored.task_state == TaskState.TASK_STATE_COMPLETED
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_input_required_sets_task_id_instead_of_reference(mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that when task_state is INPUT_REQUIRED, follow-up sets task_id (not reference_task_ids)."""
|
||||
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
|
||||
|
||||
# First turn: task ends in INPUT_REQUIRED
|
||||
mock_a2a_client.add_in_progress_task_response(
|
||||
"task-ir",
|
||||
context_id="ctx-ir",
|
||||
state=TaskState.TASK_STATE_INPUT_REQUIRED,
|
||||
text="What is your name?",
|
||||
)
|
||||
|
||||
session = A2AAgentSession()
|
||||
await agent.run("Start", session=session)
|
||||
|
||||
assert session.task_state == TaskState.TASK_STATE_INPUT_REQUIRED
|
||||
assert session.task_id == "task-ir"
|
||||
|
||||
# Second turn: follow-up should set task_id (not reference_task_ids)
|
||||
mock_a2a_client.add_in_progress_task_response(
|
||||
"task-ir-2", context_id="ctx-ir", state=TaskState.TASK_STATE_COMPLETED, text="Thanks!"
|
||||
)
|
||||
await agent.run("My name is Alice", session=session)
|
||||
|
||||
# The outbound message should have task_id set, not reference_task_ids
|
||||
last_msg = mock_a2a_client.last_message
|
||||
assert last_msg.task_id == "task-ir"
|
||||
assert list(last_msg.reference_task_ids) == []
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user