mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: feat: Add Agent Framework to A2A bridge support (#2403)
* feat: Add Agent Framework to A2A bridge support - Implement A2A event adapter for converting agent messages to A2A protocol - Add A2A execution context for managing agent execution state - Implement A2A executor for running agents in A2A environment - Add comprehensive unit tests for event adapter, execution context, and executor - Update agent framework core A2A module exports and type stubs - Integrate thread management utilities for async execution - Add getting started sample for A2A agent framework integration - Update dependencies in uv.lock This integration enables agent framework agents to communicate and execute within the A2A (Agent to Agent) infrastructure. * fix: Update references from agent_thread_storage to _agent_thread_storage in A2A executor tests * Refactor A2A agent framework and improve code structure - Reordered imports in various files for consistency and clarity. - Updated `__all__` definitions to maintain a consistent order across modules. - Simplified method signatures by removing unnecessary line breaks. - Enhanced readability by adjusting formatting in several sections. - Removed redundant comments and example scenarios in the execution context. - Improved handling of agent messages in the event adapter. - Added type hints for better clarity and type checking. - Cleaned up test cases for better organization and readability. * fix: Lint fix new line added * test: Add unit tests for AgentThreadStorage and InMemoryAgentThreadStorage * refactor: Update type hints to use new syntax for Union and List * fix: Validate RequestContext for context_id and message before execution * Refactor tests and remove A2aExecutionContext references - Deleted the test file for A2aExecutionContext as it is no longer needed. - Updated A2aExecutor tests to remove dependencies on A2aExecutionContext and adjusted method calls accordingly. - Modified event adapter tests to use ChatMessage instead of AgentRunResponseUpdate. - Removed A2aExecutionContext from imports in agent_framework.a2a module and updated type hints accordingly. * Refactor A2AExecutor tests and remove event adapter - Updated test cases to use A2AExecutor instead of A2aExecutor for consistency. - Removed mock_event_adapter fixture and related tests as A2aEventAdapter is deprecated. - Consolidated event handling tests into TestA2AExecutorEventAdapter. - Adjusted imports in various files to reflect the removal of deprecated components. - Ensured all references to A2aExecutor are updated to A2AExecutor across the codebase. * refactor: Remove AgentThreadStorage and InMemoryAgentThreadStorage classes from threads and tests * feat: A2AExecutor to have its own override able save and get threads methods for persistent storage. * fix: linter bugs * removed unnecessary changes form core package * new line added * Refactor A2AExecutor tests and update imports - Consolidated mock agent fixtures in test_a2a_executor.py to simplify agent mocking. - Removed redundant tests related to thread storage and agent types, focusing on A2AExecutor's core functionality. - Updated test assertions to reflect changes in message handling with new Message and Content classes. - Enhanced integration tests to ensure compatibility with the new agent framework structure. - Added A2AExecutor to the module exports in __init__.py and __init__.pyi for better accessibility. * Update A2A documentation: enhance usage examples for A2AAgent and A2AExecutor * Updated uv lock * Fix metadata assertion in TestA2AExecutorHandleEvents and reorder load_dotenv call in agent_framework_to_a2a.py * Update agent card configuration: add default input and output modes, and fix agent creation method * Fix assertion for metadata in TestA2AExecutorHandleEvents * Fix formatting issues in TestA2AExecutorExecute and TestA2AExecutorIntegration * Enhance A2AExecutor documentation with examples and clarify agent execution process * Revert uv lock to main * Refactor A2AExecutor: Improve formatting and streamline constructor parameters * Apply suggestions from code review Co-authored-by: Eduard van Valkenburg <eavanvalkenburg@users.noreply.github.com> * Refactor A2AExecutor to use SupportsAgentRun and enhance logging; update agent framework sample for flight and hotel booking capabilities * Enhance A2AExecutor with streaming support and custom run arguments; update tests for initialization and execution scenarios * Enhance A2AExecutor event handling with streamed artifact tracking; update tests for new behavior * Refactor A2AExecutor to enforce type hints for stream and run_kwargs attributes * Refactor A2AExecutor and tests: replace AsyncMock with MagicMock for response stream handling; clean up imports in agent_framework_to_a2a.py * refactor: streamline imports and improve code readability across multiple files * feat: enhance A2AExecutor cancel method with context validation and fixed review comments * feat: implement get_uri_data utility function for extracting base64 data from data URIs and update references * fix: update import path for get_uri_data utility function in A2AExecutor and A2AAgent * fix: correct error message handling in A2AExecutor and update test assertions --------- Co-authored-by: Eduard van Valkenburg <eavanvalkenburg@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
4adfd244ac
commit
b00465d7be
@@ -4,20 +4,48 @@ Agent-to-Agent (A2A) protocol support for inter-agent communication.
|
||||
|
||||
## Main Classes
|
||||
|
||||
- **`A2AAgent`** - Agent wrapper that exposes an agent via the A2A protocol
|
||||
- **`A2AAgent`** - Client to connect to remote A2A-compliant agents.
|
||||
- **`A2AExecutor`** - Bridge to expose Agent Framework agents via the A2A protocol.
|
||||
|
||||
## Usage
|
||||
|
||||
### A2AAgent (Client)
|
||||
|
||||
```python
|
||||
from agent_framework.a2a import A2AAgent
|
||||
|
||||
a2a_agent = A2AAgent(agent=my_agent)
|
||||
# Connect to a remote A2A agent
|
||||
a2a_agent = A2AAgent(url="http://remote-agent/a2a")
|
||||
response = await a2a_agent.run("Hello!")
|
||||
```
|
||||
|
||||
### A2AExecutor (Server/Bridge)
|
||||
|
||||
```python
|
||||
from agent_framework.a2a import A2AExecutor
|
||||
from a2a.server.apps import A2AStarletteApplication
|
||||
from a2a.server.request_handlers import DefaultRequestHandler
|
||||
from a2a.server.tasks import InMemoryTaskStore
|
||||
|
||||
# Create an A2A executor for your agent
|
||||
executor = A2AExecutor(agent=my_agent)
|
||||
|
||||
# Set up the request handler and server application
|
||||
request_handler = DefaultRequestHandler(
|
||||
agent_executor=executor,
|
||||
task_store=InMemoryTaskStore(),
|
||||
)
|
||||
|
||||
app = A2AStarletteApplication(
|
||||
agent_card=my_agent_card,
|
||||
http_handler=request_handler,
|
||||
).build()
|
||||
```
|
||||
|
||||
## Import Path
|
||||
|
||||
```python
|
||||
from agent_framework.a2a import A2AAgent
|
||||
from agent_framework.a2a import A2AAgent, A2AExecutor
|
||||
# or directly:
|
||||
from agent_framework_a2a import A2AAgent
|
||||
from agent_framework_a2a import A2AAgent, A2AExecutor
|
||||
```
|
||||
|
||||
@@ -10,11 +10,49 @@ pip install agent-framework-a2a --pre
|
||||
|
||||
The A2A agent integration enables communication with remote A2A-compliant agents using the standardized A2A protocol. This allows your Agent Framework applications to connect to agents running on different platforms, languages, or services.
|
||||
|
||||
### A2AAgent (Client)
|
||||
|
||||
The `A2AAgent` class is a client that wraps an A2A Client to connect the Agent Framework with external A2A-compliant agents.
|
||||
|
||||
```python
|
||||
from agent_framework.a2a import A2AAgent
|
||||
|
||||
# Connect to a remote A2A agent
|
||||
a2a_agent = A2AAgent(url="http://remote-agent/a2a")
|
||||
response = await a2a_agent.run("Hello!")
|
||||
```
|
||||
|
||||
### A2AExecutor (Hosting)
|
||||
|
||||
The `A2AExecutor` class bridges local AI agents built with the `agent_framework` library to the A2A protocol, allowing them to be hosted and accessed by other A2A-compliant clients.
|
||||
|
||||
```python
|
||||
from agent_framework.a2a import A2AExecutor
|
||||
from a2a.server.apps import A2AStarletteApplication
|
||||
from a2a.server.request_handlers import DefaultRequestHandler
|
||||
from a2a.server.tasks import InMemoryTaskStore
|
||||
|
||||
# Create an A2A executor for your agent
|
||||
executor = A2AExecutor(agent=my_agent)
|
||||
|
||||
# Set up the request handler and server application
|
||||
request_handler = DefaultRequestHandler(
|
||||
agent_executor=executor,
|
||||
task_store=InMemoryTaskStore(),
|
||||
)
|
||||
|
||||
app = A2AStarletteApplication(
|
||||
agent_card=my_agent_card,
|
||||
http_handler=request_handler,
|
||||
).build()
|
||||
```
|
||||
|
||||
### Basic Usage Example
|
||||
|
||||
See the [A2A agent examples](../../samples/04-hosting/a2a/) which demonstrate:
|
||||
|
||||
- Connecting to remote A2A agents
|
||||
- Hosting local agents via A2A protocol
|
||||
- Sending messages and receiving responses
|
||||
- Handling different content types (text, files, data)
|
||||
- Streaming responses and real-time interaction
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
from ._a2a_executor import A2AExecutor
|
||||
from ._agent import A2AAgent, A2AContinuationToken
|
||||
|
||||
try:
|
||||
@@ -12,5 +13,6 @@ except importlib.metadata.PackageNotFoundError:
|
||||
__all__ = [
|
||||
"A2AAgent",
|
||||
"A2AContinuationToken",
|
||||
"A2AExecutor",
|
||||
"__version__",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,275 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import logging
|
||||
from asyncio import CancelledError
|
||||
from collections.abc import Mapping
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from a2a.server.agent_execution import AgentExecutor, RequestContext
|
||||
from a2a.server.events import EventQueue
|
||||
from a2a.server.tasks import TaskUpdater
|
||||
from a2a.types import FilePart, FileWithBytes, FileWithUri, Part, TaskState, TextPart
|
||||
from a2a.utils import new_task
|
||||
from agent_framework import (
|
||||
AgentResponseUpdate,
|
||||
AgentSession,
|
||||
Message,
|
||||
SupportsAgentRun,
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
from agent_framework_a2a._utils import get_uri_data
|
||||
|
||||
logger = logging.getLogger("agent_framework.a2a")
|
||||
|
||||
|
||||
class A2AExecutor(AgentExecutor):
|
||||
"""Execute AI agents using the A2A (Agent-to-Agent) protocol.
|
||||
|
||||
The A2AExecutor bridges AI agents built with the agent_framework library and the A2A protocol,
|
||||
enabling structured agent execution with event-driven communication. It handles execution
|
||||
contexts, delegates history management to the agent's session, and converts agent
|
||||
responses into A2A protocol events.
|
||||
|
||||
The executor supports executing an Agent or WorkflowAgent. It provides comprehensive
|
||||
error handling with task status updates and supports various content types including text,
|
||||
binary data, and URI-based content.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from a2a.server.apps import A2AStarletteApplication
|
||||
from a2a.server.request_handlers import DefaultRequestHandler
|
||||
from a2a.server.tasks import InMemoryTaskStore
|
||||
from a2a.types import AgentCapabilities, AgentCard
|
||||
from agent_framework.a2a import A2AExecutor
|
||||
from agent_framework.openai import OpenAIResponsesClient
|
||||
|
||||
public_agent_card = AgentCard(
|
||||
name="Food Agent",
|
||||
description="A simple agent that provides food-related information.",
|
||||
url="http://localhost:9999/",
|
||||
version="1.0.0",
|
||||
defaultInputModes=["text"],
|
||||
defaultOutputModes=["text"],
|
||||
capabilities=AgentCapabilities(streaming=True),
|
||||
skills=[],
|
||||
)
|
||||
|
||||
# Create an agent
|
||||
agent = OpenAIResponsesClient().as_agent(
|
||||
name="Food Agent",
|
||||
instructions="A simple agent that provides food-related information.",
|
||||
)
|
||||
|
||||
# Set up the A2A server with the A2AExecutor enabled for streaming
|
||||
# and passing custom keyword arguments to the agent's run method.
|
||||
request_handler = DefaultRequestHandler(
|
||||
agent_executor=A2AExecutor(agent, stream=True, run_kwargs={"client_kwargs": {"max_tokens": 500}}),
|
||||
task_store=InMemoryTaskStore(),
|
||||
)
|
||||
|
||||
server = A2AStarletteApplication(
|
||||
agent_card=public_agent_card,
|
||||
http_handler=request_handler,
|
||||
).build()
|
||||
|
||||
Args:
|
||||
agent: The AI agent to execute.
|
||||
stream: Whether to stream the agent response. Defaults to False.
|
||||
run_kwargs: Additional keyword arguments to pass to the agent's run method.
|
||||
"""
|
||||
|
||||
def __init__(self, agent: SupportsAgentRun, stream: bool = False, run_kwargs: Mapping[str, Any] | None = None):
|
||||
"""Initialize the A2AExecutor with the specified agent.
|
||||
|
||||
Args:
|
||||
agent: The AI agent or workflow to execute.
|
||||
stream: Whether to stream the agent response. Defaults to False.
|
||||
run_kwargs: Additional keyword arguments to pass to the agent's run method.
|
||||
Cannot contain 'session' or 'stream' as these are managed by the executor.
|
||||
|
||||
Raises:
|
||||
ValueError: If run_kwargs contains 'session' or 'stream'.
|
||||
"""
|
||||
super().__init__()
|
||||
self._agent: SupportsAgentRun = agent
|
||||
self._stream: bool = stream
|
||||
if run_kwargs:
|
||||
if "session" in run_kwargs:
|
||||
raise ValueError("run_kwargs cannot contain 'session' as it is managed by the executor.")
|
||||
if "stream" in run_kwargs:
|
||||
raise ValueError("run_kwargs cannot contain 'stream' as it is managed by the executor.")
|
||||
self._run_kwargs: Mapping[str, Any] = run_kwargs or {}
|
||||
|
||||
@override
|
||||
async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
|
||||
"""Cancel agent execution for the given request context.
|
||||
|
||||
Uses a TaskUpdater to send a cancellation event through the provided event queue.
|
||||
|
||||
Args:
|
||||
context: The request context identifying the task to cancel.
|
||||
event_queue: The event queue to publish the cancellation event to.
|
||||
|
||||
Raises:
|
||||
ValueError: If context_id is not provided in the RequestContext.
|
||||
"""
|
||||
if context.context_id is None:
|
||||
raise ValueError("Context ID must be provided in the RequestContext")
|
||||
|
||||
updater = TaskUpdater(
|
||||
event_queue=event_queue,
|
||||
task_id=context.task_id or "",
|
||||
context_id=context.context_id,
|
||||
)
|
||||
|
||||
await updater.cancel()
|
||||
|
||||
@override
|
||||
async def execute(self, context: RequestContext, event_queue: EventQueue) -> None:
|
||||
"""Execute the agent with the given context and event queue.
|
||||
|
||||
Orchestrates the agent execution process: sets up the agent session,
|
||||
executes the agent, processes response messages, and handles errors with appropriate task status updates.
|
||||
"""
|
||||
if context.context_id is None:
|
||||
raise ValueError("Context ID must be provided in the RequestContext")
|
||||
if context.message is None:
|
||||
raise ValueError("Message must be provided in the RequestContext")
|
||||
|
||||
query = context.get_user_input()
|
||||
task = context.current_task
|
||||
|
||||
if not task:
|
||||
task = new_task(context.message)
|
||||
await event_queue.enqueue_event(task)
|
||||
|
||||
updater = TaskUpdater(event_queue, task.id, context.context_id)
|
||||
await updater.submit()
|
||||
|
||||
try:
|
||||
await updater.start_work()
|
||||
|
||||
session = self._agent.create_session(session_id=task.context_id)
|
||||
|
||||
if self._stream:
|
||||
await self._run_stream(query, session, updater)
|
||||
else:
|
||||
await self._run(query, session, updater)
|
||||
|
||||
# Mark as complete
|
||||
await updater.complete()
|
||||
except CancelledError:
|
||||
await updater.update_status(state=TaskState.canceled, final=True)
|
||||
except Exception as e:
|
||||
logger.exception("A2AExecutor encountered an error during execution.", exc_info=e)
|
||||
await updater.update_status(
|
||||
state=TaskState.failed,
|
||||
final=True,
|
||||
message=updater.new_agent_message([Part(root=TextPart(text=str(e)))]),
|
||||
)
|
||||
|
||||
async def _run_stream(self, query: Any, session: AgentSession, updater: TaskUpdater) -> None:
|
||||
"""Run the agent in streaming mode and publish updates to the task updater."""
|
||||
response_stream = self._agent.run(query, session=session, stream=True, **self._run_kwargs)
|
||||
streamed_artifact_ids: set[str] = set()
|
||||
await (
|
||||
response_stream.with_transform_hook(
|
||||
partial(self.handle_events, updater=updater, streamed_artifact_ids=streamed_artifact_ids)
|
||||
)
|
||||
).get_final_response()
|
||||
|
||||
async def _run(self, query: Any, session: AgentSession, updater: TaskUpdater) -> None:
|
||||
"""Run the agent in non-streaming mode and publish messages to the task updater."""
|
||||
response = await self._agent.run(query, session=session, stream=False, **self._run_kwargs)
|
||||
response_messages = response.messages
|
||||
|
||||
if not isinstance(response_messages, list):
|
||||
response_messages = [response_messages]
|
||||
|
||||
for message in response_messages:
|
||||
await self.handle_events(message, updater)
|
||||
|
||||
async def handle_events(
|
||||
self, item: Message | AgentResponseUpdate, updater: TaskUpdater, streamed_artifact_ids: set[str] | None = None
|
||||
) -> None:
|
||||
"""Convert agent response items (Messages or Updates) to A2A protocol events.
|
||||
|
||||
Processes Message or AgentResponseUpdate objects and converts them into A2A protocol format.
|
||||
Handles text, data, and URI content. USER role messages are skipped.
|
||||
|
||||
Users can override this method in a subclass to implement custom transformations
|
||||
from their agent's output format to A2A protocol events.
|
||||
|
||||
Args:
|
||||
item: The agent response item (Message or AgentResponseUpdate) to process.
|
||||
updater: The task updater to publish events to.
|
||||
streamed_artifact_ids: A set of artifact IDs that have already been streamed.
|
||||
Used to prevent duplicate updates for the same artifact.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class CustomA2AExecutor(A2AExecutor):
|
||||
async def handle_events(
|
||||
self,
|
||||
item: Message | AgentResponseUpdate,
|
||||
updater: TaskUpdater,
|
||||
streamed_artifact_ids: set[str] | None = None,
|
||||
) -> None:
|
||||
# Custom logic to transform item contents
|
||||
if item.role == "assistant" and item.contents:
|
||||
parts = [Part(root=TextPart(text=f"Custom: {item.contents[0].text}"))]
|
||||
await updater.update_status(
|
||||
state=TaskState.working,
|
||||
message=updater.new_agent_message(parts=parts),
|
||||
)
|
||||
else:
|
||||
await super().handle_events(item, updater)
|
||||
"""
|
||||
role = getattr(item, "role", None)
|
||||
if role == "user":
|
||||
# This is a user message, we can ignore it in the context of task updates
|
||||
return
|
||||
|
||||
parts: list[Part] = []
|
||||
metadata = getattr(item, "additional_properties", None)
|
||||
|
||||
# AgentResponseUpdate uses 'contents', Message uses 'contents'
|
||||
contents = getattr(item, "contents", [])
|
||||
|
||||
for content in contents:
|
||||
if content.type == "text" and content.text:
|
||||
parts.append(Part(root=TextPart(text=content.text)))
|
||||
elif content.type == "data" and content.uri:
|
||||
base64_str = get_uri_data(content.uri)
|
||||
parts.append(Part(root=FilePart(file=FileWithBytes(bytes=base64_str, mime_type=content.media_type))))
|
||||
elif content.type == "uri" and content.uri:
|
||||
parts.append(Part(root=FilePart(file=FileWithUri(uri=content.uri, mime_type=content.media_type))))
|
||||
else:
|
||||
# Silently skip unsupported content types
|
||||
logger.warning("A2AExecutor does not yet support content type: %s. Omitted.", content.type)
|
||||
|
||||
if parts:
|
||||
if isinstance(item, AgentResponseUpdate):
|
||||
# For streaming updates, we send TaskArtifactUpdateEvent via add_artifact
|
||||
await updater.add_artifact(
|
||||
parts=parts,
|
||||
artifact_id=item.message_id,
|
||||
metadata=metadata,
|
||||
append=(
|
||||
True
|
||||
if streamed_artifact_ids is not None and item.message_id in (streamed_artifact_ids or set())
|
||||
else None
|
||||
),
|
||||
)
|
||||
if item.message_id and streamed_artifact_ids is not None:
|
||||
streamed_artifact_ids.add(item.message_id)
|
||||
else:
|
||||
# For final messages, we send TaskStatusUpdateEvent with 'working' state
|
||||
await updater.update_status(
|
||||
state=TaskState.working,
|
||||
message=updater.new_agent_message(parts=parts, metadata=metadata),
|
||||
)
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence
|
||||
from typing import Any, Final, Literal, TypeAlias, overload
|
||||
@@ -49,7 +48,7 @@ from agent_framework.observability import AgentTelemetryLayer
|
||||
|
||||
__all__ = ["A2AAgent", "A2AContinuationToken"]
|
||||
|
||||
URI_PATTERN = re.compile(r"^data:(?P<media_type>[^;]+);base64,(?P<base64_data>[A-Za-z0-9+/=]+)$")
|
||||
from agent_framework_a2a._utils import get_uri_data
|
||||
|
||||
|
||||
class A2AContinuationToken(ContinuationToken):
|
||||
@@ -78,14 +77,6 @@ A2AClientEvent: TypeAlias = tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpda
|
||||
A2AStreamItem: TypeAlias = A2AMessage | A2AClientEvent
|
||||
|
||||
|
||||
def _get_uri_data(uri: str) -> str:
|
||||
match = URI_PATTERN.match(uri)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid data URI format: {uri}")
|
||||
|
||||
return match.group("base64_data")
|
||||
|
||||
|
||||
class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
"""Agent2Agent (A2A) protocol implementation.
|
||||
|
||||
@@ -652,7 +643,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
A2APart(
|
||||
root=FilePart(
|
||||
file=FileWithBytes(
|
||||
bytes=_get_uri_data(content.uri),
|
||||
bytes=get_uri_data(content.uri),
|
||||
mime_type=content.media_type,
|
||||
),
|
||||
metadata=content.additional_properties,
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import re
|
||||
|
||||
URI_PATTERN = re.compile(r"^data:(?P<media_type>[^;]+);base64,(?P<base64_data>[A-Za-z0-9+/=]+)$")
|
||||
|
||||
|
||||
def get_uri_data(uri: str) -> str:
|
||||
"""Extracts the base64-encoded data from a data URI.
|
||||
|
||||
Args:
|
||||
uri: The data URI to parse.
|
||||
|
||||
Returns:
|
||||
The base64-encoded data part of the URI.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URI format is invalid.
|
||||
"""
|
||||
match = URI_PATTERN.match(uri)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid data URI format: {uri}")
|
||||
|
||||
return match.group("base64_data")
|
||||
@@ -35,7 +35,7 @@ from agent_framework.a2a import A2AAgent
|
||||
from pytest import fixture, mark, raises
|
||||
|
||||
from agent_framework_a2a import A2AContinuationToken
|
||||
from agent_framework_a2a._agent import _get_uri_data # type: ignore
|
||||
from agent_framework_a2a._utils import get_uri_data
|
||||
|
||||
|
||||
class MockA2AClient:
|
||||
@@ -353,18 +353,18 @@ def test_parse_message_from_artifact(a2a_agent: A2AAgent) -> None:
|
||||
|
||||
|
||||
def test_get_uri_data_valid_uri() -> None:
|
||||
"""Test _get_uri_data with valid data URI."""
|
||||
"""Test get_uri_data with valid data URI."""
|
||||
|
||||
uri = "data:application/json;base64,eyJ0ZXN0IjoidmFsdWUifQ=="
|
||||
result = _get_uri_data(uri)
|
||||
result = get_uri_data(uri)
|
||||
assert result == "eyJ0ZXN0IjoidmFsdWUifQ=="
|
||||
|
||||
|
||||
def test_get_uri_data_invalid_uri() -> None:
|
||||
"""Test _get_uri_data with invalid URI format."""
|
||||
"""Test get_uri_data with invalid URI format."""
|
||||
|
||||
with raises(ValueError, match="Invalid data URI format"):
|
||||
_get_uri_data("not-a-valid-data-uri")
|
||||
get_uri_data("not-a-valid-data-uri")
|
||||
|
||||
|
||||
def test_parse_contents_from_a2a_conversion(a2a_agent: A2AAgent) -> None:
|
||||
|
||||
@@ -0,0 +1,910 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
from asyncio import CancelledError
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from a2a.types import Task, TaskState, TextPart
|
||||
from agent_framework import (
|
||||
AgentResponseUpdate,
|
||||
Content,
|
||||
Message,
|
||||
SupportsAgentRun,
|
||||
)
|
||||
from agent_framework._types import AgentResponse
|
||||
from agent_framework.a2a import A2AExecutor
|
||||
from pytest import fixture, raises
|
||||
|
||||
|
||||
@fixture
|
||||
def mock_agent() -> MagicMock:
|
||||
"""Fixture that provides a mock SupportsAgentRun."""
|
||||
agent = MagicMock(spec=SupportsAgentRun)
|
||||
agent.run = AsyncMock()
|
||||
return agent
|
||||
|
||||
|
||||
@fixture
|
||||
def mock_request_context() -> MagicMock:
|
||||
"""Fixture that provides a mock RequestContext."""
|
||||
request_context = MagicMock()
|
||||
request_context.context_id = str(uuid4())
|
||||
request_context.get_user_input = MagicMock(return_value="Test query")
|
||||
request_context.current_task = None
|
||||
request_context.message = None
|
||||
return request_context
|
||||
|
||||
|
||||
@fixture
|
||||
def mock_event_queue() -> MagicMock:
|
||||
"""Fixture that provides a mock EventQueue."""
|
||||
queue = AsyncMock()
|
||||
queue.enqueue_event = AsyncMock()
|
||||
return queue
|
||||
|
||||
|
||||
@fixture
|
||||
def mock_task() -> Task:
|
||||
"""Fixture that provides a mock Task."""
|
||||
task = MagicMock(spec=Task)
|
||||
task.id = str(uuid4())
|
||||
task.context_id = str(uuid4())
|
||||
task.state = TaskState.completed
|
||||
return task
|
||||
|
||||
|
||||
@fixture
|
||||
def mock_task_updater() -> MagicMock:
|
||||
"""Fixture that provides a mock TaskUpdater."""
|
||||
updater = MagicMock()
|
||||
updater.submit = AsyncMock()
|
||||
updater.start_work = AsyncMock()
|
||||
updater.complete = AsyncMock()
|
||||
updater.update_status = AsyncMock()
|
||||
updater.new_agent_message = MagicMock()
|
||||
return updater
|
||||
|
||||
|
||||
@fixture
|
||||
def executor(mock_agent: MagicMock) -> A2AExecutor:
|
||||
"""Fixture that provides an A2AExecutor."""
|
||||
return A2AExecutor(agent=mock_agent)
|
||||
|
||||
|
||||
class TestA2AExecutorInitialization:
|
||||
"""Tests for A2AExecutor initialization."""
|
||||
|
||||
def test_initialization_with_agent_only(self, mock_agent: MagicMock) -> None:
|
||||
"""Arrange: Create mock agent
|
||||
Act: Initialize A2AExecutor with only agent
|
||||
Assert: Executor is created with default values
|
||||
"""
|
||||
# Act
|
||||
executor = A2AExecutor(agent=mock_agent)
|
||||
|
||||
# Assert
|
||||
assert executor._agent is mock_agent
|
||||
assert executor._stream is False
|
||||
assert executor._run_kwargs == {}
|
||||
|
||||
def test_initialization_with_stream_and_kwargs(self, mock_agent: MagicMock) -> None:
|
||||
"""Arrange: Create mock agent
|
||||
Act: Initialize A2AExecutor with stream and run_kwargs
|
||||
Assert: Executor is created with specified values
|
||||
"""
|
||||
# Arrange
|
||||
run_kwargs = {"temperature": 0.5}
|
||||
|
||||
# Act
|
||||
executor = A2AExecutor(agent=mock_agent, stream=True, run_kwargs=run_kwargs)
|
||||
|
||||
# Assert
|
||||
assert executor._agent is mock_agent
|
||||
assert executor._stream is True
|
||||
assert executor._run_kwargs == run_kwargs
|
||||
|
||||
def test_initialization_with_invalid_run_kwargs(self, mock_agent: MagicMock) -> None:
|
||||
"""Arrange: Create mock agent
|
||||
Act: Initialize A2AExecutor with reserved keys in run_kwargs
|
||||
Assert: ValueError is raised
|
||||
"""
|
||||
# Act & Assert
|
||||
with raises(ValueError, match="run_kwargs cannot contain 'session'"):
|
||||
A2AExecutor(agent=mock_agent, run_kwargs={"session": "something"})
|
||||
|
||||
with raises(ValueError, match="run_kwargs cannot contain 'stream'"):
|
||||
A2AExecutor(agent=mock_agent, run_kwargs={"stream": True})
|
||||
|
||||
|
||||
class TestA2AExecutorCancel:
|
||||
"""Tests for the cancel method."""
|
||||
|
||||
async def test_cancel_method_completes(
|
||||
self,
|
||||
executor: A2AExecutor,
|
||||
mock_request_context: MagicMock,
|
||||
mock_event_queue: MagicMock,
|
||||
) -> None:
|
||||
"""Arrange: Create executor with dependencies
|
||||
Act: Call cancel method
|
||||
Assert: Method completes without raising error
|
||||
"""
|
||||
# Arrange
|
||||
mock_request_context.task_id = "task-123"
|
||||
|
||||
# Act & Assert (should not raise)
|
||||
await executor.cancel(mock_request_context, mock_event_queue) # type: ignore
|
||||
|
||||
async def test_cancel_handles_different_contexts(
|
||||
self,
|
||||
executor: A2AExecutor,
|
||||
mock_event_queue: MagicMock,
|
||||
) -> None:
|
||||
"""Arrange: Create executor with multiple request contexts
|
||||
Act: Call cancel with different contexts
|
||||
Assert: Each cancel completes successfully
|
||||
"""
|
||||
# Arrange
|
||||
context1 = MagicMock()
|
||||
context1.context_id = "ctx-1"
|
||||
context1.task_id = "task-1"
|
||||
context2 = MagicMock()
|
||||
context2.context_id = "ctx-2"
|
||||
context2.task_id = "task-2"
|
||||
|
||||
# Act & Assert
|
||||
await executor.cancel(context1, mock_event_queue) # type: ignore
|
||||
await executor.cancel(context2, mock_event_queue) # type: ignore
|
||||
|
||||
async def test_cancel_raises_error_when_context_id_missing(
|
||||
self,
|
||||
executor: A2AExecutor,
|
||||
mock_event_queue: MagicMock,
|
||||
) -> None:
|
||||
"""Arrange: Create context without context_id
|
||||
Act: Call cancel method
|
||||
Assert: ValueError is raised
|
||||
"""
|
||||
# Arrange
|
||||
mock_context = MagicMock()
|
||||
mock_context.context_id = None
|
||||
|
||||
# Act & Assert
|
||||
with raises(ValueError) as excinfo:
|
||||
await executor.cancel(mock_context, mock_event_queue) # type: ignore
|
||||
|
||||
# Assert
|
||||
assert "Context ID" in str(excinfo.value)
|
||||
|
||||
|
||||
class TestA2AExecutorExecute:
|
||||
"""Tests for the execute method."""
|
||||
|
||||
async def test_execute_with_existing_task_succeeds(
|
||||
self,
|
||||
executor: A2AExecutor,
|
||||
mock_request_context: MagicMock,
|
||||
mock_event_queue: MagicMock,
|
||||
mock_task: Task,
|
||||
) -> None:
|
||||
"""Arrange: Create executor with mocked dependencies and existing task
|
||||
Act: Call execute method
|
||||
Assert: Execution completes successfully
|
||||
"""
|
||||
# Arrange
|
||||
mock_request_context.get_user_input = MagicMock(return_value="Hello")
|
||||
mock_request_context.current_task = mock_task
|
||||
mock_request_context.context_id = "ctx-123"
|
||||
mock_request_context.message = MagicMock()
|
||||
|
||||
response_message = Message(role="assistant", contents=[Content.from_text(text="Hello back")])
|
||||
response = MagicMock(spec=AgentResponse)
|
||||
response.messages = [response_message]
|
||||
executor._agent.run = AsyncMock(return_value=response)
|
||||
executor._agent.create_session = MagicMock()
|
||||
|
||||
with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class:
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.submit = AsyncMock()
|
||||
mock_updater.start_work = AsyncMock()
|
||||
mock_updater.complete = AsyncMock()
|
||||
mock_updater.update_status = AsyncMock()
|
||||
mock_updater.new_agent_message = MagicMock(return_value="message_obj")
|
||||
mock_updater_class.return_value = mock_updater
|
||||
|
||||
# Act
|
||||
await executor.execute(mock_request_context, mock_event_queue)
|
||||
|
||||
# Assert
|
||||
mock_updater.submit.assert_called_once()
|
||||
mock_updater.start_work.assert_called_once()
|
||||
mock_updater.complete.assert_called_once()
|
||||
executor._agent.create_session.assert_called_once()
|
||||
executor._agent.run.assert_called_once()
|
||||
|
||||
async def test_execute_creates_task_when_not_exists(
|
||||
self,
|
||||
executor: A2AExecutor,
|
||||
mock_request_context: MagicMock,
|
||||
mock_event_queue: MagicMock,
|
||||
) -> None:
|
||||
"""Arrange: Create executor with request context without task
|
||||
Act: Call execute method
|
||||
Assert: New task is created and enqueued
|
||||
"""
|
||||
# Arrange
|
||||
mock_message = MagicMock()
|
||||
mock_request_context.get_user_input = MagicMock(return_value="Hello")
|
||||
mock_request_context.current_task = None
|
||||
mock_request_context.message = mock_message
|
||||
mock_request_context.context_id = "ctx-123"
|
||||
|
||||
response_message = Message(role="assistant", contents=[Content.from_text(text="Response")])
|
||||
response = MagicMock(spec=AgentResponse)
|
||||
response.messages = [response_message]
|
||||
executor._agent.run = AsyncMock(return_value=response)
|
||||
executor._agent.create_session = MagicMock()
|
||||
|
||||
with patch("agent_framework_a2a._a2a_executor.new_task") as mock_new_task:
|
||||
mock_task = MagicMock(spec=Task)
|
||||
mock_task.id = "task-new"
|
||||
mock_task.context_id = "ctx-123"
|
||||
mock_new_task.return_value = mock_task
|
||||
|
||||
with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class:
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.submit = AsyncMock()
|
||||
mock_updater.start_work = AsyncMock()
|
||||
mock_updater.complete = AsyncMock()
|
||||
mock_updater.update_status = AsyncMock()
|
||||
mock_updater.new_agent_message = MagicMock(return_value="message_obj")
|
||||
mock_updater_class.return_value = mock_updater
|
||||
|
||||
# Act
|
||||
await executor.execute(mock_request_context, mock_event_queue)
|
||||
|
||||
# Assert
|
||||
mock_new_task.assert_called_once()
|
||||
mock_event_queue.enqueue_event.assert_called_once()
|
||||
|
||||
async def test_execute_raises_error_when_context_id_missing(
|
||||
self,
|
||||
executor: A2AExecutor,
|
||||
mock_request_context: MagicMock,
|
||||
mock_event_queue: MagicMock,
|
||||
) -> None:
|
||||
"""Arrange: Create context without context_id
|
||||
Act: Call execute method
|
||||
Assert: ValueError is raised
|
||||
"""
|
||||
# Arrange
|
||||
mock_request_context.context_id = None
|
||||
mock_request_context.message = MagicMock()
|
||||
|
||||
# Act & Assert
|
||||
with raises(ValueError) as excinfo:
|
||||
await executor.execute(mock_request_context, mock_event_queue)
|
||||
|
||||
# Assert
|
||||
assert "Context ID" in str(excinfo.value)
|
||||
|
||||
async def test_execute_raises_error_when_message_missing(
|
||||
self,
|
||||
executor: A2AExecutor,
|
||||
mock_request_context: MagicMock,
|
||||
mock_event_queue: MagicMock,
|
||||
) -> None:
|
||||
"""Arrange: Create context without message
|
||||
Act: Call execute method
|
||||
Assert: ValueError is raised
|
||||
"""
|
||||
# Arrange
|
||||
mock_request_context.context_id = "ctx-123"
|
||||
mock_request_context.message = None
|
||||
|
||||
# Act & Assert
|
||||
with raises(ValueError) as excinfo:
|
||||
await executor.execute(mock_request_context, mock_event_queue)
|
||||
|
||||
# Assert
|
||||
assert "Message" in str(excinfo.value)
|
||||
|
||||
async def test_execute_handles_cancelled_error(
|
||||
self,
|
||||
executor: A2AExecutor,
|
||||
mock_request_context: MagicMock,
|
||||
mock_event_queue: MagicMock,
|
||||
mock_task: Task,
|
||||
) -> None:
|
||||
"""Arrange: Create executor that raises CancelledError
|
||||
Act: Call execute method
|
||||
Assert: Error is caught and task is marked as canceled
|
||||
"""
|
||||
# Arrange
|
||||
mock_request_context.get_user_input = MagicMock(return_value="Hello")
|
||||
mock_request_context.current_task = mock_task
|
||||
mock_request_context.context_id = "ctx-123"
|
||||
mock_request_context.message = MagicMock()
|
||||
|
||||
executor._agent.run = AsyncMock(side_effect=CancelledError())
|
||||
executor._agent.create_session = MagicMock()
|
||||
|
||||
with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class:
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.submit = AsyncMock()
|
||||
mock_updater.start_work = AsyncMock()
|
||||
mock_updater.update_status = AsyncMock()
|
||||
mock_updater_class.return_value = mock_updater
|
||||
|
||||
# Act
|
||||
await executor.execute(mock_request_context, mock_event_queue) # type: ignore
|
||||
|
||||
# Assert
|
||||
mock_updater.update_status.assert_called()
|
||||
call_args_list = mock_updater.update_status.call_args_list
|
||||
assert any(
|
||||
call[1].get("state") == TaskState.canceled and call[1].get("final") is True for call in call_args_list
|
||||
)
|
||||
|
||||
async def test_execute_handles_generic_exception(
|
||||
self,
|
||||
executor: A2AExecutor,
|
||||
mock_request_context: MagicMock,
|
||||
mock_event_queue: MagicMock,
|
||||
mock_task: Task,
|
||||
) -> None:
|
||||
"""Arrange: Create executor that raises generic exception
|
||||
Act: Call execute method
|
||||
Assert: Error is caught and task is marked as failed
|
||||
"""
|
||||
# Arrange
|
||||
mock_request_context.get_user_input = MagicMock(return_value="Hello")
|
||||
mock_request_context.current_task = mock_task
|
||||
mock_request_context.context_id = "ctx-123"
|
||||
mock_request_context.message = MagicMock()
|
||||
|
||||
error_message = "Test error"
|
||||
executor._agent.run = AsyncMock(side_effect=ValueError(error_message))
|
||||
executor._agent.create_session = MagicMock()
|
||||
|
||||
with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class:
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.submit = AsyncMock()
|
||||
mock_updater.start_work = AsyncMock()
|
||||
mock_updater.update_status = AsyncMock()
|
||||
mock_updater.new_agent_message = MagicMock(return_value="error_message_obj")
|
||||
mock_updater_class.return_value = mock_updater
|
||||
|
||||
# Act
|
||||
await executor.execute(mock_request_context, mock_event_queue)
|
||||
|
||||
# Assert
|
||||
mock_updater.new_agent_message.assert_called_once()
|
||||
args, _ = mock_updater.new_agent_message.call_args
|
||||
parts = args[0]
|
||||
assert len(parts) == 1
|
||||
assert isinstance(parts[0].root, TextPart)
|
||||
assert parts[0].root.text == error_message
|
||||
|
||||
call_args_list = mock_updater.update_status.call_args_list
|
||||
assert any(
|
||||
call[1].get("state") == TaskState.failed
|
||||
and call[1].get("final") is True
|
||||
and call[1].get("message") == "error_message_obj"
|
||||
for call in call_args_list
|
||||
)
|
||||
|
||||
async def test_execute_processes_multiple_response_messages(
|
||||
self,
|
||||
executor: A2AExecutor,
|
||||
mock_request_context: MagicMock,
|
||||
mock_event_queue: MagicMock,
|
||||
mock_task: Task,
|
||||
) -> None:
|
||||
"""Arrange: Create executor that returns multiple response messages
|
||||
Act: Call execute method
|
||||
Assert: All messages are processed through handle_events
|
||||
"""
|
||||
# Arrange
|
||||
mock_request_context.get_user_input = MagicMock(return_value="Hello")
|
||||
mock_request_context.current_task = mock_task
|
||||
mock_request_context.context_id = "ctx-123"
|
||||
mock_request_context.message = MagicMock()
|
||||
|
||||
response_message1 = Message(role="assistant", contents=[Content.from_text(text="First")])
|
||||
response_message2 = Message(role="assistant", contents=[Content.from_text(text="Second")])
|
||||
response = MagicMock(spec=AgentResponse)
|
||||
response.messages = [response_message1, response_message2]
|
||||
executor._agent.run = AsyncMock(return_value=response)
|
||||
executor._agent.create_session = MagicMock()
|
||||
|
||||
# Mock handle_events
|
||||
executor.handle_events = AsyncMock()
|
||||
|
||||
with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class:
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.submit = AsyncMock()
|
||||
mock_updater.start_work = AsyncMock()
|
||||
mock_updater.complete = AsyncMock()
|
||||
mock_updater.update_status = AsyncMock()
|
||||
mock_updater_class.return_value = mock_updater
|
||||
|
||||
# Act
|
||||
await executor.execute(mock_request_context, mock_event_queue)
|
||||
|
||||
# Assert
|
||||
assert executor.handle_events.call_count == 2
|
||||
|
||||
async def test_execute_passes_query_to_run(
|
||||
self,
|
||||
executor: A2AExecutor,
|
||||
mock_request_context: MagicMock,
|
||||
mock_event_queue: MagicMock,
|
||||
mock_task: Task,
|
||||
) -> None:
|
||||
"""Arrange: Create executor with request
|
||||
Act: Call execute method
|
||||
Assert: Query text is passed to run method with default stream and kwargs
|
||||
"""
|
||||
# Arrange
|
||||
query_text = "Hello agent"
|
||||
mock_request_context.get_user_input = MagicMock(return_value=query_text)
|
||||
mock_request_context.current_task = mock_task
|
||||
mock_request_context.context_id = "ctx-123"
|
||||
mock_request_context.message = MagicMock()
|
||||
|
||||
response_message = Message(role="assistant", contents=[Content.from_text(text="Response")])
|
||||
response = MagicMock(spec=AgentResponse)
|
||||
response.messages = [response_message]
|
||||
executor._agent.run = AsyncMock(return_value=response)
|
||||
executor._agent.create_session = MagicMock()
|
||||
|
||||
with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class:
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.submit = AsyncMock()
|
||||
mock_updater.start_work = AsyncMock()
|
||||
mock_updater.complete = AsyncMock()
|
||||
mock_updater.update_status = AsyncMock()
|
||||
mock_updater.new_agent_message = MagicMock(return_value="message_obj")
|
||||
mock_updater_class.return_value = mock_updater
|
||||
|
||||
# Act
|
||||
await executor.execute(mock_request_context, mock_event_queue)
|
||||
|
||||
# Assert
|
||||
executor._agent.run.assert_called_once_with(
|
||||
query_text, session=executor._agent.create_session(), stream=False
|
||||
)
|
||||
|
||||
async def test_execute_with_stream_enabled(
|
||||
self,
|
||||
mock_agent: MagicMock,
|
||||
mock_request_context: MagicMock,
|
||||
mock_event_queue: MagicMock,
|
||||
mock_task: Task,
|
||||
) -> None:
|
||||
"""Arrange: Create executor with stream=True
|
||||
Act: Call execute method
|
||||
Assert: _run_stream is called and passes stream=True to run
|
||||
"""
|
||||
# Arrange
|
||||
executor = A2AExecutor(agent=mock_agent, stream=True)
|
||||
query_text = "Hello agent"
|
||||
mock_request_context.get_user_input = MagicMock(return_value=query_text)
|
||||
mock_request_context.current_task = mock_task
|
||||
mock_request_context.context_id = "ctx-123"
|
||||
mock_request_context.message = MagicMock()
|
||||
|
||||
mock_response_stream = MagicMock()
|
||||
mock_response_stream.with_transform_hook = MagicMock(return_value=mock_response_stream)
|
||||
mock_response_stream.get_final_response = AsyncMock()
|
||||
mock_agent.run = MagicMock(return_value=mock_response_stream)
|
||||
mock_agent.create_session = MagicMock()
|
||||
|
||||
with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class:
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.submit = AsyncMock()
|
||||
mock_updater.start_work = AsyncMock()
|
||||
mock_updater.complete = AsyncMock()
|
||||
mock_updater.update_status = AsyncMock()
|
||||
mock_updater_class.return_value = mock_updater
|
||||
|
||||
# Act
|
||||
await executor.execute(mock_request_context, mock_event_queue)
|
||||
|
||||
# Assert
|
||||
mock_agent.run.assert_called_once_with(query_text, session=mock_agent.create_session(), stream=True)
|
||||
mock_response_stream.with_transform_hook.assert_called_once()
|
||||
mock_response_stream.get_final_response.assert_called_once()
|
||||
|
||||
async def test_execute_with_run_kwargs(
|
||||
self,
|
||||
mock_agent: MagicMock,
|
||||
mock_request_context: MagicMock,
|
||||
mock_event_queue: MagicMock,
|
||||
mock_task: Task,
|
||||
) -> None:
|
||||
"""Arrange: Create executor with run_kwargs
|
||||
Act: Call execute method
|
||||
Assert: run_kwargs are passed to run method
|
||||
"""
|
||||
# Arrange
|
||||
run_kwargs = {"temperature": 0.5, "max_tokens": 100}
|
||||
executor = A2AExecutor(agent=mock_agent, run_kwargs=run_kwargs)
|
||||
query_text = "Hello agent"
|
||||
mock_request_context.get_user_input = MagicMock(return_value=query_text)
|
||||
mock_request_context.current_task = mock_task
|
||||
mock_request_context.context_id = "ctx-123"
|
||||
mock_request_context.message = MagicMock()
|
||||
|
||||
response_message = Message(role="assistant", contents=[Content.from_text(text="Response")])
|
||||
response = MagicMock(spec=AgentResponse)
|
||||
response.messages = [response_message]
|
||||
mock_agent.run = AsyncMock(return_value=response)
|
||||
mock_agent.create_session = MagicMock()
|
||||
|
||||
with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class:
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.submit = AsyncMock()
|
||||
mock_updater.start_work = AsyncMock()
|
||||
mock_updater.complete = AsyncMock()
|
||||
mock_updater.update_status = AsyncMock()
|
||||
mock_updater_class.return_value = mock_updater
|
||||
|
||||
# Act
|
||||
await executor.execute(mock_request_context, mock_event_queue)
|
||||
|
||||
# Assert
|
||||
mock_agent.run.assert_called_once_with(
|
||||
query_text, session=mock_agent.create_session(), stream=False, **run_kwargs
|
||||
)
|
||||
|
||||
|
||||
class TestA2AExecutorHandleEvents:
|
||||
"""Tests for A2AExecutor.handle_events method."""
|
||||
|
||||
async def test_run_method_with_single_message(self, executor: A2AExecutor, mock_updater: MagicMock) -> None:
|
||||
"""Test the private _run method with a single message (not a list)."""
|
||||
# Arrange
|
||||
query = "test query"
|
||||
session = MagicMock()
|
||||
response_message = Message(role="assistant", contents=[Content.from_text(text="Response")])
|
||||
response = MagicMock(spec=AgentResponse)
|
||||
response.messages = response_message # Not a list
|
||||
executor._agent.run = AsyncMock(return_value=response)
|
||||
executor.handle_events = AsyncMock()
|
||||
|
||||
# Act
|
||||
await executor._run(query, session, mock_updater)
|
||||
|
||||
# Assert
|
||||
executor.handle_events.assert_called_once_with(response_message, mock_updater)
|
||||
|
||||
@fixture
|
||||
def mock_updater(self) -> MagicMock:
|
||||
"""Create a mock execution context."""
|
||||
updater = MagicMock()
|
||||
updater.update_status = AsyncMock()
|
||||
updater.new_agent_message = MagicMock(return_value="mock_message")
|
||||
return updater
|
||||
|
||||
async def test_ignore_user_messages(self, executor: A2AExecutor, mock_updater: MagicMock) -> None:
|
||||
"""Test that messages from USER role are ignored."""
|
||||
# Arrange
|
||||
message = Message(
|
||||
contents=[Content.from_text(text="User input")],
|
||||
role="user",
|
||||
)
|
||||
|
||||
# Act
|
||||
await executor.handle_events(message, mock_updater)
|
||||
|
||||
# Assert
|
||||
mock_updater.update_status.assert_not_called()
|
||||
|
||||
async def test_ignore_messages_with_no_contents(self, executor: A2AExecutor, mock_updater: MagicMock) -> None:
|
||||
"""Test that messages with no contents are ignored."""
|
||||
# Arrange
|
||||
message = Message(
|
||||
contents=[],
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
# Act
|
||||
await executor.handle_events(message, mock_updater)
|
||||
|
||||
# Assert
|
||||
mock_updater.update_status.assert_not_called()
|
||||
|
||||
async def test_handle_text_content(self, executor: A2AExecutor, mock_updater: MagicMock) -> None:
|
||||
"""Test handling messages with text content."""
|
||||
# Arrange
|
||||
text = "Hello, this is a test message"
|
||||
message = Message(
|
||||
contents=[Content.from_text(text=text)],
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
# Act
|
||||
await executor.handle_events(message, mock_updater)
|
||||
|
||||
# Assert
|
||||
mock_updater.update_status.assert_called_once()
|
||||
call_args = mock_updater.update_status.call_args
|
||||
assert call_args.kwargs["state"] == TaskState.working
|
||||
assert mock_updater.new_agent_message.called
|
||||
|
||||
async def test_handle_multiple_text_contents(self, executor: A2AExecutor, mock_updater: MagicMock) -> None:
|
||||
"""Test handling messages with multiple text contents."""
|
||||
# Arrange
|
||||
message = Message(
|
||||
contents=[
|
||||
Content.from_text(text="First message"),
|
||||
Content.from_text(text="Second message"),
|
||||
],
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
# Act
|
||||
await executor.handle_events(message, mock_updater)
|
||||
|
||||
# Assert
|
||||
mock_updater.update_status.assert_called_once()
|
||||
assert mock_updater.new_agent_message.called
|
||||
|
||||
async def test_handle_data_content(self, executor: A2AExecutor, mock_updater: MagicMock) -> None:
|
||||
"""Test handling messages with data content."""
|
||||
# Arrange
|
||||
data = b"test file data"
|
||||
message = Message(
|
||||
contents=[Content.from_data(data=data, media_type="application/octet-stream")],
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
# Act
|
||||
await executor.handle_events(message, mock_updater)
|
||||
|
||||
# Assert
|
||||
mock_updater.update_status.assert_called_once()
|
||||
call_args = mock_updater.update_status.call_args
|
||||
assert call_args.kwargs["state"] == TaskState.working
|
||||
|
||||
async def test_handle_uri_content(self, executor: A2AExecutor, mock_updater: MagicMock) -> None:
|
||||
"""Test handling messages with URI content."""
|
||||
# Arrange
|
||||
uri = "https://example.com/file.pdf"
|
||||
message = Message(
|
||||
contents=[Content.from_uri(uri=uri, media_type="application/pdf")],
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
# Act
|
||||
await executor.handle_events(message, mock_updater)
|
||||
|
||||
# Assert
|
||||
mock_updater.update_status.assert_called_once()
|
||||
call_args = mock_updater.update_status.call_args
|
||||
assert call_args.kwargs["state"] == TaskState.working
|
||||
|
||||
async def test_handle_mixed_content_types(self, executor: A2AExecutor, mock_updater: MagicMock) -> None:
|
||||
"""Test handling messages with mixed content types."""
|
||||
# Arrange
|
||||
data = b"file data"
|
||||
|
||||
message = Message(
|
||||
contents=[
|
||||
Content.from_text(text="Processing file..."),
|
||||
Content.from_data(data=data, media_type="application/octet-stream"),
|
||||
Content.from_uri(uri="https://example.com/reference.pdf", media_type="application/pdf"),
|
||||
],
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
# Act
|
||||
await executor.handle_events(message, mock_updater)
|
||||
|
||||
# Assert
|
||||
mock_updater.update_status.assert_called_once()
|
||||
call_args = mock_updater.update_status.call_args
|
||||
assert call_args.kwargs["state"] == TaskState.working
|
||||
|
||||
async def test_handle_with_additional_properties(self, executor: A2AExecutor, mock_updater: MagicMock) -> None:
|
||||
"""Test handling messages with additional properties metadata."""
|
||||
# Arrange
|
||||
additional_props = {"custom_field": "custom_value", "priority": "high"}
|
||||
message = Message(
|
||||
contents=[Content.from_text(text="Test message")],
|
||||
role="assistant",
|
||||
additional_properties=additional_props,
|
||||
)
|
||||
|
||||
# Act
|
||||
await executor.handle_events(message, mock_updater)
|
||||
|
||||
# Assert
|
||||
mock_updater.update_status.assert_called_once()
|
||||
mock_updater.new_agent_message.assert_called_once()
|
||||
call_args = mock_updater.new_agent_message.call_args
|
||||
assert call_args.kwargs["metadata"] == additional_props
|
||||
|
||||
async def test_handle_with_no_additional_properties(self, executor: A2AExecutor, mock_updater: MagicMock) -> None:
|
||||
"""Test handling messages without additional properties."""
|
||||
# Arrange
|
||||
message = Message(
|
||||
contents=[Content.from_text(text="Test message")],
|
||||
role="assistant",
|
||||
additional_properties=None,
|
||||
)
|
||||
|
||||
# Act
|
||||
await executor.handle_events(message, mock_updater)
|
||||
|
||||
# Assert
|
||||
mock_updater.update_status.assert_called_once()
|
||||
mock_updater.new_agent_message.assert_called_once()
|
||||
call_args = mock_updater.new_agent_message.call_args
|
||||
assert call_args.kwargs["metadata"] == {}
|
||||
|
||||
async def test_parts_list_passed_to_new_agent_message(self, executor: A2AExecutor, mock_updater: MagicMock) -> None:
|
||||
"""Test that parts list is correctly passed to new_agent_message."""
|
||||
# Arrange
|
||||
message = Message(
|
||||
contents=[
|
||||
Content.from_text(text="Message 1"),
|
||||
Content.from_text(text="Message 2"),
|
||||
],
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
# Act
|
||||
await executor.handle_events(message, mock_updater)
|
||||
|
||||
# Assert
|
||||
mock_updater.new_agent_message.assert_called_once()
|
||||
call_kwargs = mock_updater.new_agent_message.call_args.kwargs
|
||||
assert "parts" in call_kwargs
|
||||
parts_list = call_kwargs["parts"]
|
||||
assert len(parts_list) == 2
|
||||
|
||||
async def test_task_state_always_working(self, executor: A2AExecutor, mock_updater: MagicMock) -> None:
|
||||
"""Test that task state is always set to working."""
|
||||
# Arrange
|
||||
message = Message(
|
||||
contents=[Content.from_text(text="Any message")],
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
# Act
|
||||
await executor.handle_events(message, mock_updater)
|
||||
|
||||
# Assert
|
||||
call_kwargs = mock_updater.update_status.call_args.kwargs
|
||||
assert call_kwargs["state"] == TaskState.working
|
||||
|
||||
async def test_handle_agent_response_update_no_streamed_set(
|
||||
self, executor: A2AExecutor, mock_updater: MagicMock
|
||||
) -> None:
|
||||
"""Test handling AgentResponseUpdate (streaming) without a tracking set."""
|
||||
# Arrange
|
||||
update = AgentResponseUpdate(
|
||||
contents=[Content.from_text(text="Streaming chunk")],
|
||||
role="assistant",
|
||||
message_id="msg-1",
|
||||
)
|
||||
mock_updater.add_artifact = AsyncMock()
|
||||
|
||||
# Act
|
||||
await executor.handle_events(update, mock_updater)
|
||||
|
||||
# Assert
|
||||
mock_updater.add_artifact.assert_called_once()
|
||||
call_kwargs = mock_updater.add_artifact.call_args.kwargs
|
||||
assert call_kwargs["artifact_id"] == "msg-1"
|
||||
assert call_kwargs["append"] is None
|
||||
|
||||
async def test_handle_agent_response_update_first_time(
|
||||
self, executor: A2AExecutor, mock_updater: MagicMock
|
||||
) -> None:
|
||||
"""Test handling AgentResponseUpdate (streaming) for the first time with a tracking set."""
|
||||
# Arrange
|
||||
update = AgentResponseUpdate(
|
||||
contents=[Content.from_text(text="Streaming chunk")],
|
||||
role="assistant",
|
||||
message_id="msg-1",
|
||||
)
|
||||
mock_updater.add_artifact = AsyncMock()
|
||||
streamed_artifact_ids = set()
|
||||
|
||||
# Act
|
||||
await executor.handle_events(update, mock_updater, streamed_artifact_ids=streamed_artifact_ids)
|
||||
|
||||
# Assert
|
||||
mock_updater.add_artifact.assert_called_once()
|
||||
call_kwargs = mock_updater.add_artifact.call_args.kwargs
|
||||
assert call_kwargs["append"] is None
|
||||
assert "msg-1" in streamed_artifact_ids
|
||||
|
||||
async def test_handle_agent_response_update_subsequent_time(
|
||||
self, executor: A2AExecutor, mock_updater: MagicMock
|
||||
) -> None:
|
||||
"""Test handling AgentResponseUpdate (streaming) for subsequent times with a tracking set."""
|
||||
# Arrange
|
||||
update = AgentResponseUpdate(
|
||||
contents=[Content.from_text(text="Next chunk")],
|
||||
role="assistant",
|
||||
message_id="msg-1",
|
||||
)
|
||||
mock_updater.add_artifact = AsyncMock()
|
||||
streamed_artifact_ids = {"msg-1"}
|
||||
|
||||
# Act
|
||||
await executor.handle_events(update, mock_updater, streamed_artifact_ids=streamed_artifact_ids)
|
||||
|
||||
# Assert
|
||||
mock_updater.add_artifact.assert_called_once()
|
||||
call_kwargs = mock_updater.add_artifact.call_args.kwargs
|
||||
assert call_kwargs["append"] is True
|
||||
|
||||
async def test_handle_unsupported_content_type(self, executor: A2AExecutor, mock_updater: MagicMock) -> None:
|
||||
"""Test handling messages with unsupported content types."""
|
||||
# Arrange
|
||||
message = Message(
|
||||
contents=[Content(type="unknown", text="Some text")],
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch("agent_framework_a2a._a2a_executor.logger") as mock_logger:
|
||||
await executor.handle_events(message, mock_updater)
|
||||
|
||||
# Assert
|
||||
mock_logger.warning.assert_called_once()
|
||||
mock_updater.update_status.assert_not_called()
|
||||
|
||||
|
||||
class TestA2AExecutorIntegration:
|
||||
"""Integration tests for A2AExecutor."""
|
||||
|
||||
async def test_full_execution_flow_with_responses(
|
||||
self,
|
||||
executor: A2AExecutor,
|
||||
mock_request_context: MagicMock,
|
||||
mock_event_queue: MagicMock,
|
||||
mock_task: Task,
|
||||
) -> None:
|
||||
"""Arrange: Create executor with all mocked dependencies
|
||||
Act: Execute full flow from request to completion
|
||||
Assert: All components interact correctly
|
||||
"""
|
||||
# Arrange
|
||||
mock_request_context.get_user_input = MagicMock(return_value="Hello agent")
|
||||
mock_request_context.current_task = mock_task
|
||||
mock_request_context.context_id = "ctx-123"
|
||||
mock_request_context.message = MagicMock()
|
||||
|
||||
response = MagicMock(spec=AgentResponse)
|
||||
response_message = MagicMock(spec=Message)
|
||||
response.messages = [response_message]
|
||||
response_message.contents = [Content.from_text(text="Hello user")]
|
||||
response_message.role = "assistant"
|
||||
response_message.additional_properties = None
|
||||
|
||||
executor._agent.run = AsyncMock(return_value=response)
|
||||
executor._agent.create_session = MagicMock()
|
||||
executor.handle_events = AsyncMock()
|
||||
|
||||
with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class:
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.submit = AsyncMock()
|
||||
mock_updater.start_work = AsyncMock()
|
||||
mock_updater.complete = AsyncMock()
|
||||
mock_updater.update_status = AsyncMock()
|
||||
mock_updater_class.return_value = mock_updater
|
||||
|
||||
# Act
|
||||
await executor.execute(mock_request_context, mock_event_queue)
|
||||
|
||||
# Assert
|
||||
mock_updater.submit.assert_called_once()
|
||||
mock_updater.start_work.assert_called_once()
|
||||
executor.handle_events.assert_called_once()
|
||||
mock_updater.complete.assert_called_once()
|
||||
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework_a2a._utils import get_uri_data
|
||||
|
||||
|
||||
def test_get_uri_data_valid() -> None:
|
||||
"""Test get_uri_data with valid data URIs."""
|
||||
# Simple text/plain
|
||||
uri = "data:text/plain;base64,SGVsbG8sIFdvcmxkIQ=="
|
||||
assert get_uri_data(uri) == "SGVsbG8sIFdvcmxkIQ=="
|
||||
|
||||
# Image png
|
||||
uri = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
|
||||
assert get_uri_data(uri) == "iVBORw0KGgoAAAANSUhEUgfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
|
||||
|
||||
# Application octet-stream
|
||||
uri = "data:application/octet-stream;base64,AQIDBA=="
|
||||
assert get_uri_data(uri) == "AQIDBA=="
|
||||
|
||||
|
||||
def test_get_uri_data_invalid_format() -> None:
|
||||
"""Test get_uri_data with invalid URI formats."""
|
||||
invalid_uris = [
|
||||
"not-a-uri",
|
||||
"http://example.com",
|
||||
"data:text/plain;SGVsbG8sIFdvcmxkIQ==", # Missing base64 marker
|
||||
"data:base64,SGVsbG8sIFdvcmxkIQ==", # Missing media type
|
||||
"data:text/plain;charset=utf-8;base64,SGVsbG8sIFdvcmxkIQ==", # Extra parameters (current regex doesn't support)
|
||||
"data:text/plain;base64,SGVsbG8sIFdvcmxkIQ== extra",
|
||||
]
|
||||
for uri in invalid_uris:
|
||||
with pytest.raises(ValueError, match="Invalid data URI format"):
|
||||
get_uri_data(uri)
|
||||
|
||||
|
||||
def test_get_uri_data_empty() -> None:
|
||||
"""Test get_uri_data with empty string."""
|
||||
with pytest.raises(ValueError, match="Invalid data URI format"):
|
||||
get_uri_data("")
|
||||
@@ -7,6 +7,7 @@ This module lazily re-exports objects from:
|
||||
|
||||
Supported classes:
|
||||
- A2AAgent
|
||||
- A2AExecutor
|
||||
"""
|
||||
|
||||
import importlib
|
||||
@@ -14,7 +15,7 @@ from typing import Any
|
||||
|
||||
IMPORT_PATH = "agent_framework_a2a"
|
||||
PACKAGE_NAME = "agent-framework-a2a"
|
||||
_IMPORTS = ["A2AAgent"]
|
||||
_IMPORTS = ["A2AAgent", "A2AExecutor"]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from agent_framework_a2a import (
|
||||
A2AAgent,
|
||||
)
|
||||
from agent_framework_a2a import A2AAgent, A2AExecutor
|
||||
|
||||
__all__ = [
|
||||
"A2AAgent",
|
||||
]
|
||||
__all__ = ["A2AAgent", "A2AExecutor"]
|
||||
|
||||
@@ -12,6 +12,7 @@ The remaining files are supporting modules used by the server:
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| [`agent_framework_to_a2a.py`](agent_framework_to_a2a.py) | Exposes an agent_framework agent as an A2A-compliant server. Demonstrates how to wrap an agent_framework agent and expose it as an A2A service that other A2A clients can discover and communicate with. |
|
||||
| [`agent_definitions.py`](agent_definitions.py) | Agent and AgentCard factory definitions for invoice, policy, and logistics agents. |
|
||||
| [`agent_executor.py`](agent_executor.py) | Bridges the a2a-sdk `AgentExecutor` interface to Agent Framework agents. |
|
||||
| [`invoice_data.py`](invoice_data.py) | Mock invoice data and tool functions for the invoice agent. |
|
||||
@@ -60,6 +61,9 @@ In a separate terminal (from the same directory), point the client at a running
|
||||
```powershell
|
||||
$env:A2A_AGENT_HOST = "http://localhost:5001/"
|
||||
uv run python agent_with_a2a.py
|
||||
|
||||
# A2A server exposing an agent_framework agent
|
||||
uv run python agent_framework_to_a2a.py
|
||||
```
|
||||
|
||||
### 3. Run the Function Tools Sample
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import uvicorn
|
||||
from a2a.server.apps import A2AStarletteApplication
|
||||
from a2a.server.request_handlers import DefaultRequestHandler
|
||||
from a2a.server.tasks import InMemoryTaskStore
|
||||
from a2a.types import (
|
||||
AgentCapabilities,
|
||||
AgentCard,
|
||||
AgentSkill,
|
||||
)
|
||||
from agent_framework import Agent
|
||||
from agent_framework.a2a import A2AExecutor
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# --8<-- [start:AgentSkill]
|
||||
flight_skill = AgentSkill(
|
||||
id="Flight_Booking",
|
||||
name="Flight Booking",
|
||||
description="Search and book flights across Europe.",
|
||||
tags=["flights", "travel", "europe"],
|
||||
examples=[],
|
||||
)
|
||||
hotel_skill = AgentSkill(
|
||||
id="Hotel_Booking",
|
||||
name="Hotel Booking",
|
||||
description="Search and book hotels across Europe.",
|
||||
tags=["hotels", "travel", "accommodation"],
|
||||
examples=[],
|
||||
)
|
||||
# --8<-- [end:AgentSkill]
|
||||
|
||||
# --8<-- [start:AgentCard]
|
||||
# This will be the public-facing agent card
|
||||
public_agent_card = AgentCard(
|
||||
name="Europe Travel Agent",
|
||||
description="A helpful Europe Travel Agent that can help users search and book flights and hotels across Europe.",
|
||||
url="http://localhost:9999/",
|
||||
version="1.0.0",
|
||||
defaultInputModes=["text"],
|
||||
defaultOutputModes=["text"],
|
||||
capabilities=AgentCapabilities(streaming=True),
|
||||
skills=[flight_skill, hotel_skill],
|
||||
)
|
||||
# --8<-- [end:AgentCard]
|
||||
|
||||
agent = Agent(
|
||||
client=OpenAIChatClient(),
|
||||
name="Europe Travel Agent",
|
||||
instructions="You are a helpful Europe Travel Agent. You can help users search and book flights and hotels across Europe."
|
||||
)
|
||||
|
||||
request_handler = DefaultRequestHandler(
|
||||
agent_executor=A2AExecutor(agent),
|
||||
task_store=InMemoryTaskStore(),
|
||||
)
|
||||
|
||||
server = A2AStarletteApplication(
|
||||
agent_card=public_agent_card,
|
||||
http_handler=request_handler,
|
||||
)
|
||||
|
||||
server = server.build()
|
||||
# print(schemas.get_schema(server.routes))
|
||||
|
||||
uvicorn.run(server, host="0.0.0.0", port=9999)
|
||||
Reference in New Issue
Block a user