feat(a2a): add A2AAgentSession with reference_task_ids and input-required support (#5980)

* feat(a2a): link follow-up messages via reference_task_ids

Track the task_id from A2A responses (task, status_update, artifact_update,
and message payloads) on session.state and include it as reference_task_ids
on subsequent outgoing messages. This enables remote agents to correlate
follow-up messages as task refinements per the A2A spec.

Resolves #5938

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* feat(a2a): add A2AAgentSession for typed protocol state tracking

Introduce A2AAgentSession (subclass of AgentSession) with context_id,
task_id, and task_state properties. This follows the DurableAgentSession
pattern and mirrors the .NET A2AAgentSession design.

- Track task_id, context_id, and task_state from all response payload types
- Validate context_id consistency (raise on mismatch)
- Auto-assign server-generated context_id when not set
- Only A2AAgentSession gets reference tracking (no state dict fallback)
- Plain AgentSession continues to work without reference tracking
- Add serialization support (to_dict/from_dict)
- Export via agent_framework.a2a and agent_framework_a2a

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* style: remove unnecessary string annotation (pyupgrade)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix: use AgentSession.from_dict for state deserialization

Avoids importing private _deserialize_state, matching the
DurableAgentSession pattern.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix: track context_id from message payloads in A2AAgentSession

Previously, context_id was only captured from task, status_update, and
artifact_update payloads. Message-only responses (which carry context_id
but may lack task_id) were silently lost. This fix:

- Captures msg.context_id in the message handler
- Persists session state when either last_task_id or last_context_id is
  present (not only when task_id is truthy)
- Only updates task_id/task_state when a task_id was actually returned
- Adds a test for message-only context_id tracking

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* addressed comments

* Gate status content to INPUT_REQUIRED/terminal states (match .NET)

Match .NET's GetUserInputRequests pattern: only emit TaskStatusUpdateEvent
message content when state is INPUT_REQUIRED or terminal. Intermediate
status text (WORKING, SUBMITTED) is no longer surfaced to callers.

When state is INPUT_REQUIRED, set additional_properties['input_required']
= True so callers can distinguish input requests from final responses.

Closes #5937

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Address review: remove message task_id tracking, defensive fallbacks, and input_required flag

- Do not track task_id from Message payloads (simple interactions
  without task tracking)
- Remove 'or last_task_id' fallback from status_update and
  artifact_update handlers (spec guarantees task_id is always set)
- Remove additional_properties['input_required'] flag (content gating
  to INPUT_REQUIRED/terminal states is the signal itself)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Giles Odigwe
2026-05-28 01:36:49 -07:00
committed by GitHub
Unverified
parent 371a869e44
commit efdabd56dc
5 changed files with 512 additions and 64 deletions
@@ -3,7 +3,7 @@
import importlib.metadata
from ._a2a_executor import A2AExecutor
from ._agent import A2AAgent, A2AContinuationToken
from ._agent import A2AAgent, A2AAgentSession, A2AContinuationToken
try:
__version__ = importlib.metadata.version(__name__)
@@ -12,6 +12,7 @@ except importlib.metadata.PackageNotFoundError:
__all__ = [
"A2AAgent",
"A2AAgentSession",
"A2AContinuationToken",
"A2AExecutor",
"__version__",
+157 -15
View File
@@ -43,11 +43,89 @@ from agent_framework._types import AgentRunInputs
from agent_framework.observability import AgentTelemetryLayer
from google.protobuf.json_format import MessageToDict
__all__ = ["A2AAgent", "A2AContinuationToken"]
__all__ = ["A2AAgent", "A2AAgentSession", "A2AContinuationToken"]
from agent_framework_a2a._utils import get_uri_data
class A2AAgentSession(AgentSession):
"""Session for A2A-based agents.
Extends AgentSession with A2A protocol-specific state: context_id for
conversation tracking, task_id for the most recent task, and task_state
for detecting input-required continuations vs. task refinements.
Attributes:
context_id: The A2A conversation context identifier.
task_id: The most recent task ID returned by the remote agent.
task_state: The state of the most recent task (e.g., completed, input-required).
"""
_CONTEXT_ID_KEY = "a2a_context_id"
_TASK_ID_KEY = "a2a_task_id"
_TASK_STATE_KEY = "a2a_task_state"
def __init__(
self,
*,
context_id: str | None = None,
task_id: str | None = None,
task_state: TaskState | None = None,
) -> None:
"""Initialize the A2A agent session.
Keyword Args:
context_id: Optional A2A context ID for conversation tracking.
task_id: Optional task ID from a previous interaction.
task_state: Optional state of the most recent task.
"""
super().__init__(service_session_id=context_id)
self.context_id: str | None = context_id
self.task_id: str | None = task_id
self.task_state: TaskState | None = task_state
def to_dict(self) -> dict[str, Any]:
"""Serialize session to a plain dict for storage/transfer."""
data = super().to_dict()
if self.context_id is not None:
data[self._CONTEXT_ID_KEY] = self.context_id
if self.task_id is not None:
data[self._TASK_ID_KEY] = self.task_id
if self.task_state is not None:
data[self._TASK_STATE_KEY] = self.task_state
return data
@classmethod
def from_dict(cls, data: dict[str, Any]) -> A2AAgentSession:
"""Restore session from a previously serialized dict.
Args:
data: Dict from a previous ``to_dict()`` call.
Returns:
Restored A2AAgentSession instance.
"""
data = dict(data) # defensive copy
context_id = data.pop(cls._CONTEXT_ID_KEY, None)
task_id = data.pop(cls._TASK_ID_KEY, None)
task_state_value = data.pop(cls._TASK_STATE_KEY, None)
# TaskState is a protobuf enum (int values); store and restore as-is
task_state: TaskState | None = task_state_value if task_state_value is not None else None
# Delegate state deserialization to the base class
base_session = AgentSession.from_dict(data)
session = cls(
context_id=context_id or base_session.service_session_id,
task_id=task_id,
task_state=task_state,
)
session._session_id = base_session.session_id
session.state.update(base_session.state)
return session
class A2AContinuationToken(ContinuationToken):
"""Continuation token for A2A protocol long-running tasks."""
@@ -314,10 +392,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
else:
if not normalized_messages:
raise ValueError("At least one message is required when starting a new task (no continuation_token).")
a2a_message = self._prepare_message_for_a2a(
normalized_messages[-1],
context_id=session.service_session_id if session else None,
)
a2a_message = self._prepare_message_for_a2a(normalized_messages[-1], session=session)
request = SendMessageRequest(message=a2a_message)
if background and not stream:
# return_immediately only applies to non-streaming (message/send)
@@ -392,6 +467,9 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
all_updates: list[AgentResponseUpdate] = []
streamed_artifact_ids_by_task: dict[str, set[str]] = {}
last_task_id: str | None = None
last_context_id: str | None = None
last_task_state: TaskState | None = None
# In non-streaming mode, accumulate intermediate status content so it
# can be surfaced when the terminal event arrives (mirroring v0.3.x
# behavior where the full Task history was available at completion).
@@ -401,6 +479,8 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
if payload_type == "message":
# Process A2A Message
msg = item.message
if msg.context_id:
last_context_id = msg.context_id
contents = self._parse_contents_from_a2a(msg.parts)
metadata = MessageToDict(msg.metadata) if msg.metadata else None
update = AgentResponseUpdate(
@@ -414,6 +494,10 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
yield update
elif payload_type == "task":
task = item.task
last_task_id = task.id
if task.context_id:
last_context_id = task.context_id
last_task_state = task.status.state
updates = self._updates_from_task(
task,
background=background,
@@ -435,20 +519,25 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
yield update
elif payload_type == "status_update":
status_event = item.status_update
last_task_id = status_event.task_id
if status_event.context_id:
last_context_id = status_event.context_id
last_task_state = status_event.status.state
updates = self._updates_from_task_update_event(status_event)
is_terminal = status_event.status.state in TERMINAL_TASK_STATES
is_input_required = status_event.status.state == TaskState.TASK_STATE_INPUT_REQUIRED
if emit_intermediate:
for update in updates:
all_updates.append(update)
yield update
elif is_terminal:
elif is_terminal or is_input_required:
if updates:
# Terminal event with content — discard accumulated intermediates
# Terminal/input-required event with content — discard accumulated intermediates
pending_updates_by_task.pop(status_event.task_id, None)
for update in updates:
all_updates.append(update)
yield update
else:
elif is_terminal:
# Terminal event with NO content — flush accumulated updates
pending = pending_updates_by_task.pop(status_event.task_id, [])
for update in pending:
@@ -460,6 +549,9 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
pending_updates_by_task.setdefault(status_event.task_id, []).extend(updates)
elif payload_type == "artifact_update":
artifact_event = item.artifact_update
last_task_id = artifact_event.task_id
if artifact_event.context_id:
last_context_id = artifact_event.context_id
updates = self._updates_from_task_update_event(artifact_event)
# Always yield artifact updates — they carry actual response
# content (files, data). Track IDs so that a subsequent
@@ -478,6 +570,22 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
if all_updates:
session_context._response = AgentResponse.from_updates(all_updates) # type: ignore[assignment]
# Persist A2A protocol state on the session for follow-up message linking.
if isinstance(session, A2AAgentSession) and (last_task_id or last_context_id):
# Validate context_id consistency
if session.context_id is not None and last_context_id and session.context_id != last_context_id:
raise RuntimeError(
f"The context_id returned from the A2A agent ('{last_context_id}') "
f"differs from the session's context_id ('{session.context_id}')."
)
# Assign server-generated context_id if not already set
if session.context_id is None and last_context_id:
session.context_id = last_context_id
session.service_session_id = last_context_id
if last_task_id:
session.task_id = last_task_id
session.task_state = last_task_state
await self._run_after_providers(session=session, context=session_context)
# ------------------------------------------------------------------
@@ -601,6 +709,10 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
if not update_event.status.HasField("message") or not update_event.status.message.parts:
return []
state = update_event.status.state
if state not in TERMINAL_TASK_STATES and state != TaskState.TASK_STATE_INPUT_REQUIRED:
return []
message = update_event.status.message
contents = self._parse_contents_from_a2a(message.parts)
if not contents:
@@ -609,6 +721,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
msg_meta = MessageToDict(message.metadata) if message.metadata else {}
event_meta = MessageToDict(update_event.metadata) if update_event.metadata else {}
merged_metadata = {**msg_meta, **event_meta} or None
return [
AgentResponseUpdate(
contents=contents,
@@ -647,7 +760,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
return AgentResponse.from_updates(updates)
return AgentResponse(messages=[], response_id=task.id, raw_representation=task)
def _prepare_message_for_a2a(self, message: Message, *, context_id: str | None = None) -> A2AMessage:
def _prepare_message_for_a2a(self, message: Message, *, session: AgentSession | None = None) -> A2AMessage:
"""Prepare a Message for the A2A protocol.
Transforms Agent Framework Message objects into A2A protocol Messages by:
@@ -656,14 +769,33 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
- Converting file references (URI/data/hosted_file) to FilePart objects
- Preserving metadata and additional properties from the original message
- Setting the role to 'user' as framework messages are treated as user input
- Linking follow-up messages to previous tasks via reference_task_ids or task_id
When the session is an ``A2AAgentSession``, the method reads context_id,
task_id, and task_state directly. If the task is in INPUT_REQUIRED state,
the outbound message's ``task_id`` is set (continuing the same task);
otherwise ``reference_task_ids`` is used for task refinement linking.
Args:
message: The framework Message to convert.
context_id: Optional fallback context identifier (e.g. derived from
``AgentSession.service_session_id``). When the *message* already
carries a ``context_id`` in its ``additional_properties`` that
value takes precedence; otherwise this fallback is used.
Keyword Args:
session: Optional session to read A2A state from. If an
``A2AAgentSession``, context_id/task_id/task_state are used for
linking. A plain ``AgentSession`` provides service_session_id as
a fallback context_id.
"""
# Extract A2A state from the session
context_id: str | None = None
previous_task_id: str | None = None
task_state: TaskState | None = None
if isinstance(session, A2AAgentSession):
context_id = session.context_id
previous_task_id = session.task_id
task_state = session.task_state
elif session is not None:
context_id = session.service_session_id
parts: list[A2APart] = []
if not message.contents:
raise ValueError("Message.contents is empty; cannot convert to A2AMessage.")
@@ -722,14 +854,24 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
a2a_metadata = message.additional_properties.get("a2a_metadata")
return A2AMessage(
a2a_message = A2AMessage(
role=A2ARole.ROLE_USER,
parts=parts,
message_id=message.message_id or uuid.uuid4().hex,
context_id=message.additional_properties.get("context_id") or context_id,
context_id=context_id,
metadata=a2a_metadata or {},
)
if previous_task_id:
if task_state == TaskState.TASK_STATE_INPUT_REQUIRED:
# Task is waiting for user input — set task_id to continue the same task
a2a_message.task_id = previous_task_id
else:
# Link as a follow-up (task refinement)
a2a_message.reference_task_ids.append(previous_task_id)
return a2a_message
def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Content]:
"""Parse A2A Parts into Agent Framework Content.
+349 -45
View File
@@ -31,7 +31,7 @@ from agent_framework import (
from agent_framework.a2a import A2AAgent
from pytest import fixture, mark, raises
from agent_framework_a2a import A2AContinuationToken
from agent_framework_a2a import A2AAgentSession, A2AContinuationToken
from agent_framework_a2a._utils import get_uri_data
@@ -482,24 +482,25 @@ def test_prepare_message_for_a2a_with_multiple_contents() -> None:
def test_prepare_message_for_a2a_forwards_context_id() -> None:
"""Test conversion of Message preserves context_id without duplicating it in metadata."""
"""Test conversion of Message uses context_id from A2AAgentSession."""
agent = A2AAgent(client=MagicMock(), http_client=None)
message = Message(
role="user",
contents=[Content.from_text(text="Continue the task")],
additional_properties={"context_id": "ctx-123", "a2a_metadata": {"trace_id": "trace-456"}},
additional_properties={"a2a_metadata": {"trace_id": "trace-456"}},
)
result = agent._prepare_message_for_a2a(message)
session = A2AAgentSession(context_id="ctx-123")
result = agent._prepare_message_for_a2a(message, session=session)
assert result.context_id == "ctx-123"
assert result.metadata == {"trace_id": "trace-456"}
def test_prepare_message_for_a2a_uses_fallback_context_id() -> None:
"""Test that context_id kwarg is used when message has no context_id property."""
"""Test that service_session_id from a plain session is used when message has no context_id property."""
agent = A2AAgent(client=MagicMock(), http_client=None)
@@ -508,25 +509,26 @@ def test_prepare_message_for_a2a_uses_fallback_context_id() -> None:
contents=[Content.from_text(text="Hello")],
)
result = agent._prepare_message_for_a2a(message, context_id="session-ctx-1")
session = AgentSession(service_session_id="session-ctx-1")
result = agent._prepare_message_for_a2a(message, session=session)
assert result.context_id == "session-ctx-1"
def test_prepare_message_for_a2a_message_context_id_takes_precedence() -> None:
"""Test that message.additional_properties context_id wins over the fallback."""
def test_prepare_message_for_a2a_a2a_session_context_id_takes_precedence() -> None:
"""Test that A2AAgentSession.context_id is used over plain session service_session_id."""
agent = A2AAgent(client=MagicMock(), http_client=None)
message = Message(
role="user",
contents=[Content.from_text(text="Hello")],
additional_properties={"context_id": "explicit-ctx"},
)
result = agent._prepare_message_for_a2a(message, context_id="session-ctx-1")
session = A2AAgentSession(context_id="a2a-ctx")
result = agent._prepare_message_for_a2a(message, session=session)
assert result.context_id == "explicit-ctx"
assert result.context_id == "a2a-ctx"
def test_parse_contents_from_a2a_with_data_part() -> None:
@@ -963,21 +965,16 @@ async def test_run_passes_session_service_session_id_as_context_id(mock_a2a_clie
@mark.asyncio
async def test_run_message_context_id_takes_precedence_over_session(mock_a2a_client: MockA2AClient) -> None:
"""Test that an explicit context_id on the message wins over session.service_session_id."""
async def test_run_a2a_session_context_id_used_over_service_session_id(mock_a2a_client: MockA2AClient) -> None:
"""Test that A2AAgentSession.context_id is used for outbound messages."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
mock_a2a_client.add_message_response("msg-ctx2", "reply")
session = AgentSession(service_session_id="svc-session-42")
message = Message(
role="user",
contents=[Content.from_text(text="Hello")],
additional_properties={"context_id": "explicit-ctx"},
)
await agent.run(messages=[message], session=session)
session = A2AAgentSession(context_id="a2a-ctx-99")
await agent.run("Hello", session=session)
assert mock_a2a_client.last_message is not None
assert mock_a2a_client.last_message.context_id == "explicit-ctx"
assert mock_a2a_client.last_message.context_id == "a2a-ctx-99"
# endregion
@@ -1332,16 +1329,17 @@ async def test_streaming_artifact_update_event_yields_content(
async def test_streaming_status_update_event_yields_content(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that streaming status update events surface message content directly from the update event."""
"""Test that streaming status update events surface content for terminal/input-required states only."""
# COMPLETED state should yield content (terminal)
update_event = TaskStatusUpdateEvent(
task_id="task-status",
context_id="ctx-status",
status=TaskStatus(
state=TaskState.TASK_STATE_WORKING,
state=TaskState.TASK_STATE_COMPLETED,
message=A2AMessage(
message_id=str(uuid4()),
role=A2ARole.ROLE_AGENT,
parts=[Part(text="Still working")],
parts=[Part(text="Done")],
),
),
)
@@ -1352,11 +1350,64 @@ async def test_streaming_status_update_event_yields_content(
updates.append(update)
assert len(updates) == 1
assert updates[0].text == "Still working"
assert updates[0].text == "Done"
assert updates[0].role == "assistant"
assert updates[0].raw_representation == update_event
@mark.asyncio
async def test_streaming_input_required_emits_content(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that input-required status updates emit content (gated states that pass through)."""
update_event = TaskStatusUpdateEvent(
task_id="task-status",
context_id="ctx-status",
status=TaskStatus(
state=TaskState.TASK_STATE_INPUT_REQUIRED,
message=A2AMessage(
message_id=str(uuid4()),
role=A2ARole.ROLE_AGENT,
parts=[Part(text="What is your name?")],
),
),
)
mock_a2a_client.responses.append(StreamResponse(status_update=update_event))
updates: list[AgentResponseUpdate] = []
async for update in a2a_agent.run("Hello", stream=True):
updates.append(update)
assert len(updates) == 1
assert updates[0].text == "What is your name?"
@mark.asyncio
async def test_streaming_working_status_gates_content(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that intermediate WORKING status updates do NOT emit content (gated like .NET)."""
update_event = TaskStatusUpdateEvent(
task_id="task-status",
context_id="ctx-status",
status=TaskStatus(
state=TaskState.TASK_STATE_WORKING,
message=A2AMessage(
message_id=str(uuid4()),
role=A2ARole.ROLE_AGENT,
parts=[Part(text="Processing...")],
),
),
)
mock_a2a_client.responses.append(StreamResponse(status_update=update_event))
updates: list[AgentResponseUpdate] = []
async for update in a2a_agent.run("Hello", stream=True):
updates.append(update)
assert len(updates) == 0
async def test_streaming_artifact_update_event_does_not_duplicate_terminal_task_artifacts(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
@@ -1576,28 +1627,17 @@ async def test_task_status_update_event_metadata_merged(a2a_agent: A2AAgent, moc
task_id="task-se",
context_id="ctx",
status=TaskStatus(
state=TaskState.TASK_STATE_WORKING,
state=TaskState.TASK_STATE_INPUT_REQUIRED,
message=A2AMessage(
message_id="m1",
role=A2ARole.ROLE_AGENT,
parts=[Part(text="working...")],
parts=[Part(text="need input")],
metadata={"msg_key": "msg_val"},
),
),
metadata={"event_key": "event_val"},
)
terminal_task = Task(
id="task-se",
context_id="ctx",
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
artifacts=[
Artifact(artifact_id="a1", parts=[Part(text="done")]),
],
)
mock_a2a_client.responses.extend([
StreamResponse(status_update=status_event),
StreamResponse(task=terminal_task),
])
mock_a2a_client.responses.append(StreamResponse(status_update=status_event))
stream = a2a_agent.run("hello", stream=True)
updates: list[AgentResponseUpdate] = []
@@ -1681,11 +1721,11 @@ async def test_non_streaming_terminal_status_update_surfaces_content(
assert response.messages[0].text == "Done! Here is your answer."
async def test_non_streaming_accumulates_working_content_for_empty_terminal(
async def test_non_streaming_working_content_gated(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Non-streaming run() accumulates WORKING content and flushes on empty terminal event."""
# Intermediate WORKING event with content
"""Non-streaming: WORKING status content is gated and not surfaced to callers."""
# Intermediate WORKING event with content — should be gated
working_msg = A2AMessage(
message_id="msg-working",
role=A2ARole.ROLE_AGENT,
@@ -1702,9 +1742,8 @@ async def test_non_streaming_accumulates_working_content_for_empty_terminal(
response = await a2a_agent.run("Hello")
# The accumulated WORKING content is flushed when terminal arrives empty
assert len(response.messages) == 1
assert response.messages[0].text == "Here is your answer from working state."
# WORKING content is gated — nothing to accumulate or flush
assert len(response.messages) == 0
async def test_non_streaming_intermediate_discarded_when_terminal_has_content(
@@ -1761,3 +1800,268 @@ async def test_non_streaming_artifact_update_surfaces_content(
# endregion
# region Reference Task IDs Tests
@mark.asyncio
async def test_first_message_has_no_reference_task_ids(mock_a2a_client: MockA2AClient) -> None:
"""Test that the first message sent has no reference_task_ids."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
mock_a2a_client.add_task_response("task-first", [{"content": "Hello back"}])
session = A2AAgentSession()
await agent.run("Hello", session=session)
assert mock_a2a_client.last_message is not None
assert list(mock_a2a_client.last_message.reference_task_ids) == []
@mark.asyncio
async def test_follow_up_message_includes_reference_task_ids(mock_a2a_client: MockA2AClient) -> None:
"""Test that a follow-up message references the previous task_id."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
mock_a2a_client.add_task_response("task-abc-123", [{"content": "First reply"}])
session = A2AAgentSession()
await agent.run("Hello", session=session)
# Verify task_id was persisted on session
assert session.task_id == "task-abc-123"
# Send a follow-up message
mock_a2a_client.add_task_response("task-def-456", [{"content": "Second reply"}])
await agent.run("Follow up", session=session)
assert mock_a2a_client.last_message is not None
assert list(mock_a2a_client.last_message.reference_task_ids) == ["task-abc-123"]
@mark.asyncio
async def test_reference_task_ids_updated_after_each_interaction(mock_a2a_client: MockA2AClient) -> None:
"""Test that reference_task_ids always points to the most recent task."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
session = A2AAgentSession()
# First interaction
mock_a2a_client.add_task_response("task-1", [{"content": "Reply 1"}])
await agent.run("Message 1", session=session)
assert session.task_id == "task-1"
# Second interaction
mock_a2a_client.add_task_response("task-2", [{"content": "Reply 2"}])
await agent.run("Message 2", session=session)
assert mock_a2a_client.last_message.reference_task_ids == ["task-1"]
assert session.task_id == "task-2"
# Third interaction references the second task
mock_a2a_client.add_task_response("task-3", [{"content": "Reply 3"}])
await agent.run("Message 3", session=session)
assert mock_a2a_client.last_message.reference_task_ids == ["task-2"]
assert session.task_id == "task-3"
@mark.asyncio
async def test_task_id_tracked_from_status_update_events(mock_a2a_client: MockA2AClient) -> None:
"""Test that task_id is tracked even when response only contains status update events."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
# Simulate a stream that only has status_update events (no full task payload)
status_event = TaskStatusUpdateEvent(
task_id="task-from-status",
context_id="ctx-1",
status=TaskStatus(
state=TaskState.TASK_STATE_COMPLETED,
message=A2AMessage(
message_id="msg-status",
role=A2ARole.ROLE_AGENT,
parts=[Part(text="Done")],
),
),
)
mock_a2a_client.responses.append(StreamResponse(status_update=status_event))
session = A2AAgentSession()
await agent.run("Hello", session=session)
assert session.task_id == "task-from-status"
assert session.task_state == TaskState.TASK_STATE_COMPLETED
@mark.asyncio
async def test_no_session_does_not_crash_reference_task_ids(mock_a2a_client: MockA2AClient) -> None:
"""Test that running without a session (no reference tracking) works fine."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
mock_a2a_client.add_task_response("task-no-session", [{"content": "Reply"}])
# Should not raise — no session means no reference_task_ids
response = await agent.run("Hello")
assert response is not None
assert mock_a2a_client.last_message.reference_task_ids == []
@mark.asyncio
async def test_task_id_not_tracked_from_message_payload(mock_a2a_client: MockA2AClient) -> None:
"""Test that task_id is NOT tracked from message payloads (simple interactions without task tracking)."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
# Simulate a response that is a message with task_id set (no task/status_update events).
# Per A2A spec, a Message response indicates simple interaction — task_id should not be persisted.
message_with_task = A2AMessage(
message_id="msg-with-task",
role=A2ARole.ROLE_AGENT,
parts=[Part(text="Response")],
task_id="task-from-message",
)
mock_a2a_client.responses.append(StreamResponse(message=message_with_task))
session = A2AAgentSession()
await agent.run("Hello", session=session)
assert session.task_id is None
@mark.asyncio
async def test_context_id_assigned_from_response(mock_a2a_client: MockA2AClient) -> None:
"""Test that context_id is assigned from the response when not set on session."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
mock_a2a_client.add_task_response("task-ctx", [{"content": "Reply"}])
session = A2AAgentSession()
assert session.context_id is None
await agent.run("Hello", session=session)
# context_id from the task response should be assigned
assert session.context_id == "test-context"
assert session.service_session_id == "test-context"
@mark.asyncio
async def test_context_id_tracked_from_message_payload(mock_a2a_client: MockA2AClient) -> None:
"""Test that context_id is captured from message-only responses (no task payload)."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
# Simulate a response with only a message that has context_id but no task_id
message_with_context = A2AMessage(
message_id="msg-ctx-only",
role=A2ARole.ROLE_AGENT,
parts=[Part(text="Hello!")],
context_id="server-ctx-123",
)
mock_a2a_client.responses.append(StreamResponse(message=message_with_context))
session = A2AAgentSession()
await agent.run("Hi", session=session)
# context_id should be captured even without a task_id
assert session.context_id == "server-ctx-123"
assert session.service_session_id == "server-ctx-123"
assert session.task_id is None
@mark.asyncio
async def test_context_id_mismatch_raises_error(mock_a2a_client: MockA2AClient) -> None:
"""Test that a context_id mismatch between session and response raises an error."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
# Task response has context_id="test-context" (from add_task_response helper)
mock_a2a_client.add_task_response("task-mismatch", [{"content": "Reply"}])
# Session already has a different context_id
session = A2AAgentSession(context_id="different-context")
with raises(RuntimeError, match="differs from the session's context_id"):
await agent.run("Hello", session=session)
@mark.asyncio
async def test_task_state_tracked_on_session(mock_a2a_client: MockA2AClient) -> None:
"""Test that task_state is tracked on A2AAgentSession."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
# Add a task that ends in INPUT_REQUIRED
mock_a2a_client.add_in_progress_task_response(
"task-input",
context_id="ctx-input",
state=TaskState.TASK_STATE_INPUT_REQUIRED,
text="What is your name?",
)
session = A2AAgentSession()
await agent.run("Start", session=session)
assert session.task_id == "task-input"
assert session.task_state == TaskState.TASK_STATE_INPUT_REQUIRED
@mark.asyncio
async def test_plain_agent_session_no_reference_tracking(mock_a2a_client: MockA2AClient) -> None:
"""Test that a plain AgentSession works but does not get reference_task_ids tracking."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
mock_a2a_client.add_task_response("task-plain", [{"content": "Reply"}])
session = AgentSession()
await agent.run("Hello", session=session)
# Plain session does not get task_id tracking
assert "a2a_task_id" not in session.state
# Follow-up has no reference_task_ids (no tracking on plain session)
mock_a2a_client.add_task_response("task-plain-2", [{"content": "Reply 2"}])
await agent.run("Follow up", session=session)
assert list(mock_a2a_client.last_message.reference_task_ids) == []
@mark.asyncio
async def test_a2a_agent_session_serialization() -> None:
"""Test A2AAgentSession serialization and deserialization."""
session = A2AAgentSession(
context_id="ctx-456",
task_id="task-789",
task_state=TaskState.TASK_STATE_COMPLETED,
)
data = session.to_dict()
restored = A2AAgentSession.from_dict(data)
assert restored.session_id == session.session_id
assert restored.context_id == "ctx-456"
assert restored.task_id == "task-789"
assert restored.task_state == TaskState.TASK_STATE_COMPLETED
@mark.asyncio
async def test_input_required_sets_task_id_instead_of_reference(mock_a2a_client: MockA2AClient) -> None:
"""Test that when task_state is INPUT_REQUIRED, follow-up sets task_id (not reference_task_ids)."""
agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
# First turn: task ends in INPUT_REQUIRED
mock_a2a_client.add_in_progress_task_response(
"task-ir",
context_id="ctx-ir",
state=TaskState.TASK_STATE_INPUT_REQUIRED,
text="What is your name?",
)
session = A2AAgentSession()
await agent.run("Start", session=session)
assert session.task_state == TaskState.TASK_STATE_INPUT_REQUIRED
assert session.task_id == "task-ir"
# Second turn: follow-up should set task_id (not reference_task_ids)
mock_a2a_client.add_in_progress_task_response(
"task-ir-2", context_id="ctx-ir", state=TaskState.TASK_STATE_COMPLETED, text="Thanks!"
)
await agent.run("My name is Alice", session=session)
# The outbound message should have task_id set, not reference_task_ids
last_msg = mock_a2a_client.last_message
assert last_msg.task_id == "task-ir"
assert list(last_msg.reference_task_ids) == []
# endregion