mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add long-running agents and background responses support (#3808)
* Python: Add long-running agents and background responses support - Add ContinuationToken TypedDict to core types - Add continuation_token field to ChatResponse, ChatResponseUpdate, AgentResponse, and AgentResponseUpdate - Add background and continuation_token options to OpenAIResponsesOptions - Implement polling via responses.retrieve() and streaming resumption in RawOpenAIResponsesClient - Propagate continuation tokens through agent run() and map_chat_to_agent_update - Fix streaming telemetry 'Failed to detach context' error in both ChatTelemetryLayer and AgentTelemetryLayer by avoiding trace.use_span() context attachment for async-managed spans - Add 14 unit tests for continuation token types and background flows - Add background_responses sample showing polling and stream resumption Fixes #2478 * Python: Add A2A long-running task support via ContinuationToken - Make ContinuationToken provider-agnostic (total=False, optional task_id/context_id fields) - Add background param to A2AAgent.run() controlling token emission - Add poll_task() for single-request task state retrieval - Add resubscribe support via continuation_token param on run() - Extract _updates_from_task() and _map_a2a_stream() for cleaner code - Streamline run()/streaming by removing intermediate _stream_updates wrapper - Update A2A sample to show background=False (default) with link to background_responses sample - Remove stale BareAgent from __all__ - Add 12 new A2A continuation token tests * fix logic for overriding continuation token when done * refactored ContinuationToken setup
This commit is contained in:
committed by
GitHub
Unverified
parent
32ba81e990
commit
35097d8c75
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- **agent-framework-core**: Add long-running agents and background responses support with `ContinuationToken` TypedDict, `background` option in `OpenAIResponsesOptions`, and continuation token propagation through response types ([#2478](https://github.com/microsoft/agent-framework/issues/2478))
|
||||
|
||||
## [1.0.0b260130] - 2026-01-30
|
||||
|
||||
### Added
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
from ._agent import A2AAgent
|
||||
from ._agent import A2AAgent, A2AContinuationToken
|
||||
|
||||
try:
|
||||
__version__ = importlib.metadata.version(__name__)
|
||||
@@ -11,5 +11,6 @@ except importlib.metadata.PackageNotFoundError:
|
||||
|
||||
__all__ = [
|
||||
"A2AAgent",
|
||||
"A2AContinuationToken",
|
||||
"__version__",
|
||||
]
|
||||
|
||||
@@ -20,6 +20,8 @@ from a2a.types import (
|
||||
FileWithUri,
|
||||
Message,
|
||||
Task,
|
||||
TaskIdParams,
|
||||
TaskQueryParams,
|
||||
TaskState,
|
||||
TextPart,
|
||||
TransportProtocol,
|
||||
@@ -34,21 +36,39 @@ from agent_framework import (
|
||||
BaseAgent,
|
||||
ChatMessage,
|
||||
Content,
|
||||
ContinuationToken,
|
||||
ResponseStream,
|
||||
normalize_messages,
|
||||
prepend_agent_framework_to_user_agent,
|
||||
)
|
||||
from agent_framework.observability import AgentTelemetryLayer
|
||||
|
||||
__all__ = ["A2AAgent"]
|
||||
__all__ = ["A2AAgent", "A2AContinuationToken"]
|
||||
|
||||
URI_PATTERN = re.compile(r"^data:(?P<media_type>[^;]+);base64,(?P<base64_data>[A-Za-z0-9+/=]+)$")
|
||||
|
||||
|
||||
class A2AContinuationToken(ContinuationToken):
|
||||
"""Continuation token for A2A protocol long-running tasks."""
|
||||
|
||||
task_id: str
|
||||
"""A2A protocol task ID."""
|
||||
context_id: str
|
||||
"""A2A protocol context ID."""
|
||||
|
||||
|
||||
TERMINAL_TASK_STATES = [
|
||||
TaskState.completed,
|
||||
TaskState.failed,
|
||||
TaskState.canceled,
|
||||
TaskState.rejected,
|
||||
]
|
||||
IN_PROGRESS_TASK_STATES = [
|
||||
TaskState.submitted,
|
||||
TaskState.working,
|
||||
TaskState.input_required,
|
||||
TaskState.auth_required,
|
||||
]
|
||||
|
||||
|
||||
def _get_uri_data(uri: str) -> str:
|
||||
@@ -193,6 +213,8 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
thread: AgentThread | None = None,
|
||||
continuation_token: A2AContinuationToken | None = None,
|
||||
background: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@@ -203,6 +225,8 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
*,
|
||||
stream: Literal[True],
|
||||
thread: AgentThread | None = None,
|
||||
continuation_token: A2AContinuationToken | None = None,
|
||||
background: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
@@ -212,85 +236,62 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
*,
|
||||
stream: bool = False,
|
||||
thread: AgentThread | None = None,
|
||||
continuation_token: A2AContinuationToken | None = None,
|
||||
background: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""Get a response from the agent.
|
||||
|
||||
This method returns the final result of the agent's execution
|
||||
as a single AgentResponse object when stream=False. When stream=True,
|
||||
it returns a ResponseStream that yields AgentResponseUpdate objects.
|
||||
|
||||
Args:
|
||||
messages: The message(s) to send to the agent.
|
||||
|
||||
Keyword Args:
|
||||
stream: Whether to stream the response. Defaults to False.
|
||||
thread: The conversation thread associated with the message(s).
|
||||
continuation_token: Optional token to resume a long-running task
|
||||
instead of starting a new one.
|
||||
background: When True, in-progress task updates surface continuation
|
||||
tokens so the caller can poll or resubscribe later. When False
|
||||
(default), the agent internally waits for the task to complete.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
When stream=False: An Awaitable[AgentResponse].
|
||||
When stream=True: A ResponseStream of AgentResponseUpdate items.
|
||||
"""
|
||||
if continuation_token is not None:
|
||||
a2a_stream: AsyncIterable[Any] = self.client.resubscribe(TaskIdParams(id=continuation_token["task_id"]))
|
||||
else:
|
||||
normalized_messages = normalize_messages(messages)
|
||||
a2a_message = self._prepare_message_for_a2a(normalized_messages[-1])
|
||||
a2a_stream = self.client.send_message(a2a_message)
|
||||
|
||||
response = ResponseStream(
|
||||
self._map_a2a_stream(a2a_stream, background=background),
|
||||
finalizer=lambda updates: AgentResponse.from_updates(list(updates)),
|
||||
)
|
||||
if stream:
|
||||
return self._run_stream_impl(messages=messages, thread=thread, **kwargs)
|
||||
return self._run_impl(messages=messages, thread=thread, **kwargs)
|
||||
return response
|
||||
return response.get_final_response()
|
||||
|
||||
async def _run_impl(
|
||||
async def _map_a2a_stream(
|
||||
self,
|
||||
messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
|
||||
a2a_stream: AsyncIterable[Any],
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentResponse[Any]:
|
||||
"""Non-streaming implementation of run."""
|
||||
# Collect all updates and use framework to consolidate updates into response
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
async for update in self._stream_updates(messages, thread=thread, **kwargs):
|
||||
updates.append(update)
|
||||
return AgentResponse.from_updates(updates)
|
||||
|
||||
def _run_stream_impl(
|
||||
self,
|
||||
messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""Streaming implementation of run."""
|
||||
|
||||
def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse[Any]:
|
||||
return AgentResponse.from_updates(list(updates))
|
||||
|
||||
return ResponseStream(self._stream_updates(messages, thread=thread, **kwargs), finalizer=_finalize)
|
||||
|
||||
async def _stream_updates(
|
||||
self,
|
||||
messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
background: bool = False,
|
||||
) -> AsyncIterable[AgentResponseUpdate]:
|
||||
"""Internal method to stream updates from the A2A agent.
|
||||
"""Map raw A2A protocol items to AgentResponseUpdates.
|
||||
|
||||
Args:
|
||||
messages: The message(s) to send to the agent.
|
||||
a2a_stream: The raw A2A event stream.
|
||||
|
||||
Keyword Args:
|
||||
thread: The conversation thread associated with the message(s).
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Yields:
|
||||
AgentResponseUpdate items from the A2A agent.
|
||||
background: When False, in-progress task updates are silently
|
||||
consumed (the stream keeps iterating until a terminal state).
|
||||
When True, they are yielded with a continuation token.
|
||||
"""
|
||||
normalized_messages = normalize_messages(messages)
|
||||
a2a_message = self._prepare_message_for_a2a(normalized_messages[-1])
|
||||
|
||||
response_stream = self.client.send_message(a2a_message)
|
||||
|
||||
async for item in response_stream:
|
||||
async for item in a2a_stream:
|
||||
if isinstance(item, Message):
|
||||
# Process A2A Message
|
||||
contents = self._parse_contents_from_a2a(item.parts)
|
||||
yield AgentResponseUpdate(
|
||||
contents=contents,
|
||||
@@ -300,33 +301,82 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
)
|
||||
elif isinstance(item, tuple) and len(item) == 2: # ClientEvent = (Task, UpdateEvent)
|
||||
task, _update_event = item
|
||||
if isinstance(task, Task) and task.status.state in TERMINAL_TASK_STATES:
|
||||
# Convert Task artifacts to ChatMessages and yield as separate updates
|
||||
task_messages = self._parse_messages_from_task(task)
|
||||
if task_messages:
|
||||
for message in task_messages:
|
||||
# Use the artifact's ID from raw_representation as message_id for unique identification
|
||||
artifact_id = getattr(message.raw_representation, "artifact_id", None)
|
||||
yield AgentResponseUpdate(
|
||||
contents=message.contents,
|
||||
role=message.role,
|
||||
response_id=task.id,
|
||||
message_id=artifact_id,
|
||||
raw_representation=task,
|
||||
)
|
||||
else:
|
||||
# Empty task
|
||||
yield AgentResponseUpdate(
|
||||
contents=[],
|
||||
role="assistant",
|
||||
response_id=task.id,
|
||||
raw_representation=task,
|
||||
)
|
||||
if isinstance(task, Task):
|
||||
for update in self._updates_from_task(task, background=background):
|
||||
yield update
|
||||
else:
|
||||
# Unknown response type
|
||||
msg = f"Only Message and Task responses are supported from A2A agents. Received: {type(item)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Task helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _updates_from_task(self, task: Task, *, background: bool = False) -> list[AgentResponseUpdate]:
|
||||
"""Convert an A2A Task into AgentResponseUpdate(s).
|
||||
|
||||
Terminal tasks produce updates from their artifacts/history.
|
||||
In-progress tasks produce a continuation token update only when
|
||||
``background=True``; otherwise they are silently skipped so the
|
||||
caller keeps consuming the stream until completion.
|
||||
"""
|
||||
if task.status.state in TERMINAL_TASK_STATES:
|
||||
task_messages = self._parse_messages_from_task(task)
|
||||
if task_messages:
|
||||
return [
|
||||
AgentResponseUpdate(
|
||||
contents=message.contents,
|
||||
role=message.role,
|
||||
response_id=task.id,
|
||||
message_id=getattr(message.raw_representation, "artifact_id", None),
|
||||
raw_representation=task,
|
||||
)
|
||||
for message in task_messages
|
||||
]
|
||||
return [AgentResponseUpdate(contents=[], role="assistant", response_id=task.id, raw_representation=task)]
|
||||
|
||||
if background and task.status.state in IN_PROGRESS_TASK_STATES:
|
||||
token = self._build_continuation_token(task)
|
||||
return [
|
||||
AgentResponseUpdate(
|
||||
contents=[],
|
||||
role="assistant",
|
||||
response_id=task.id,
|
||||
continuation_token=token,
|
||||
raw_representation=task,
|
||||
)
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _build_continuation_token(task: Task) -> A2AContinuationToken | None:
|
||||
"""Build an A2AContinuationToken from an A2A Task if it is still in progress."""
|
||||
if task.status.state in IN_PROGRESS_TASK_STATES:
|
||||
return A2AContinuationToken(task_id=task.id, context_id=task.context_id)
|
||||
return None
|
||||
|
||||
async def poll_task(self, continuation_token: A2AContinuationToken) -> AgentResponse[Any]:
|
||||
"""Poll for the current state of a long-running A2A task.
|
||||
|
||||
Unlike ``run(continuation_token=...)``, which resubscribes to the SSE
|
||||
stream, this performs a single request to retrieve the task state.
|
||||
|
||||
Args:
|
||||
continuation_token: A token previously obtained from a response's
|
||||
``continuation_token`` field.
|
||||
|
||||
Returns:
|
||||
An AgentResponse whose ``continuation_token`` is set when the task
|
||||
is still in progress, or ``None`` when it has reached a terminal state.
|
||||
"""
|
||||
task_id = continuation_token["task_id"]
|
||||
task = await self.client.get_task(TaskQueryParams(id=task_id))
|
||||
updates = self._updates_from_task(task, background=True)
|
||||
if updates:
|
||||
return AgentResponse.from_updates(updates)
|
||||
return AgentResponse(messages=[], response_id=task.id, raw_representation=task)
|
||||
|
||||
def _prepare_message_for_a2a(self, message: ChatMessage) -> A2AMessage:
|
||||
"""Prepare a ChatMessage for the A2A protocol.
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from agent_framework import (
|
||||
from agent_framework.a2a import A2AAgent
|
||||
from pytest import fixture, raises
|
||||
|
||||
from agent_framework_a2a import A2AContinuationToken
|
||||
from agent_framework_a2a._agent import _get_uri_data # type: ignore
|
||||
|
||||
|
||||
@@ -38,6 +39,8 @@ class MockA2AClient:
|
||||
def __init__(self) -> None:
|
||||
self.call_count: int = 0
|
||||
self.responses: list[Any] = []
|
||||
self.resubscribe_responses: list[Any] = []
|
||||
self.get_task_response: Task | None = None
|
||||
|
||||
def add_message_response(self, message_id: str, text: str, role: str = "agent") -> None:
|
||||
"""Add a mock Message response."""
|
||||
@@ -80,6 +83,18 @@ class MockA2AClient:
|
||||
client_event = (task, update_event)
|
||||
self.responses.append(client_event)
|
||||
|
||||
def add_in_progress_task_response(
|
||||
self,
|
||||
task_id: str,
|
||||
context_id: str = "test-context",
|
||||
state: TaskState = TaskState.working,
|
||||
) -> None:
|
||||
"""Add a mock in-progress Task response (non-terminal)."""
|
||||
status = TaskStatus(state=state, message=None)
|
||||
task = Task(id=task_id, context_id=context_id, status=status)
|
||||
client_event = (task, None)
|
||||
self.responses.append(client_event)
|
||||
|
||||
async def send_message(self, message: Any) -> AsyncIterator[Any]:
|
||||
"""Mock send_message method that yields responses."""
|
||||
self.call_count += 1
|
||||
@@ -88,6 +103,22 @@ class MockA2AClient:
|
||||
response = self.responses.pop(0)
|
||||
yield response
|
||||
|
||||
async def resubscribe(self, request: Any) -> AsyncIterator[Any]:
|
||||
"""Mock resubscribe method that yields responses."""
|
||||
self.call_count += 1
|
||||
|
||||
for response in self.resubscribe_responses:
|
||||
yield response
|
||||
self.resubscribe_responses.clear()
|
||||
|
||||
async def get_task(self, request: Any) -> Task:
|
||||
"""Mock get_task method that returns a task."""
|
||||
self.call_count += 1
|
||||
if self.get_task_response is not None:
|
||||
return self.get_task_response
|
||||
msg = "No get_task response configured"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@fixture
|
||||
def mock_a2a_client() -> MockA2AClient:
|
||||
@@ -598,3 +629,158 @@ def test_a2a_agent_initialization_with_timeout_parameter() -> None:
|
||||
|
||||
# Verify it's an httpx.Timeout object with our custom timeout applied to all components
|
||||
assert isinstance(timeout_arg, httpx.Timeout)
|
||||
|
||||
|
||||
# region Continuation Token Tests
|
||||
|
||||
|
||||
async def test_working_task_emits_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that a working (non-terminal) task yields an update with a continuation token when background=True."""
|
||||
mock_a2a_client.add_in_progress_task_response("task-wip", context_id="ctx-1", state=TaskState.working)
|
||||
|
||||
response = await a2a_agent.run("Start long task", background=True)
|
||||
|
||||
assert isinstance(response, AgentResponse)
|
||||
assert response.continuation_token is not None
|
||||
assert response.continuation_token["task_id"] == "task-wip"
|
||||
assert response.continuation_token["context_id"] == "ctx-1"
|
||||
|
||||
|
||||
async def test_submitted_task_emits_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that a submitted task yields a continuation token when background=True."""
|
||||
mock_a2a_client.add_in_progress_task_response("task-sub", state=TaskState.submitted)
|
||||
|
||||
response = await a2a_agent.run("Submit task", background=True)
|
||||
|
||||
assert response.continuation_token is not None
|
||||
assert response.continuation_token["task_id"] == "task-sub"
|
||||
|
||||
|
||||
async def test_input_required_task_emits_continuation_token(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Test that an input_required task yields a continuation token when background=True."""
|
||||
mock_a2a_client.add_in_progress_task_response("task-input", state=TaskState.input_required)
|
||||
|
||||
response = await a2a_agent.run("Need input", background=True)
|
||||
|
||||
assert response.continuation_token is not None
|
||||
assert response.continuation_token["task_id"] == "task-input"
|
||||
|
||||
|
||||
async def test_working_task_no_token_without_background(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that background=False (default) does not emit continuation tokens for in-progress tasks."""
|
||||
mock_a2a_client.add_in_progress_task_response("task-fg", context_id="ctx-fg", state=TaskState.working)
|
||||
|
||||
response = await a2a_agent.run("Foreground task")
|
||||
|
||||
assert response.continuation_token is None
|
||||
|
||||
|
||||
async def test_completed_task_has_no_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that a completed task does not set a continuation token."""
|
||||
mock_a2a_client.add_task_response("task-done", [{"id": "art-1", "content": "Result"}])
|
||||
|
||||
response = await a2a_agent.run("Quick task")
|
||||
|
||||
assert response.continuation_token is None
|
||||
assert len(response.messages) == 1
|
||||
assert response.messages[0].text == "Result"
|
||||
|
||||
|
||||
async def test_streaming_emits_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that streaming with background=True yields updates with continuation tokens."""
|
||||
mock_a2a_client.add_in_progress_task_response("task-stream", context_id="ctx-s", state=TaskState.working)
|
||||
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
async for update in a2a_agent.run("Stream task", stream=True, background=True):
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 1
|
||||
assert updates[0].continuation_token is not None
|
||||
assert updates[0].continuation_token["task_id"] == "task-stream"
|
||||
assert updates[0].continuation_token["context_id"] == "ctx-s"
|
||||
|
||||
|
||||
async def test_resume_via_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that run() with continuation_token uses resubscribe instead of send_message."""
|
||||
# Set up the resubscribe response (completed task)
|
||||
status = TaskStatus(state=TaskState.completed, message=None)
|
||||
artifact = Artifact(
|
||||
artifact_id="art-resume",
|
||||
name="result",
|
||||
parts=[Part(root=TextPart(text="Resumed result"))],
|
||||
)
|
||||
task = Task(id="task-resume", context_id="ctx-r", status=status, artifacts=[artifact])
|
||||
mock_a2a_client.resubscribe_responses.append((task, None))
|
||||
|
||||
token = A2AContinuationToken(task_id="task-resume", context_id="ctx-r")
|
||||
response = await a2a_agent.run(continuation_token=token)
|
||||
|
||||
assert isinstance(response, AgentResponse)
|
||||
assert len(response.messages) == 1
|
||||
assert response.messages[0].text == "Resumed result"
|
||||
assert response.continuation_token is None
|
||||
|
||||
|
||||
async def test_resume_streaming_via_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test that streaming run() with continuation_token and background=True uses resubscribe."""
|
||||
# Still working
|
||||
status_wip = TaskStatus(state=TaskState.working, message=None)
|
||||
task_wip = Task(id="task-rs", context_id="ctx-rs", status=status_wip)
|
||||
# Then completed
|
||||
status_done = TaskStatus(state=TaskState.completed, message=None)
|
||||
artifact = Artifact(
|
||||
artifact_id="art-rs",
|
||||
name="result",
|
||||
parts=[Part(root=TextPart(text="Stream resumed"))],
|
||||
)
|
||||
task_done = Task(id="task-rs", context_id="ctx-rs", status=status_done, artifacts=[artifact])
|
||||
mock_a2a_client.resubscribe_responses.extend([(task_wip, None), (task_done, None)])
|
||||
|
||||
token = A2AContinuationToken(task_id="task-rs", context_id="ctx-rs")
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
async for update in a2a_agent.run(stream=True, continuation_token=token, background=True):
|
||||
updates.append(update)
|
||||
|
||||
# First update: in-progress with token, second: completed with content
|
||||
assert len(updates) == 2
|
||||
assert updates[0].continuation_token is not None
|
||||
assert updates[0].continuation_token["task_id"] == "task-rs"
|
||||
assert updates[1].continuation_token is None
|
||||
assert updates[1].contents[0].text == "Stream resumed"
|
||||
|
||||
|
||||
async def test_poll_task_in_progress(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test poll_task returns continuation token when task is still in progress."""
|
||||
status = TaskStatus(state=TaskState.working, message=None)
|
||||
mock_a2a_client.get_task_response = Task(id="task-poll", context_id="ctx-p", status=status)
|
||||
|
||||
token = A2AContinuationToken(task_id="task-poll", context_id="ctx-p")
|
||||
response = await a2a_agent.poll_task(token)
|
||||
|
||||
assert response.continuation_token is not None
|
||||
assert response.continuation_token["task_id"] == "task-poll"
|
||||
|
||||
|
||||
async def test_poll_task_completed(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None:
|
||||
"""Test poll_task returns result with no continuation token when task is complete."""
|
||||
status = TaskStatus(state=TaskState.completed, message=None)
|
||||
artifact = Artifact(
|
||||
artifact_id="art-poll",
|
||||
name="result",
|
||||
parts=[Part(root=TextPart(text="Poll result"))],
|
||||
)
|
||||
mock_a2a_client.get_task_response = Task(
|
||||
id="task-poll-done", context_id="ctx-pd", status=status, artifacts=[artifact]
|
||||
)
|
||||
|
||||
token = A2AContinuationToken(task_id="task-poll-done", context_id="ctx-pd")
|
||||
response = await a2a_agent.poll_task(token)
|
||||
|
||||
assert response.continuation_token is None
|
||||
assert len(response.messages) == 1
|
||||
assert response.messages[0].text == "Poll result"
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -165,7 +165,7 @@ class _RunContext(TypedDict):
|
||||
finalize_kwargs: dict[str, Any]
|
||||
|
||||
|
||||
__all__ = ["BareAgent", "BaseAgent", "ChatAgent", "RawChatAgent", "SupportsAgentRun"]
|
||||
__all__ = ["BaseAgent", "ChatAgent", "RawChatAgent", "SupportsAgentRun"]
|
||||
|
||||
|
||||
# region Agent Protocol
|
||||
@@ -523,10 +523,6 @@ class BaseAgent(SerializationMixin):
|
||||
return agent_tool
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
BareAgent = BaseAgent
|
||||
|
||||
|
||||
# region ChatAgent
|
||||
|
||||
|
||||
@@ -908,6 +904,7 @@ class RawChatAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
usage_details=response.usage_details,
|
||||
value=response.value,
|
||||
response_format=response_format,
|
||||
continuation_token=response.continuation_token,
|
||||
raw_representation=response,
|
||||
additional_properties=response.additional_properties,
|
||||
)
|
||||
|
||||
@@ -36,6 +36,7 @@ __all__ = [
|
||||
"ChatResponse",
|
||||
"ChatResponseUpdate",
|
||||
"Content",
|
||||
"ContinuationToken",
|
||||
"FinalT",
|
||||
"FinishReason",
|
||||
"FinishReasonLiteral",
|
||||
@@ -1760,6 +1761,7 @@ def _process_update(response: ChatResponse | AgentResponse, update: ChatResponse
|
||||
response.finish_reason = update.finish_reason
|
||||
if update.model_id is not None:
|
||||
response.model_id = update.model_id
|
||||
response.continuation_token = update.continuation_token
|
||||
|
||||
|
||||
def _coalesce_text_content(contents: list[Content], type_str: Literal["text", "text_reasoning"]) -> None:
|
||||
@@ -1796,6 +1798,39 @@ def _finalize_response(response: ChatResponse | AgentResponse) -> None:
|
||||
_coalesce_text_content(msg.contents, "text_reasoning")
|
||||
|
||||
|
||||
# region ContinuationToken
|
||||
|
||||
|
||||
class ContinuationToken(TypedDict):
|
||||
"""Opaque token for resuming long-running agent operations.
|
||||
|
||||
A JSON-serializable dict used to poll for completion or resume a
|
||||
streaming response. Presence on a response indicates the operation
|
||||
is still in progress; ``None`` means the operation is complete.
|
||||
|
||||
Each provider subclasses this with its own fields; consumers should
|
||||
treat the token as opaque and simply pass it back to the same agent.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
import json
|
||||
|
||||
# Persist token across restarts
|
||||
token_json = json.dumps(response.continuation_token)
|
||||
|
||||
# Restore and resume
|
||||
token = json.loads(token_json)
|
||||
response = await agent.run(
|
||||
thread=thread,
|
||||
options={"continuation_token": token},
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
class ChatResponse(SerializationMixin, Generic[ResponseModelT]):
|
||||
"""Represents the response to a chat request.
|
||||
|
||||
@@ -1861,6 +1896,7 @@ class ChatResponse(SerializationMixin, Generic[ResponseModelT]):
|
||||
usage_details: UsageDetails | None = None,
|
||||
value: ResponseModelT | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
continuation_token: ContinuationToken | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
raw_representation: Any | None = None,
|
||||
) -> None:
|
||||
@@ -1876,6 +1912,8 @@ class ChatResponse(SerializationMixin, Generic[ResponseModelT]):
|
||||
usage_details: Optional usage details for the chat response.
|
||||
value: Optional value of the structured output.
|
||||
response_format: Optional response format for the chat response.
|
||||
continuation_token: Optional token for resuming a long-running background operation.
|
||||
When present, indicates the operation is still in progress.
|
||||
additional_properties: Optional additional properties associated with the chat response.
|
||||
raw_representation: Optional raw representation of the chat response from an underlying implementation.
|
||||
"""
|
||||
@@ -1907,6 +1945,7 @@ class ChatResponse(SerializationMixin, Generic[ResponseModelT]):
|
||||
self._response_format: type[BaseModel] | None = response_format
|
||||
self._value_parsed: bool = value is not None
|
||||
self.additional_properties = additional_properties or {}
|
||||
self.continuation_token = continuation_token
|
||||
self.raw_representation: Any | list[Any] | None = raw_representation
|
||||
|
||||
@overload
|
||||
@@ -2109,6 +2148,7 @@ class ChatResponseUpdate(SerializationMixin):
|
||||
model_id: str | None = None,
|
||||
created_at: CreatedAtT | None = None,
|
||||
finish_reason: FinishReasonLiteral | FinishReason | None = None,
|
||||
continuation_token: ContinuationToken | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
raw_representation: Any | None = None,
|
||||
) -> None:
|
||||
@@ -2124,6 +2164,8 @@ class ChatResponseUpdate(SerializationMixin):
|
||||
model_id: Optional model ID associated with this response update.
|
||||
created_at: Optional timestamp for the chat response update.
|
||||
finish_reason: Optional finish reason for the operation.
|
||||
continuation_token: Optional token for resuming a long-running background operation.
|
||||
When present, indicates the operation is still in progress.
|
||||
additional_properties: Optional additional properties associated with the chat response update.
|
||||
raw_representation: Optional raw representation of the chat response update
|
||||
from an underlying implementation.
|
||||
@@ -2151,6 +2193,7 @@ class ChatResponseUpdate(SerializationMixin):
|
||||
self.model_id = model_id
|
||||
self.created_at = created_at
|
||||
self.finish_reason = finish_reason
|
||||
self.continuation_token = continuation_token
|
||||
self.additional_properties = additional_properties
|
||||
self.raw_representation = raw_representation
|
||||
|
||||
@@ -2222,6 +2265,7 @@ class AgentResponse(SerializationMixin, Generic[ResponseModelT]):
|
||||
usage_details: UsageDetails | None = None,
|
||||
value: ResponseModelT | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
continuation_token: ContinuationToken | None = None,
|
||||
raw_representation: Any | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
@@ -2236,6 +2280,8 @@ class AgentResponse(SerializationMixin, Generic[ResponseModelT]):
|
||||
usage_details: The usage details for the chat response.
|
||||
value: The structured output of the agent run response, if applicable.
|
||||
response_format: Optional response format for the agent response.
|
||||
continuation_token: Optional token for resuming a long-running background operation.
|
||||
When present, indicates the operation is still in progress.
|
||||
additional_properties: Any additional properties associated with the chat response.
|
||||
raw_representation: The raw representation of the chat response from an underlying implementation.
|
||||
"""
|
||||
@@ -2262,6 +2308,7 @@ class AgentResponse(SerializationMixin, Generic[ResponseModelT]):
|
||||
self._response_format: type[BaseModel] | None = response_format
|
||||
self._value_parsed: bool = value is not None
|
||||
self.additional_properties = additional_properties or {}
|
||||
self.continuation_token = continuation_token
|
||||
self.raw_representation = raw_representation
|
||||
|
||||
@property
|
||||
@@ -2444,6 +2491,7 @@ class AgentResponseUpdate(SerializationMixin):
|
||||
response_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
created_at: CreatedAtT | None = None,
|
||||
continuation_token: ContinuationToken | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
raw_representation: Any | None = None,
|
||||
) -> None:
|
||||
@@ -2458,6 +2506,8 @@ class AgentResponseUpdate(SerializationMixin):
|
||||
response_id: Optional ID of the response of which this update is a part.
|
||||
message_id: Optional ID of the message of which this update is a part.
|
||||
created_at: Optional timestamp for the chat response update.
|
||||
continuation_token: Optional token for resuming a long-running background operation.
|
||||
When present, indicates the operation is still in progress.
|
||||
additional_properties: Optional additional properties associated with the chat response update.
|
||||
raw_representation: Optional raw representation of the chat response update.
|
||||
|
||||
@@ -2486,6 +2536,7 @@ class AgentResponseUpdate(SerializationMixin):
|
||||
self.response_id = response_id
|
||||
self.message_id = message_id
|
||||
self.created_at = created_at
|
||||
self.continuation_token = continuation_token
|
||||
self.additional_properties = additional_properties
|
||||
self.raw_representation: Any | list[Any] | None = raw_representation
|
||||
|
||||
@@ -2514,6 +2565,7 @@ def map_chat_to_agent_update(update: ChatResponseUpdate, agent_name: str | None)
|
||||
response_id=update.response_id,
|
||||
message_id=update.message_id,
|
||||
created_at=update.created_at,
|
||||
continuation_token=update.continuation_token,
|
||||
additional_properties=update.additional_properties,
|
||||
raw_representation=update,
|
||||
)
|
||||
|
||||
@@ -1139,8 +1139,14 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
else:
|
||||
raise RuntimeError("Streaming telemetry requires a ResponseStream result.")
|
||||
|
||||
span_cm = _get_span(attributes=attributes, span_name_attribute=SpanAttributes.LLM_REQUEST_MODEL)
|
||||
span = span_cm.__enter__()
|
||||
# Create span directly without trace.use_span() context attachment.
|
||||
# Streaming spans are closed asynchronously in cleanup hooks, which run
|
||||
# in a different async context than creation — using use_span() would
|
||||
# cause "Failed to detach context" errors from OpenTelemetry.
|
||||
operation = attributes.get(OtelAttr.OPERATION, "operation")
|
||||
span_name = attributes.get(SpanAttributes.LLM_REQUEST_MODEL, "unknown")
|
||||
span = get_tracer().start_span(f"{operation} {span_name}")
|
||||
span.set_attributes(attributes)
|
||||
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages:
|
||||
_capture_messages(
|
||||
span=span,
|
||||
@@ -1157,7 +1163,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
if span_state["closed"]:
|
||||
return
|
||||
span_state["closed"] = True
|
||||
span_cm.__exit__(None, None, None)
|
||||
span.end()
|
||||
|
||||
def _record_duration() -> None:
|
||||
duration_state["duration"] = perf_counter() - start_time
|
||||
@@ -1326,8 +1332,14 @@ class AgentTelemetryLayer:
|
||||
else:
|
||||
raise RuntimeError("Streaming telemetry requires a ResponseStream result.")
|
||||
|
||||
span_cm = _get_span(attributes=attributes, span_name_attribute=OtelAttr.AGENT_NAME)
|
||||
span = span_cm.__enter__()
|
||||
# Create span directly without trace.use_span() context attachment.
|
||||
# Streaming spans are closed asynchronously in cleanup hooks, which run
|
||||
# in a different async context than creation — using use_span() would
|
||||
# cause "Failed to detach context" errors from OpenTelemetry.
|
||||
operation = attributes.get(OtelAttr.OPERATION, "operation")
|
||||
span_name = attributes.get(OtelAttr.AGENT_NAME, "unknown")
|
||||
span = get_tracer().start_span(f"{operation} {span_name}")
|
||||
span.set_attributes(attributes)
|
||||
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages:
|
||||
_capture_messages(
|
||||
span=span,
|
||||
@@ -1344,7 +1356,7 @@ class AgentTelemetryLayer:
|
||||
if span_state["closed"]:
|
||||
return
|
||||
span_state["closed"] = True
|
||||
span_cm.__exit__(None, None, None)
|
||||
span.end()
|
||||
|
||||
def _record_duration() -> None:
|
||||
duration_state["duration"] = perf_counter() - start_time
|
||||
|
||||
@@ -56,6 +56,7 @@ from .._types import (
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
Content,
|
||||
ContinuationToken,
|
||||
ResponseStream,
|
||||
Role,
|
||||
TextSpanRegion,
|
||||
@@ -98,7 +99,14 @@ if TYPE_CHECKING:
|
||||
logger = get_logger("agent_framework.openai")
|
||||
|
||||
|
||||
__all__ = ["OpenAIResponsesClient", "OpenAIResponsesOptions", "RawOpenAIResponsesClient"]
|
||||
__all__ = ["OpenAIContinuationToken", "OpenAIResponsesClient", "OpenAIResponsesOptions", "RawOpenAIResponsesClient"]
|
||||
|
||||
|
||||
class OpenAIContinuationToken(ContinuationToken):
|
||||
"""Continuation token for OpenAI Responses API background operations."""
|
||||
|
||||
response_id: str
|
||||
"""OpenAI Responses API response ID."""
|
||||
|
||||
|
||||
# region OpenAI Responses Options TypedDict
|
||||
@@ -190,6 +198,17 @@ class OpenAIResponsesOptions(ChatOptions[ResponseFormatT], Generic[ResponseForma
|
||||
- 'auto': Truncate from beginning if exceeds context
|
||||
- 'disabled': Fail with 400 error if exceeds context"""
|
||||
|
||||
background: bool
|
||||
"""Whether to run the model response in the background.
|
||||
When True, the response returns immediately with a continuation token
|
||||
that can be used to poll for the result.
|
||||
See: https://platform.openai.com/docs/guides/background"""
|
||||
|
||||
continuation_token: OpenAIContinuationToken
|
||||
"""Token for resuming or polling a long-running background operation.
|
||||
Pass the ``continuation_token`` from a previous response to poll for
|
||||
completion or resume a streaming response."""
|
||||
|
||||
|
||||
OpenAIResponsesOptionsT = TypeVar(
|
||||
"OpenAIResponsesOptionsT",
|
||||
@@ -266,33 +285,60 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
|
||||
stream: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
||||
continuation_token: OpenAIContinuationToken | None = options.get("continuation_token") # type: ignore[assignment]
|
||||
|
||||
if stream:
|
||||
function_call_ids: dict[int, tuple[str, str]] = {}
|
||||
validated_options: dict[str, Any] | None = None
|
||||
|
||||
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
||||
nonlocal validated_options
|
||||
client, run_options, validated_options = await self._prepare_request(messages, options, **kwargs)
|
||||
try:
|
||||
if "text_format" in run_options:
|
||||
async with client.responses.stream(**run_options) as response:
|
||||
async for chunk in response:
|
||||
yield self._parse_chunk_from_openai(
|
||||
chunk, options=validated_options, function_call_ids=function_call_ids
|
||||
)
|
||||
else:
|
||||
async for chunk in await client.responses.create(stream=True, **run_options):
|
||||
if continuation_token is not None:
|
||||
# Resume a background streaming response by retrieving with stream=True
|
||||
client = await self._ensure_client()
|
||||
validated_options = await self._validate_options(options)
|
||||
try:
|
||||
stream_response = await client.responses.retrieve(
|
||||
continuation_token["response_id"],
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in stream_response:
|
||||
yield self._parse_chunk_from_openai(
|
||||
chunk, options=validated_options, function_call_ids=function_call_ids
|
||||
)
|
||||
except Exception as ex:
|
||||
self._handle_request_error(ex)
|
||||
except Exception as ex:
|
||||
self._handle_request_error(ex)
|
||||
else:
|
||||
client, run_options, validated_options = await self._prepare_request(messages, options, **kwargs)
|
||||
try:
|
||||
if "text_format" in run_options:
|
||||
async with client.responses.stream(**run_options) as response:
|
||||
async for chunk in response:
|
||||
yield self._parse_chunk_from_openai(
|
||||
chunk, options=validated_options, function_call_ids=function_call_ids
|
||||
)
|
||||
else:
|
||||
async for chunk in await client.responses.create(stream=True, **run_options):
|
||||
yield self._parse_chunk_from_openai(
|
||||
chunk, options=validated_options, function_call_ids=function_call_ids
|
||||
)
|
||||
except Exception as ex:
|
||||
self._handle_request_error(ex)
|
||||
|
||||
response_format = validated_options.get("response_format") if validated_options else None
|
||||
return self._build_response_stream(_stream(), response_format=response_format)
|
||||
|
||||
# Non-streaming
|
||||
async def _get_response() -> ChatResponse:
|
||||
if continuation_token is not None:
|
||||
# Poll a background response by retrieving without stream
|
||||
client = await self._ensure_client()
|
||||
validated_options = await self._validate_options(options)
|
||||
try:
|
||||
response = await client.responses.retrieve(continuation_token["response_id"])
|
||||
except Exception as ex:
|
||||
self._handle_request_error(ex)
|
||||
return self._parse_response_from_openai(response, options=validated_options)
|
||||
client, run_options, validated_options = await self._prepare_request(messages, options, **kwargs)
|
||||
try:
|
||||
if "text_format" in run_options:
|
||||
@@ -538,6 +584,7 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
|
||||
"response_format", # handled separately
|
||||
"conversation_id", # handled separately
|
||||
"tool_choice", # handled separately
|
||||
"continuation_token", # handled separately in _inner_get_response
|
||||
}
|
||||
run_options: dict[str, Any] = {k: v for k, v in options.items() if k not in exclude_keys and v is not None}
|
||||
|
||||
@@ -1070,6 +1117,9 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
|
||||
# Only pass response_format to ChatResponse if it's a Pydantic model type,
|
||||
# not a runtime JSON schema dict
|
||||
args["response_format"] = response_format
|
||||
# Set continuation_token when background operation is still in progress
|
||||
if response.status and response.status in ("in_progress", "queued"):
|
||||
args["continuation_token"] = OpenAIContinuationToken(response_id=response.id)
|
||||
return ChatResponse(**args)
|
||||
|
||||
def _parse_chunk_from_openai(
|
||||
@@ -1083,6 +1133,7 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
|
||||
contents: list[Content] = []
|
||||
conversation_id: str | None = None
|
||||
response_id: str | None = None
|
||||
continuation_token: OpenAIContinuationToken | None = None
|
||||
model = self.model_id
|
||||
match event.type:
|
||||
# types:
|
||||
@@ -1211,9 +1262,12 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
|
||||
case "response.created":
|
||||
response_id = event.response.id
|
||||
conversation_id = self._get_conversation_id(event.response, options.get("store"))
|
||||
if event.response.status and event.response.status in ("in_progress", "queued"):
|
||||
continuation_token = OpenAIContinuationToken(response_id=event.response.id)
|
||||
case "response.in_progress":
|
||||
response_id = event.response.id
|
||||
conversation_id = self._get_conversation_id(event.response, options.get("store"))
|
||||
continuation_token = OpenAIContinuationToken(response_id=event.response.id)
|
||||
case "response.completed":
|
||||
response_id = event.response.id
|
||||
conversation_id = self._get_conversation_id(event.response, options.get("store"))
|
||||
@@ -1454,6 +1508,7 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
|
||||
response_id=response_id,
|
||||
role="assistant",
|
||||
model_id=model,
|
||||
continuation_token=continuation_token,
|
||||
additional_properties=metadata,
|
||||
raw_representation=event,
|
||||
)
|
||||
|
||||
@@ -2434,3 +2434,263 @@ async def test_integration_streaming_file_search() -> None:
|
||||
|
||||
assert "sunny" in full_message.lower()
|
||||
assert "75" in full_message
|
||||
|
||||
|
||||
# region Background Response / ContinuationToken Tests
|
||||
|
||||
|
||||
def test_continuation_token_json_serializable() -> None:
|
||||
"""Test that OpenAIContinuationToken is a plain dict and JSON-serializable."""
|
||||
from agent_framework.openai import OpenAIContinuationToken
|
||||
|
||||
token = OpenAIContinuationToken(response_id="resp_abc123")
|
||||
assert token["response_id"] == "resp_abc123"
|
||||
|
||||
# JSON round-trip
|
||||
serialized = json.dumps(token)
|
||||
restored = json.loads(serialized)
|
||||
assert restored["response_id"] == "resp_abc123"
|
||||
|
||||
|
||||
def test_chat_response_with_continuation_token() -> None:
|
||||
"""Test that ChatResponse accepts and stores continuation_token."""
|
||||
from agent_framework.openai import OpenAIContinuationToken
|
||||
|
||||
token = OpenAIContinuationToken(response_id="resp_123")
|
||||
response = ChatResponse(
|
||||
messages=ChatMessage(role="assistant", contents=[Content.from_text(text="Hello")]),
|
||||
response_id="resp_123",
|
||||
continuation_token=token,
|
||||
)
|
||||
assert response.continuation_token is not None
|
||||
assert response.continuation_token["response_id"] == "resp_123"
|
||||
|
||||
|
||||
def test_chat_response_without_continuation_token() -> None:
|
||||
"""Test that ChatResponse defaults continuation_token to None."""
|
||||
response = ChatResponse(
|
||||
messages=ChatMessage(role="assistant", contents=[Content.from_text(text="Hello")]),
|
||||
)
|
||||
assert response.continuation_token is None
|
||||
|
||||
|
||||
def test_chat_response_update_with_continuation_token() -> None:
|
||||
"""Test that ChatResponseUpdate accepts and stores continuation_token."""
|
||||
from agent_framework.openai import OpenAIContinuationToken
|
||||
|
||||
token = OpenAIContinuationToken(response_id="resp_456")
|
||||
update = ChatResponseUpdate(
|
||||
contents=[Content.from_text(text="chunk")],
|
||||
role="assistant",
|
||||
continuation_token=token,
|
||||
)
|
||||
assert update.continuation_token is not None
|
||||
assert update.continuation_token["response_id"] == "resp_456"
|
||||
|
||||
|
||||
def test_agent_response_with_continuation_token() -> None:
|
||||
"""Test that AgentResponse accepts and stores continuation_token."""
|
||||
from agent_framework import AgentResponse
|
||||
from agent_framework.openai import OpenAIContinuationToken
|
||||
|
||||
token = OpenAIContinuationToken(response_id="resp_789")
|
||||
response = AgentResponse(
|
||||
messages=ChatMessage(role="assistant", contents=[Content.from_text(text="done")]),
|
||||
continuation_token=token,
|
||||
)
|
||||
assert response.continuation_token is not None
|
||||
assert response.continuation_token["response_id"] == "resp_789"
|
||||
|
||||
|
||||
def test_agent_response_update_with_continuation_token() -> None:
|
||||
"""Test that AgentResponseUpdate accepts and stores continuation_token."""
|
||||
from agent_framework import AgentResponseUpdate
|
||||
from agent_framework.openai import OpenAIContinuationToken
|
||||
|
||||
token = OpenAIContinuationToken(response_id="resp_012")
|
||||
update = AgentResponseUpdate(
|
||||
contents=[Content.from_text(text="streaming")],
|
||||
role="assistant",
|
||||
continuation_token=token,
|
||||
)
|
||||
assert update.continuation_token is not None
|
||||
assert update.continuation_token["response_id"] == "resp_012"
|
||||
|
||||
|
||||
def test_parse_response_from_openai_with_background_in_progress() -> None:
|
||||
"""Test that _parse_response_from_openai sets continuation_token when status is in_progress."""
|
||||
client = OpenAIResponsesClient(model_id="test-model", api_key="test-key")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.output_parsed = None
|
||||
mock_response.metadata = {}
|
||||
mock_response.usage = None
|
||||
mock_response.id = "resp_bg_123"
|
||||
mock_response.model = "test-model"
|
||||
mock_response.created_at = 1000000000
|
||||
mock_response.status = "in_progress"
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.type = "message"
|
||||
mock_message.content = []
|
||||
mock_response.output = [mock_message]
|
||||
|
||||
options: dict[str, Any] = {"store": False}
|
||||
result = client._parse_response_from_openai(mock_response, options=options)
|
||||
|
||||
assert result.continuation_token is not None
|
||||
assert result.continuation_token["response_id"] == "resp_bg_123"
|
||||
|
||||
|
||||
def test_parse_response_from_openai_with_background_queued() -> None:
|
||||
"""Test that _parse_response_from_openai sets continuation_token when status is queued."""
|
||||
client = OpenAIResponsesClient(model_id="test-model", api_key="test-key")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.output_parsed = None
|
||||
mock_response.metadata = {}
|
||||
mock_response.usage = None
|
||||
mock_response.id = "resp_bg_456"
|
||||
mock_response.model = "test-model"
|
||||
mock_response.created_at = 1000000000
|
||||
mock_response.status = "queued"
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.type = "message"
|
||||
mock_message.content = []
|
||||
mock_response.output = [mock_message]
|
||||
|
||||
options: dict[str, Any] = {"store": False}
|
||||
result = client._parse_response_from_openai(mock_response, options=options)
|
||||
|
||||
assert result.continuation_token is not None
|
||||
assert result.continuation_token["response_id"] == "resp_bg_456"
|
||||
|
||||
|
||||
def test_parse_response_from_openai_with_background_completed() -> None:
|
||||
"""Test that _parse_response_from_openai does NOT set continuation_token when status is completed."""
|
||||
client = OpenAIResponsesClient(model_id="test-model", api_key="test-key")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.output_parsed = None
|
||||
mock_response.metadata = {}
|
||||
mock_response.usage = None
|
||||
mock_response.id = "resp_bg_789"
|
||||
mock_response.model = "test-model"
|
||||
mock_response.created_at = 1000000000
|
||||
mock_response.status = "completed"
|
||||
|
||||
mock_text_content = MagicMock()
|
||||
mock_text_content.type = "output_text"
|
||||
mock_text_content.text = "Final answer"
|
||||
mock_text_content.annotations = []
|
||||
mock_text_content.logprobs = None
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.type = "message"
|
||||
mock_message.content = [mock_text_content]
|
||||
mock_response.output = [mock_message]
|
||||
|
||||
options: dict[str, Any] = {"store": False}
|
||||
result = client._parse_response_from_openai(mock_response, options=options)
|
||||
|
||||
assert result.continuation_token is None
|
||||
|
||||
|
||||
def test_streaming_response_in_progress_sets_continuation_token() -> None:
|
||||
"""Test that _parse_chunk_from_openai sets continuation_token for in_progress events."""
|
||||
client = OpenAIResponsesClient(model_id="test-model", api_key="test-key")
|
||||
chat_options: dict[str, Any] = {}
|
||||
function_call_ids: dict[int, tuple[str, str]] = {}
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.type = "response.in_progress"
|
||||
mock_event.response = MagicMock()
|
||||
mock_event.response.id = "resp_stream_123"
|
||||
mock_event.response.conversation = MagicMock()
|
||||
mock_event.response.conversation.id = "conv_456"
|
||||
mock_event.response.status = "in_progress"
|
||||
|
||||
update = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids)
|
||||
|
||||
assert update.continuation_token is not None
|
||||
assert update.continuation_token["response_id"] == "resp_stream_123"
|
||||
|
||||
|
||||
def test_streaming_response_created_with_in_progress_status_sets_continuation_token() -> None:
|
||||
"""Test that response.created with in_progress status sets continuation_token."""
|
||||
client = OpenAIResponsesClient(model_id="test-model", api_key="test-key")
|
||||
chat_options: dict[str, Any] = {}
|
||||
function_call_ids: dict[int, tuple[str, str]] = {}
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.type = "response.created"
|
||||
mock_event.response = MagicMock()
|
||||
mock_event.response.id = "resp_created_123"
|
||||
mock_event.response.conversation = MagicMock()
|
||||
mock_event.response.conversation.id = "conv_789"
|
||||
mock_event.response.status = "in_progress"
|
||||
|
||||
update = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids)
|
||||
|
||||
assert update.continuation_token is not None
|
||||
assert update.continuation_token["response_id"] == "resp_created_123"
|
||||
|
||||
|
||||
def test_streaming_response_completed_no_continuation_token() -> None:
|
||||
"""Test that response.completed does NOT set continuation_token."""
|
||||
client = OpenAIResponsesClient(model_id="test-model", api_key="test-key")
|
||||
chat_options: dict[str, Any] = {}
|
||||
function_call_ids: dict[int, tuple[str, str]] = {}
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.type = "response.completed"
|
||||
mock_event.response = MagicMock()
|
||||
mock_event.response.id = "resp_done_123"
|
||||
mock_event.response.conversation = MagicMock()
|
||||
mock_event.response.conversation.id = "conv_done"
|
||||
mock_event.response.model = "test-model"
|
||||
mock_event.response.usage = None
|
||||
|
||||
update = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids)
|
||||
|
||||
assert update.continuation_token is None
|
||||
|
||||
|
||||
def test_map_chat_to_agent_update_preserves_continuation_token() -> None:
|
||||
"""Test that map_chat_to_agent_update propagates continuation_token."""
|
||||
from agent_framework._types import map_chat_to_agent_update
|
||||
|
||||
token = {"response_id": "resp_map_123"}
|
||||
chat_update = ChatResponseUpdate(
|
||||
contents=[Content.from_text(text="chunk")],
|
||||
role="assistant",
|
||||
response_id="resp_map_123",
|
||||
continuation_token=token,
|
||||
)
|
||||
|
||||
agent_update = map_chat_to_agent_update(chat_update, agent_name="test-agent")
|
||||
|
||||
assert agent_update.continuation_token is not None
|
||||
assert agent_update.continuation_token["response_id"] == "resp_map_123"
|
||||
|
||||
|
||||
async def test_prepare_options_excludes_continuation_token() -> None:
|
||||
"""Test that _prepare_options does not pass continuation_token to OpenAI API."""
|
||||
client = OpenAIResponsesClient(model_id="test-model", api_key="test-key")
|
||||
|
||||
messages = [ChatMessage(role="user", contents=[Content.from_text(text="Hello")])]
|
||||
options: dict[str, Any] = {
|
||||
"model_id": "test-model",
|
||||
"continuation_token": {"response_id": "resp_123"},
|
||||
"background": True,
|
||||
}
|
||||
|
||||
run_options = await client._prepare_options(messages, options)
|
||||
|
||||
assert "continuation_token" not in run_options
|
||||
assert "background" in run_options
|
||||
assert run_options["background"] is True
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -0,0 +1,139 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
|
||||
from agent_framework import ChatAgent
|
||||
from agent_framework.openai import OpenAIResponsesClient
|
||||
|
||||
"""Background Responses Sample.
|
||||
|
||||
This sample demonstrates long-running agent operations using the OpenAI
|
||||
Responses API ``background`` option. Two patterns are shown:
|
||||
|
||||
1. **Non-streaming polling** – start a background run, then poll with the
|
||||
``continuation_token`` until the operation completes.
|
||||
2. **Streaming with resumption** – start a background streaming run, simulate
|
||||
an interruption, and resume from the last ``continuation_token``.
|
||||
|
||||
Prerequisites:
|
||||
- Set the ``OPENAI_API_KEY`` environment variable.
|
||||
- A model that benefits from background execution (e.g. ``o3``).
|
||||
"""
|
||||
|
||||
|
||||
# 1. Create the agent with an OpenAI Responses client.
|
||||
agent = ChatAgent(
|
||||
name="researcher",
|
||||
instructions="You are a helpful research assistant. Be concise.",
|
||||
chat_client=OpenAIResponsesClient(model_id="o3"),
|
||||
)
|
||||
|
||||
|
||||
async def non_streaming_polling() -> None:
|
||||
"""Demonstrate non-streaming background run with polling."""
|
||||
print("=== Non-Streaming Polling ===\n")
|
||||
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
# 2. Start a background run — returns immediately.
|
||||
response = await agent.run(
|
||||
messages="Briefly explain the theory of relativity in two sentences.",
|
||||
thread=thread,
|
||||
options={"background": True},
|
||||
)
|
||||
|
||||
print(f"Initial status: continuation_token={'set' if response.continuation_token else 'None'}")
|
||||
|
||||
# 3. Poll until the operation completes.
|
||||
poll_count = 0
|
||||
while response.continuation_token is not None:
|
||||
poll_count += 1
|
||||
await asyncio.sleep(2)
|
||||
response = await agent.run(
|
||||
thread=thread,
|
||||
options={"continuation_token": response.continuation_token},
|
||||
)
|
||||
print(f" Poll {poll_count}: continuation_token={'set' if response.continuation_token else 'None'}")
|
||||
|
||||
# 4. Done — print the final result.
|
||||
print(f"\nResult ({poll_count} poll(s)):\n{response.text}\n")
|
||||
|
||||
|
||||
async def streaming_with_resumption() -> None:
|
||||
"""Demonstrate streaming background run with simulated interruption and resumption."""
|
||||
print("=== Streaming with Resumption ===\n")
|
||||
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
# 2. Start a streaming background run.
|
||||
last_token = None
|
||||
stream = agent.run(
|
||||
messages="Briefly list three benefits of exercise.",
|
||||
stream=True,
|
||||
thread=thread,
|
||||
options={"background": True},
|
||||
)
|
||||
|
||||
# 3. Read some chunks, then simulate an interruption.
|
||||
chunk_count = 0
|
||||
print("First stream (before interruption):")
|
||||
async for update in stream:
|
||||
last_token = update.continuation_token
|
||||
if update.text:
|
||||
print(update.text, end="", flush=True)
|
||||
chunk_count += 1
|
||||
if chunk_count >= 3:
|
||||
print("\n [simulated interruption]")
|
||||
break
|
||||
|
||||
# 4. Resume from the last continuation token.
|
||||
if last_token is not None:
|
||||
print("Resumed stream:")
|
||||
stream = agent.run(
|
||||
stream=True,
|
||||
thread=thread,
|
||||
options={"continuation_token": last_token},
|
||||
)
|
||||
async for update in stream:
|
||||
if update.text:
|
||||
print(update.text, end="", flush=True)
|
||||
|
||||
print("\n")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
await non_streaming_polling()
|
||||
await streaming_with_resumption()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
Sample output:
|
||||
|
||||
=== Non-Streaming Polling ===
|
||||
|
||||
Initial status: continuation_token=set
|
||||
Poll 1: continuation_token=set
|
||||
Poll 2: continuation_token=None
|
||||
|
||||
Result (2 poll(s)):
|
||||
The theory of relativity, developed by Albert Einstein, consists of special
|
||||
relativity (1905), which shows that the laws of physics are the same for all
|
||||
non-accelerating observers and that the speed of light is constant, and general
|
||||
relativity (1915), which describes gravity as the curvature of spacetime caused
|
||||
by mass and energy.
|
||||
|
||||
=== Streaming with Resumption ===
|
||||
|
||||
First stream (before interruption):
|
||||
Here are three
|
||||
[simulated interruption]
|
||||
Resumed stream:
|
||||
key benefits of regular exercise:
|
||||
|
||||
1. **Improved cardiovascular health** ...
|
||||
2. **Better mental health** ...
|
||||
3. **Stronger muscles and bones** ...
|
||||
"""
|
||||
@@ -2,12 +2,15 @@
|
||||
|
||||
This folder contains examples demonstrating how to create and use agents with the A2A (Agent2Agent) protocol from the `agent_framework` package to communicate with remote A2A agents.
|
||||
|
||||
By default the A2AAgent waits for the remote agent to finish before returning (`background=False`), so long-running A2A tasks are handled transparently. For advanced scenarios where you need to poll or resubscribe to in-progress tasks using continuation tokens, see the [background responses sample](../../../concepts/background_responses.py).
|
||||
|
||||
For more information about the A2A protocol specification, visit: https://a2a-protocol.org/latest/
|
||||
|
||||
## Examples
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| [`agent_with_a2a.py`](agent_with_a2a.py) | The simplest way to connect to and use a single A2A agent. Demonstrates agent discovery via agent cards and basic message exchange using the A2A protocol. |
|
||||
| [`agent_with_a2a.py`](agent_with_a2a.py) | Demonstrates agent discovery, non-streaming and streaming responses using the A2A protocol. |
|
||||
|
||||
## Environment Variables
|
||||
|
||||
|
||||
@@ -15,13 +15,18 @@ the A2A protocol. A2A is a standardized communication protocol that enables inte
|
||||
between different agent systems, allowing agents built with different frameworks and
|
||||
technologies to communicate seamlessly.
|
||||
|
||||
By default the A2AAgent waits for the remote agent to finish before returning (background=False).
|
||||
This means long-running A2A tasks are handled transparently — the caller simply awaits the result.
|
||||
For advanced scenarios where you need to poll or resubscribe to in-progress tasks, see the
|
||||
background_responses sample: samples/concepts/background_responses.py
|
||||
|
||||
For more information about the A2A protocol specification, visit: https://a2a-protocol.org/latest/
|
||||
|
||||
Key concepts demonstrated:
|
||||
- Discovering A2A-compliant agents using AgentCard resolution
|
||||
- Creating A2AAgent instances to wrap external A2A endpoints
|
||||
- Converting Agent Framework messages to A2A protocol format
|
||||
- Handling A2A responses (Messages and Tasks) back to framework types
|
||||
- Non-streaming request/response
|
||||
- Streaming responses to receive incremental updates via SSE
|
||||
|
||||
To run this sample:
|
||||
1. Set the A2A_AGENT_HOST environment variable to point to an A2A-compliant agent endpoint
|
||||
@@ -29,50 +34,75 @@ To run this sample:
|
||||
2. Ensure the target agent exposes its AgentCard at /.well-known/agent.json
|
||||
3. Run: uv run python agent_with_a2a.py
|
||||
|
||||
The sample will:
|
||||
- Connect to the specified A2A agent endpoint
|
||||
- Retrieve and parse the agent's capabilities via its AgentCard
|
||||
- Send a message using the A2A protocol
|
||||
- Display the agent's response
|
||||
|
||||
Visit the README.md for more details on setting up and running A2A agents.
|
||||
"""
|
||||
|
||||
|
||||
async def main():
|
||||
"""Demonstrates connecting to and communicating with an A2A-compliant agent."""
|
||||
# Get A2A agent host from environment
|
||||
# 1. Get A2A agent host from environment.
|
||||
a2a_agent_host = os.getenv("A2A_AGENT_HOST")
|
||||
if not a2a_agent_host:
|
||||
raise ValueError("A2A_AGENT_HOST environment variable is not set")
|
||||
|
||||
print(f"Connecting to A2A agent at: {a2a_agent_host}")
|
||||
|
||||
# Initialize A2ACardResolver
|
||||
# 2. Resolve the agent card to discover capabilities.
|
||||
async with httpx.AsyncClient(timeout=60.0) as http_client:
|
||||
resolver = A2ACardResolver(httpx_client=http_client, base_url=a2a_agent_host)
|
||||
|
||||
# Get agent card
|
||||
agent_card = await resolver.get_agent_card()
|
||||
print(f"Found agent: {agent_card.name} - {agent_card.description}")
|
||||
|
||||
# Create A2A agent instance
|
||||
agent = A2AAgent(
|
||||
name=agent_card.name,
|
||||
description=agent_card.description,
|
||||
agent_card=agent_card,
|
||||
url=a2a_agent_host,
|
||||
)
|
||||
|
||||
# Invoke the agent and output the result
|
||||
print("\nSending message to A2A agent...")
|
||||
# 3. Create A2A agent instance.
|
||||
async with A2AAgent(
|
||||
name=agent_card.name,
|
||||
description=agent_card.description,
|
||||
agent_card=agent_card,
|
||||
url=a2a_agent_host,
|
||||
) as agent:
|
||||
# 4. Simple request/response — the agent waits for completion internally.
|
||||
# Even if the remote agent takes a while, background=False (the default)
|
||||
# means the call blocks until a terminal state is reached.
|
||||
print("\n--- Non-streaming response ---")
|
||||
response = await agent.run("What are your capabilities?")
|
||||
|
||||
# Print the response
|
||||
print("\nAgent Response:")
|
||||
print("Agent Response:")
|
||||
for message in response.messages:
|
||||
print(message.text)
|
||||
print(f" {message.text}")
|
||||
|
||||
# 5. Stream a response — the natural model for A2A.
|
||||
# Updates arrive as Server-Sent Events, letting you observe
|
||||
# progress in real time as the remote agent works.
|
||||
print("\n--- Streaming response ---")
|
||||
async with agent.run("Tell me about yourself", stream=True) as stream:
|
||||
async for update in stream:
|
||||
for content in update.contents:
|
||||
if content.text:
|
||||
print(f" {content.text}")
|
||||
|
||||
response = await stream.get_final_response()
|
||||
print(f"\nFinal response ({len(response.messages)} message(s)):")
|
||||
for message in response.messages:
|
||||
print(f" {message.text}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
"""
|
||||
Sample output:
|
||||
|
||||
Connecting to A2A agent at: http://localhost:5001/
|
||||
Found agent: MyAgent - A helpful AI assistant
|
||||
|
||||
--- Non-streaming response ---
|
||||
Agent Response:
|
||||
I can help with code generation, analysis, and general Q&A.
|
||||
|
||||
--- Streaming response ---
|
||||
I am an AI assistant built to help with various tasks.
|
||||
|
||||
Final response (1 message(s)):
|
||||
I am an AI assistant built to help with various tasks.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user