mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
35097d8c75
* Python: Add long-running agents and background responses support - Add ContinuationToken TypedDict to core types - Add continuation_token field to ChatResponse, ChatResponseUpdate, AgentResponse, and AgentResponseUpdate - Add background and continuation_token options to OpenAIResponsesOptions - Implement polling via responses.retrieve() and streaming resumption in RawOpenAIResponsesClient - Propagate continuation tokens through agent run() and map_chat_to_agent_update - Fix streaming telemetry 'Failed to detach context' error in both ChatTelemetryLayer and AgentTelemetryLayer by avoiding trace.use_span() context attachment for async-managed spans - Add 14 unit tests for continuation token types and background flows - Add background_responses sample showing polling and stream resumption Fixes #2478 * Python: Add A2A long-running task support via ContinuationToken - Make ContinuationToken provider-agnostic (total=False, optional task_id/context_id fields) - Add background param to A2AAgent.run() controlling token emission - Add poll_task() for single-request task state retrieval - Add resubscribe support via continuation_token param on run() - Extract _updates_from_task() and _map_a2a_stream() for cleaner code - Streamline run()/streaming by removing intermediate _stream_updates wrapper - Update A2A sample to show background=False (default) with link to background_responses sample - Remove stale BareAgent from __all__ - Add 12 new A2A continuation token tests * fix logic for overriding continuation token when done * refactored ContinuationToken setup
787 lines
29 KiB
Python
787 lines
29 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
from collections.abc import AsyncIterator
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from uuid import uuid4
|
|
|
|
import httpx
|
|
from a2a.types import (
|
|
AgentCard,
|
|
Artifact,
|
|
DataPart,
|
|
FilePart,
|
|
FileWithUri,
|
|
Message,
|
|
Part,
|
|
Task,
|
|
TaskState,
|
|
TaskStatus,
|
|
TextPart,
|
|
)
|
|
from a2a.types import Role as A2ARole
|
|
from agent_framework import (
|
|
AgentResponse,
|
|
AgentResponseUpdate,
|
|
ChatMessage,
|
|
Content,
|
|
)
|
|
from agent_framework.a2a import A2AAgent
|
|
from pytest import fixture, raises
|
|
|
|
from agent_framework_a2a import A2AContinuationToken
|
|
from agent_framework_a2a._agent import _get_uri_data # type: ignore
|
|
|
|
|
|
class MockA2AClient:
|
|
"""Mock implementation of A2A Client for testing."""
|
|
|
|
def __init__(self) -> None:
|
|
self.call_count: int = 0
|
|
self.responses: list[Any] = []
|
|
self.resubscribe_responses: list[Any] = []
|
|
self.get_task_response: Task | None = None
|
|
|
|
def add_message_response(self, message_id: str, text: str, role: str = "agent") -> None:
|
|
"""Add a mock Message response."""
|
|
|
|
# Create actual TextPart instance and wrap it in Part
|
|
text_part = Part(root=TextPart(text=text))
|
|
|
|
# Create actual Message instance
|
|
message = Message(
|
|
message_id=message_id, role=A2ARole.agent if role == "agent" else A2ARole.user, parts=[text_part]
|
|
)
|
|
self.responses.append(message)
|
|
|
|
def add_task_response(self, task_id: str, artifacts: list[dict[str, Any]]) -> None:
|
|
"""Add a mock Task response."""
|
|
# Create mock artifacts
|
|
mock_artifacts = []
|
|
for artifact_data in artifacts:
|
|
# Create actual TextPart instance and wrap it in Part
|
|
text_part = Part(root=TextPart(text=artifact_data.get("content", "Test content")))
|
|
|
|
artifact = Artifact(
|
|
artifact_id=artifact_data.get("id", str(uuid4())),
|
|
name=artifact_data.get("name", "test-artifact"),
|
|
description=artifact_data.get("description", "Test artifact"),
|
|
parts=[text_part],
|
|
)
|
|
mock_artifacts.append(artifact)
|
|
|
|
# Create task status
|
|
status = TaskStatus(state=TaskState.completed, message=None)
|
|
|
|
# Create actual Task instance
|
|
task = Task(
|
|
id=task_id, context_id="test-context", status=status, artifacts=mock_artifacts if mock_artifacts else None
|
|
)
|
|
|
|
# Mock the ClientEvent tuple format
|
|
update_event = None # No specific update event for completed tasks
|
|
client_event = (task, update_event)
|
|
self.responses.append(client_event)
|
|
|
|
def add_in_progress_task_response(
|
|
self,
|
|
task_id: str,
|
|
context_id: str = "test-context",
|
|
state: TaskState = TaskState.working,
|
|
) -> None:
|
|
"""Add a mock in-progress Task response (non-terminal)."""
|
|
status = TaskStatus(state=state, message=None)
|
|
task = Task(id=task_id, context_id=context_id, status=status)
|
|
client_event = (task, None)
|
|
self.responses.append(client_event)
|
|
|
|
async def send_message(self, message: Any) -> AsyncIterator[Any]:
|
|
"""Mock send_message method that yields responses."""
|
|
self.call_count += 1
|
|
|
|
if self.responses:
|
|
response = self.responses.pop(0)
|
|
yield response
|
|
|
|
async def resubscribe(self, request: Any) -> AsyncIterator[Any]:
|
|
"""Mock resubscribe method that yields responses."""
|
|
self.call_count += 1
|
|
|
|
for response in self.resubscribe_responses:
|
|
yield response
|
|
self.resubscribe_responses.clear()
|
|
|
|
async def get_task(self, request: Any) -> Task:
|
|
"""Mock get_task method that returns a task."""
|
|
self.call_count += 1
|
|
if self.get_task_response is not None:
|
|
return self.get_task_response
|
|
msg = "No get_task response configured"
|
|
raise ValueError(msg)
|
|
|
|
|
|
@fixture
|
|
def mock_a2a_client() -> MockA2AClient:
|
|
"""Fixture that provides a mock A2A client."""
|
|
return MockA2AClient()
|
|
|
|
|
|
@fixture
|
|
def a2a_agent(mock_a2a_client: MockA2AClient) -> A2AAgent:
|
|
"""Fixture that provides an A2AAgent with a mock client."""
|
|
return A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None)
|
|
|
|
|
|
def test_a2a_agent_initialization_with_client(mock_a2a_client: MockA2AClient) -> None:
|
|
"""Test A2AAgent initialization with provided client."""
|
|
# Use model_construct to bypass Pydantic validation for mock objects
|
|
agent = A2AAgent(
|
|
name="Test Agent", id="test-agent-123", description="A test agent", client=mock_a2a_client, http_client=None
|
|
)
|
|
|
|
assert agent.name == "Test Agent"
|
|
assert agent.id == "test-agent-123"
|
|
assert agent.description == "A test agent"
|
|
assert agent.client == mock_a2a_client
|
|
|
|
|
|
def test_a2a_agent_initialization_without_client_raises_error() -> None:
|
|
"""Test A2AAgent initialization without client or URL raises ValueError."""
|
|
with raises(ValueError, match="Either agent_card or url must be provided"):
|
|
A2AAgent(name="Test Agent")
|
|
|
|
|
|
async def test_run_with_message_response(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
|
"""Test run() method with immediate Message response."""
|
|
mock_a2a_client.add_message_response("msg-123", "Hello from agent!", "agent")
|
|
|
|
response = await a2a_agent.run("Hello agent")
|
|
|
|
assert isinstance(response, AgentResponse)
|
|
assert len(response.messages) == 1
|
|
assert response.messages[0].role == "assistant"
|
|
assert response.messages[0].text == "Hello from agent!"
|
|
assert response.response_id == "msg-123"
|
|
assert mock_a2a_client.call_count == 1
|
|
|
|
|
|
async def test_run_with_task_response_single_artifact(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
|
"""Test run() method with Task response containing single artifact."""
|
|
artifacts = [{"id": "art-1", "content": "Generated report content"}]
|
|
mock_a2a_client.add_task_response("task-456", artifacts)
|
|
|
|
response = await a2a_agent.run("Generate a report")
|
|
|
|
assert isinstance(response, AgentResponse)
|
|
assert len(response.messages) == 1
|
|
assert response.messages[0].role == "assistant"
|
|
assert response.messages[0].text == "Generated report content"
|
|
assert response.response_id == "task-456"
|
|
assert mock_a2a_client.call_count == 1
|
|
|
|
|
|
async def test_run_with_task_response_multiple_artifacts(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
|
"""Test run() method with Task response containing multiple artifacts."""
|
|
artifacts = [
|
|
{"id": "art-1", "content": "First artifact content"},
|
|
{"id": "art-2", "content": "Second artifact content"},
|
|
{"id": "art-3", "content": "Third artifact content"},
|
|
]
|
|
mock_a2a_client.add_task_response("task-789", artifacts)
|
|
|
|
response = await a2a_agent.run("Generate multiple outputs")
|
|
|
|
assert isinstance(response, AgentResponse)
|
|
assert len(response.messages) == 3
|
|
|
|
assert response.messages[0].text == "First artifact content"
|
|
assert response.messages[1].text == "Second artifact content"
|
|
assert response.messages[2].text == "Third artifact content"
|
|
|
|
# All should be assistant messages
|
|
for message in response.messages:
|
|
assert message.role == "assistant"
|
|
|
|
assert response.response_id == "task-789"
|
|
|
|
|
|
async def test_run_with_task_response_no_artifacts(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
|
"""Test run() method with Task response containing no artifacts."""
|
|
mock_a2a_client.add_task_response("task-empty", [])
|
|
|
|
response = await a2a_agent.run("Do something with no output")
|
|
|
|
assert isinstance(response, AgentResponse)
|
|
assert response.response_id == "task-empty"
|
|
|
|
|
|
async def test_run_with_unknown_response_type_raises_error(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
|
"""Test run() method with unknown response type raises NotImplementedError."""
|
|
mock_a2a_client.responses.append("invalid_response")
|
|
|
|
with raises(NotImplementedError, match="Only Message and Task responses are supported"):
|
|
await a2a_agent.run("Test message")
|
|
|
|
|
|
def test_parse_messages_from_task_empty_artifacts(a2a_agent: A2AAgent) -> None:
|
|
"""Test _parse_messages_from_task with task containing no artifacts."""
|
|
task = MagicMock()
|
|
task.artifacts = None
|
|
|
|
result = a2a_agent._parse_messages_from_task(task)
|
|
|
|
assert len(result) == 0
|
|
|
|
|
|
def test_parse_messages_from_task_with_artifacts(a2a_agent: A2AAgent) -> None:
|
|
"""Test _parse_messages_from_task with task containing artifacts."""
|
|
task = MagicMock()
|
|
|
|
# Create mock artifacts
|
|
artifact1 = MagicMock()
|
|
artifact1.artifact_id = "art-1"
|
|
text_part1 = MagicMock()
|
|
text_part1.root = MagicMock()
|
|
text_part1.root.kind = "text"
|
|
text_part1.root.text = "Content 1"
|
|
text_part1.root.metadata = None
|
|
artifact1.parts = [text_part1]
|
|
|
|
artifact2 = MagicMock()
|
|
artifact2.artifact_id = "art-2"
|
|
text_part2 = MagicMock()
|
|
text_part2.root = MagicMock()
|
|
text_part2.root.kind = "text"
|
|
text_part2.root.text = "Content 2"
|
|
text_part2.root.metadata = None
|
|
artifact2.parts = [text_part2]
|
|
|
|
task.artifacts = [artifact1, artifact2]
|
|
|
|
result = a2a_agent._parse_messages_from_task(task)
|
|
|
|
assert len(result) == 2
|
|
assert result[0].text == "Content 1"
|
|
assert result[1].text == "Content 2"
|
|
assert all(msg.role == "assistant" for msg in result)
|
|
|
|
|
|
def test_parse_message_from_artifact(a2a_agent: A2AAgent) -> None:
|
|
"""Test _parse_message_from_artifact conversion."""
|
|
artifact = MagicMock()
|
|
artifact.artifact_id = "test-artifact"
|
|
|
|
text_part = MagicMock()
|
|
text_part.root = MagicMock()
|
|
text_part.root.kind = "text"
|
|
text_part.root.text = "Artifact content"
|
|
text_part.root.metadata = None
|
|
|
|
artifact.parts = [text_part]
|
|
|
|
result = a2a_agent._parse_message_from_artifact(artifact)
|
|
|
|
assert isinstance(result, ChatMessage)
|
|
assert result.role == "assistant"
|
|
assert result.text == "Artifact content"
|
|
assert result.raw_representation == artifact
|
|
|
|
|
|
def test_get_uri_data_valid_uri() -> None:
|
|
"""Test _get_uri_data with valid data URI."""
|
|
|
|
uri = "data:application/json;base64,eyJ0ZXN0IjoidmFsdWUifQ=="
|
|
result = _get_uri_data(uri)
|
|
assert result == "eyJ0ZXN0IjoidmFsdWUifQ=="
|
|
|
|
|
|
def test_get_uri_data_invalid_uri() -> None:
|
|
"""Test _get_uri_data with invalid URI format."""
|
|
|
|
with raises(ValueError, match="Invalid data URI format"):
|
|
_get_uri_data("not-a-valid-data-uri")
|
|
|
|
|
|
def test_parse_contents_from_a2a_conversion(a2a_agent: A2AAgent) -> None:
|
|
"""Test A2A parts to contents conversion."""
|
|
|
|
agent = A2AAgent(name="Test Agent", client=MockA2AClient(), _http_client=None)
|
|
|
|
# Create A2A parts
|
|
parts = [Part(root=TextPart(text="First part")), Part(root=TextPart(text="Second part"))]
|
|
|
|
# Convert to contents
|
|
contents = agent._parse_contents_from_a2a(parts)
|
|
|
|
# Verify conversion
|
|
assert len(contents) == 2
|
|
assert contents[0].type == "text"
|
|
assert contents[1].type == "text"
|
|
assert contents[0].text == "First part"
|
|
assert contents[1].text == "Second part"
|
|
|
|
|
|
def test_prepare_message_for_a2a_with_error_content(a2a_agent: A2AAgent) -> None:
|
|
"""Test _prepare_message_for_a2a with ErrorContent."""
|
|
|
|
# Create ChatMessage with ErrorContent
|
|
error_content = Content.from_error(message="Test error message")
|
|
message = ChatMessage(role="user", contents=[error_content])
|
|
|
|
# Convert to A2A message
|
|
a2a_message = a2a_agent._prepare_message_for_a2a(message)
|
|
|
|
# Verify conversion
|
|
assert len(a2a_message.parts) == 1
|
|
assert a2a_message.parts[0].root.text == "Test error message"
|
|
|
|
|
|
def test_prepare_message_for_a2a_with_uri_content(a2a_agent: A2AAgent) -> None:
|
|
"""Test _prepare_message_for_a2a with UriContent."""
|
|
|
|
# Create ChatMessage with UriContent
|
|
uri_content = Content.from_uri(uri="http://example.com/file.pdf", media_type="application/pdf")
|
|
message = ChatMessage(role="user", contents=[uri_content])
|
|
|
|
# Convert to A2A message
|
|
a2a_message = a2a_agent._prepare_message_for_a2a(message)
|
|
|
|
# Verify conversion
|
|
assert len(a2a_message.parts) == 1
|
|
assert a2a_message.parts[0].root.file.uri == "http://example.com/file.pdf"
|
|
assert a2a_message.parts[0].root.file.mime_type == "application/pdf"
|
|
|
|
|
|
def test_prepare_message_for_a2a_with_data_content(a2a_agent: A2AAgent) -> None:
|
|
"""Test _prepare_message_for_a2a with DataContent."""
|
|
|
|
# Create ChatMessage with DataContent (base64 data URI)
|
|
data_content = Content.from_uri(uri="data:text/plain;base64,SGVsbG8gV29ybGQ=", media_type="text/plain")
|
|
message = ChatMessage(role="user", contents=[data_content])
|
|
|
|
# Convert to A2A message
|
|
a2a_message = a2a_agent._prepare_message_for_a2a(message)
|
|
|
|
# Verify conversion
|
|
assert len(a2a_message.parts) == 1
|
|
assert a2a_message.parts[0].root.file.bytes == "SGVsbG8gV29ybGQ="
|
|
assert a2a_message.parts[0].root.file.mime_type == "text/plain"
|
|
|
|
|
|
def test_prepare_message_for_a2a_empty_contents_raises_error(a2a_agent: A2AAgent) -> None:
|
|
"""Test _prepare_message_for_a2a with empty contents raises ValueError."""
|
|
# Create ChatMessage with no contents
|
|
message = ChatMessage(role="user", contents=[])
|
|
|
|
# Should raise ValueError for empty contents
|
|
with raises(ValueError, match="ChatMessage.contents is empty"):
|
|
a2a_agent._prepare_message_for_a2a(message)
|
|
|
|
|
|
async def test_run_streaming_with_message_response(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
|
"""Test run(stream=True) method with immediate Message response."""
|
|
mock_a2a_client.add_message_response("msg-stream-123", "Streaming response from agent!", "agent")
|
|
|
|
# Collect streaming updates
|
|
updates: list[AgentResponseUpdate] = []
|
|
async for update in a2a_agent.run("Hello agent", stream=True):
|
|
updates.append(update)
|
|
|
|
# Verify streaming response
|
|
assert len(updates) == 1
|
|
assert isinstance(updates[0], AgentResponseUpdate)
|
|
assert updates[0].role == "assistant"
|
|
assert len(updates[0].contents) == 1
|
|
|
|
content = updates[0].contents[0]
|
|
assert content.type == "text"
|
|
assert content.text == "Streaming response from agent!"
|
|
|
|
assert updates[0].response_id == "msg-stream-123"
|
|
assert mock_a2a_client.call_count == 1
|
|
|
|
|
|
async def test_context_manager_cleanup() -> None:
|
|
"""Test context manager cleanup of http client."""
|
|
|
|
# Create mock http client that tracks aclose calls
|
|
mock_http_client = AsyncMock()
|
|
mock_a2a_client = MagicMock()
|
|
|
|
agent = A2AAgent(client=mock_a2a_client)
|
|
agent._http_client = mock_http_client
|
|
|
|
# Test context manager cleanup
|
|
async with agent:
|
|
pass
|
|
|
|
# Verify aclose was called
|
|
mock_http_client.aclose.assert_called_once()
|
|
|
|
|
|
async def test_context_manager_no_cleanup_when_no_http_client() -> None:
|
|
"""Test context manager when _http_client is None."""
|
|
|
|
mock_a2a_client = MagicMock()
|
|
|
|
agent = A2AAgent(client=mock_a2a_client, _http_client=None)
|
|
|
|
# This should not raise any errors
|
|
async with agent:
|
|
pass
|
|
|
|
|
|
def test_prepare_message_for_a2a_with_multiple_contents() -> None:
|
|
"""Test conversion of ChatMessage with multiple contents."""
|
|
|
|
agent = A2AAgent(client=MagicMock(), _http_client=None)
|
|
|
|
# Create message with multiple content types
|
|
message = ChatMessage(
|
|
role="user",
|
|
contents=[
|
|
Content.from_text(text="Here's the analysis:"),
|
|
Content.from_data(data=b"binary data", media_type="application/octet-stream"),
|
|
Content.from_uri(uri="https://example.com/image.png", media_type="image/png"),
|
|
Content.from_text(text='{"structured": "data"}'),
|
|
],
|
|
)
|
|
|
|
result = agent._prepare_message_for_a2a(message)
|
|
|
|
# Should have converted all 4 contents to parts
|
|
assert len(result.parts) == 4
|
|
|
|
# Check each part type
|
|
assert result.parts[0].root.kind == "text" # Regular text
|
|
assert result.parts[1].root.kind == "file" # Binary data
|
|
assert result.parts[2].root.kind == "file" # URI content
|
|
assert result.parts[3].root.kind == "text" # JSON text remains as text (no parsing)
|
|
|
|
|
|
def test_parse_contents_from_a2a_with_data_part() -> None:
|
|
"""Test conversion of A2A DataPart."""
|
|
|
|
agent = A2AAgent(client=MagicMock(), _http_client=None)
|
|
|
|
# Create DataPart
|
|
data_part = Part(root=DataPart(data={"key": "value", "number": 42}, metadata={"source": "test"}))
|
|
|
|
contents = agent._parse_contents_from_a2a([data_part])
|
|
|
|
assert len(contents) == 1
|
|
|
|
assert contents[0].type == "text"
|
|
assert contents[0].text == '{"key": "value", "number": 42}'
|
|
assert contents[0].additional_properties == {"source": "test"}
|
|
|
|
|
|
def test_parse_contents_from_a2a_unknown_part_kind() -> None:
|
|
"""Test error handling for unknown A2A part kind."""
|
|
agent = A2AAgent(client=MagicMock(), _http_client=None)
|
|
|
|
# Create a mock part with unknown kind
|
|
mock_part = MagicMock()
|
|
mock_part.root.kind = "unknown_kind"
|
|
|
|
with raises(ValueError, match="Unknown Part kind: unknown_kind"):
|
|
agent._parse_contents_from_a2a([mock_part])
|
|
|
|
|
|
def test_prepare_message_for_a2a_with_hosted_file() -> None:
|
|
"""Test conversion of ChatMessage with HostedFileContent to A2A message."""
|
|
|
|
agent = A2AAgent(client=MagicMock(), _http_client=None)
|
|
|
|
# Create message with hosted file content
|
|
message = ChatMessage(
|
|
role="user",
|
|
contents=[Content.from_hosted_file(file_id="hosted://storage/document.pdf")],
|
|
)
|
|
|
|
result = agent._prepare_message_for_a2a(message) # noqa: SLF001
|
|
|
|
# Verify the conversion
|
|
assert len(result.parts) == 1
|
|
part = result.parts[0]
|
|
assert part.root.kind == "file"
|
|
|
|
# Verify it's a FilePart with FileWithUri
|
|
|
|
assert isinstance(part.root, FilePart)
|
|
assert isinstance(part.root.file, FileWithUri)
|
|
assert part.root.file.uri == "hosted://storage/document.pdf"
|
|
assert part.root.file.mime_type is None # HostedFileContent doesn't specify media_type
|
|
|
|
|
|
def test_parse_contents_from_a2a_with_hosted_file_uri() -> None:
|
|
"""Test conversion of A2A FilePart with hosted file URI back to UriContent."""
|
|
|
|
agent = A2AAgent(client=MagicMock(), _http_client=None)
|
|
|
|
# Create FilePart with hosted file URI (simulating what A2A would send back)
|
|
file_part = Part(
|
|
root=FilePart(
|
|
file=FileWithUri(
|
|
uri="hosted://storage/document.pdf",
|
|
mime_type=None,
|
|
)
|
|
)
|
|
)
|
|
|
|
contents = agent._parse_contents_from_a2a([file_part]) # noqa: SLF001
|
|
|
|
assert len(contents) == 1
|
|
|
|
assert contents[0].type == "uri"
|
|
assert contents[0].uri == "hosted://storage/document.pdf"
|
|
assert contents[0].media_type == "" # Converted None to empty string
|
|
|
|
|
|
def test_auth_interceptor_parameter() -> None:
|
|
"""Test that auth_interceptor parameter is accepted without errors."""
|
|
# Create a mock auth interceptor
|
|
mock_auth_interceptor = MagicMock()
|
|
|
|
# Test that A2AAgent can be created with auth_interceptor parameter
|
|
# Using url parameter for simplicity
|
|
agent = A2AAgent(
|
|
name="test-agent",
|
|
url="https://test-agent.example.com",
|
|
auth_interceptor=mock_auth_interceptor,
|
|
)
|
|
|
|
# Verify the agent was created successfully
|
|
assert agent.name == "test-agent"
|
|
assert agent.client is not None
|
|
|
|
|
|
def test_transport_negotiation_both_fail() -> None:
|
|
"""Test that RuntimeError is raised when both primary and fallback transport negotiation fail."""
|
|
# Create a mock agent card
|
|
mock_agent_card = MagicMock(spec=AgentCard)
|
|
mock_agent_card.url = "http://test-agent.example.com"
|
|
|
|
# Mock the factory to simulate both primary and fallback failures
|
|
mock_factory = MagicMock()
|
|
|
|
# Both calls to factory.create() fail
|
|
primary_error = Exception("no compatible transports found")
|
|
fallback_error = Exception("fallback also failed")
|
|
mock_factory.create.side_effect = [primary_error, fallback_error]
|
|
|
|
with (
|
|
patch("agent_framework_a2a._agent.ClientFactory", return_value=mock_factory),
|
|
patch("agent_framework_a2a._agent.minimal_agent_card"),
|
|
patch("agent_framework_a2a._agent.httpx.AsyncClient"),
|
|
raises(RuntimeError, match="A2A transport negotiation failed"),
|
|
):
|
|
# Attempt to create A2AAgent - should raise RuntimeError
|
|
A2AAgent(
|
|
name="test-agent",
|
|
agent_card=mock_agent_card,
|
|
)
|
|
|
|
|
|
def test_create_timeout_config_httpx_timeout() -> None:
|
|
"""Test _create_timeout_config with httpx.Timeout object returns it unchanged."""
|
|
agent = A2AAgent(name="Test Agent", client=MockA2AClient(), http_client=None)
|
|
|
|
custom_timeout = httpx.Timeout(connect=15.0, read=180.0, write=20.0, pool=8.0)
|
|
timeout_config = agent._create_timeout_config(custom_timeout)
|
|
|
|
assert timeout_config is custom_timeout # Same object reference
|
|
assert timeout_config.connect == 15.0
|
|
assert timeout_config.read == 180.0
|
|
assert timeout_config.write == 20.0
|
|
assert timeout_config.pool == 8.0
|
|
|
|
|
|
def test_create_timeout_config_invalid_type() -> None:
|
|
"""Test _create_timeout_config with invalid type raises TypeError."""
|
|
agent = A2AAgent(name="Test Agent", client=MockA2AClient(), http_client=None)
|
|
|
|
with raises(TypeError, match="Invalid timeout type: <class 'str'>. Expected float, httpx.Timeout, or None."):
|
|
agent._create_timeout_config("invalid")
|
|
|
|
|
|
def test_a2a_agent_initialization_with_timeout_parameter() -> None:
|
|
"""Test A2AAgent initialization with timeout parameter."""
|
|
# Test with URL to trigger httpx client creation
|
|
with (
|
|
patch("agent_framework_a2a._agent.httpx.AsyncClient") as mock_async_client,
|
|
patch("agent_framework_a2a._agent.ClientFactory") as mock_factory,
|
|
):
|
|
# Mock the factory and client creation
|
|
mock_client_instance = MagicMock()
|
|
mock_factory.return_value.create.return_value = mock_client_instance
|
|
|
|
# Create agent with custom timeout
|
|
A2AAgent(name="Test Agent", url="https://test-agent.example.com", timeout=120.0)
|
|
|
|
# Verify httpx.AsyncClient was called with the configured timeout
|
|
mock_async_client.assert_called_once()
|
|
call_args = mock_async_client.call_args
|
|
|
|
# Check that timeout parameter was passed
|
|
assert "timeout" in call_args.kwargs
|
|
timeout_arg = call_args.kwargs["timeout"]
|
|
|
|
# Verify it's an httpx.Timeout object with our custom timeout applied to all components
|
|
assert isinstance(timeout_arg, httpx.Timeout)
|
|
|
|
|
|
# region Continuation Token Tests
|
|
|
|
|
|
async def test_working_task_emits_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
|
"""Test that a working (non-terminal) task yields an update with a continuation token when background=True."""
|
|
mock_a2a_client.add_in_progress_task_response("task-wip", context_id="ctx-1", state=TaskState.working)
|
|
|
|
response = await a2a_agent.run("Start long task", background=True)
|
|
|
|
assert isinstance(response, AgentResponse)
|
|
assert response.continuation_token is not None
|
|
assert response.continuation_token["task_id"] == "task-wip"
|
|
assert response.continuation_token["context_id"] == "ctx-1"
|
|
|
|
|
|
async def test_submitted_task_emits_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
|
"""Test that a submitted task yields a continuation token when background=True."""
|
|
mock_a2a_client.add_in_progress_task_response("task-sub", state=TaskState.submitted)
|
|
|
|
response = await a2a_agent.run("Submit task", background=True)
|
|
|
|
assert response.continuation_token is not None
|
|
assert response.continuation_token["task_id"] == "task-sub"
|
|
|
|
|
|
async def test_input_required_task_emits_continuation_token(
|
|
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
|
) -> None:
|
|
"""Test that an input_required task yields a continuation token when background=True."""
|
|
mock_a2a_client.add_in_progress_task_response("task-input", state=TaskState.input_required)
|
|
|
|
response = await a2a_agent.run("Need input", background=True)
|
|
|
|
assert response.continuation_token is not None
|
|
assert response.continuation_token["task_id"] == "task-input"
|
|
|
|
|
|
async def test_working_task_no_token_without_background(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
|
"""Test that background=False (default) does not emit continuation tokens for in-progress tasks."""
|
|
mock_a2a_client.add_in_progress_task_response("task-fg", context_id="ctx-fg", state=TaskState.working)
|
|
|
|
response = await a2a_agent.run("Foreground task")
|
|
|
|
assert response.continuation_token is None
|
|
|
|
|
|
async def test_completed_task_has_no_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
|
"""Test that a completed task does not set a continuation token."""
|
|
mock_a2a_client.add_task_response("task-done", [{"id": "art-1", "content": "Result"}])
|
|
|
|
response = await a2a_agent.run("Quick task")
|
|
|
|
assert response.continuation_token is None
|
|
assert len(response.messages) == 1
|
|
assert response.messages[0].text == "Result"
|
|
|
|
|
|
async def test_streaming_emits_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
|
"""Test that streaming with background=True yields updates with continuation tokens."""
|
|
mock_a2a_client.add_in_progress_task_response("task-stream", context_id="ctx-s", state=TaskState.working)
|
|
|
|
updates: list[AgentResponseUpdate] = []
|
|
async for update in a2a_agent.run("Stream task", stream=True, background=True):
|
|
updates.append(update)
|
|
|
|
assert len(updates) == 1
|
|
assert updates[0].continuation_token is not None
|
|
assert updates[0].continuation_token["task_id"] == "task-stream"
|
|
assert updates[0].continuation_token["context_id"] == "ctx-s"
|
|
|
|
|
|
async def test_resume_via_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
|
"""Test that run() with continuation_token uses resubscribe instead of send_message."""
|
|
# Set up the resubscribe response (completed task)
|
|
status = TaskStatus(state=TaskState.completed, message=None)
|
|
artifact = Artifact(
|
|
artifact_id="art-resume",
|
|
name="result",
|
|
parts=[Part(root=TextPart(text="Resumed result"))],
|
|
)
|
|
task = Task(id="task-resume", context_id="ctx-r", status=status, artifacts=[artifact])
|
|
mock_a2a_client.resubscribe_responses.append((task, None))
|
|
|
|
token = A2AContinuationToken(task_id="task-resume", context_id="ctx-r")
|
|
response = await a2a_agent.run(continuation_token=token)
|
|
|
|
assert isinstance(response, AgentResponse)
|
|
assert len(response.messages) == 1
|
|
assert response.messages[0].text == "Resumed result"
|
|
assert response.continuation_token is None
|
|
|
|
|
|
async def test_resume_streaming_via_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
|
"""Test that streaming run() with continuation_token and background=True uses resubscribe."""
|
|
# Still working
|
|
status_wip = TaskStatus(state=TaskState.working, message=None)
|
|
task_wip = Task(id="task-rs", context_id="ctx-rs", status=status_wip)
|
|
# Then completed
|
|
status_done = TaskStatus(state=TaskState.completed, message=None)
|
|
artifact = Artifact(
|
|
artifact_id="art-rs",
|
|
name="result",
|
|
parts=[Part(root=TextPart(text="Stream resumed"))],
|
|
)
|
|
task_done = Task(id="task-rs", context_id="ctx-rs", status=status_done, artifacts=[artifact])
|
|
mock_a2a_client.resubscribe_responses.extend([(task_wip, None), (task_done, None)])
|
|
|
|
token = A2AContinuationToken(task_id="task-rs", context_id="ctx-rs")
|
|
updates: list[AgentResponseUpdate] = []
|
|
async for update in a2a_agent.run(stream=True, continuation_token=token, background=True):
|
|
updates.append(update)
|
|
|
|
# First update: in-progress with token, second: completed with content
|
|
assert len(updates) == 2
|
|
assert updates[0].continuation_token is not None
|
|
assert updates[0].continuation_token["task_id"] == "task-rs"
|
|
assert updates[1].continuation_token is None
|
|
assert updates[1].contents[0].text == "Stream resumed"
|
|
|
|
|
|
async def test_poll_task_in_progress(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
|
"""Test poll_task returns continuation token when task is still in progress."""
|
|
status = TaskStatus(state=TaskState.working, message=None)
|
|
mock_a2a_client.get_task_response = Task(id="task-poll", context_id="ctx-p", status=status)
|
|
|
|
token = A2AContinuationToken(task_id="task-poll", context_id="ctx-p")
|
|
response = await a2a_agent.poll_task(token)
|
|
|
|
assert response.continuation_token is not None
|
|
assert response.continuation_token["task_id"] == "task-poll"
|
|
|
|
|
|
async def test_poll_task_completed(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
|
"""Test poll_task returns result with no continuation token when task is complete."""
|
|
status = TaskStatus(state=TaskState.completed, message=None)
|
|
artifact = Artifact(
|
|
artifact_id="art-poll",
|
|
name="result",
|
|
parts=[Part(root=TextPart(text="Poll result"))],
|
|
)
|
|
mock_a2a_client.get_task_response = Task(
|
|
id="task-poll-done", context_id="ctx-pd", status=status, artifacts=[artifact]
|
|
)
|
|
|
|
token = A2AContinuationToken(task_id="task-poll-done", context_id="ctx-pd")
|
|
response = await a2a_agent.poll_task(token)
|
|
|
|
assert response.continuation_token is None
|
|
assert len(response.messages) == 1
|
|
assert response.messages[0].text == "Poll result"
|
|
|
|
|
|
# endregion
|