mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Persist hosted MCP call/results as canonical mcp_call output (#6070)
* Persist hosted MCP call/results as canonical mcp_call output - Preserve hosted MCP call/result pairs as canonical mcp_call output items - Coalesce MCP call + result in non-streaming conversion path - Keep call-id alignment for MCP tool call tracking and output mapping - Update tests and package metadata * Fix missing Mapping import in hosted responses adapter * Fix pyright unknown type in MCP output stringification * Fix typing for MCP output sequence iteration * Improve MCP output robustness and avoid eager flattening * Bump foundry_hosting to b7 and update responses dependency to b7 * Restore foundry_hosting package version to 1.0.0a260521 * Refactor hosted MCP output parsing
This commit is contained in:
committed by
GitHub
Unverified
parent
05ebb966cf
commit
043208241a
@@ -9,7 +9,7 @@ import logging
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
from collections.abc import AsyncIterable, AsyncIterator, Generator, Sequence
|
||||
from collections.abc import AsyncIterable, AsyncIterator, Generator, Mapping, Sequence
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack, suppress
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from pathlib import Path
|
||||
@@ -472,14 +472,12 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
# Run the agent in non-streaming mode
|
||||
response = await self._agent.run(stream=False, **run_kwargs) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
for message in response.messages:
|
||||
for content in message.contents:
|
||||
async for item in _to_outputs(
|
||||
response_event_stream,
|
||||
content,
|
||||
approval_storage=self._approval_storage,
|
||||
):
|
||||
yield item
|
||||
async for item in _to_outputs_for_messages(
|
||||
response_event_stream,
|
||||
response.messages,
|
||||
approval_storage=self._approval_storage,
|
||||
):
|
||||
yield item
|
||||
yield response_event_stream.emit_completed()
|
||||
else:
|
||||
if tracker is None: # pragma: no cover - defensive, set above
|
||||
@@ -620,10 +618,8 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
checkpoint_storage=write_storage,
|
||||
)
|
||||
|
||||
for message in response.messages:
|
||||
for content in message.contents:
|
||||
async for item in _to_outputs(response_event_stream, content):
|
||||
yield item
|
||||
async for item in _to_outputs_for_messages(response_event_stream, response.messages):
|
||||
yield item
|
||||
|
||||
await self._delete_not_latest_checkpoints(write_storage, self._agent.workflow.name)
|
||||
yield response_event_stream.emit_completed()
|
||||
@@ -729,7 +725,7 @@ class _OutputItemTracker:
|
||||
yield self._fc_builder.emit_arguments_delta(args_str)
|
||||
|
||||
elif content.type == "mcp_server_tool_call" and content.tool_name:
|
||||
key = f"{content.server_name or 'default'}::{content.tool_name}"
|
||||
key = content.call_id or f"{content.server_name or 'default'}::{content.tool_name}"
|
||||
if self._active_type != "mcp_server_tool_call" or self._active_id != key:
|
||||
yield from self._close()
|
||||
yield from self._open_mcp_call(content)
|
||||
@@ -738,6 +734,24 @@ class _OutputItemTracker:
|
||||
if self._mcp_builder is not None:
|
||||
yield self._mcp_builder.emit_arguments_delta(args_str)
|
||||
|
||||
elif (
|
||||
content.type == "mcp_server_tool_result"
|
||||
and self._active_type == "mcp_server_tool_call"
|
||||
and self._mcp_builder is not None
|
||||
and content.call_id is not None
|
||||
and content.call_id == self._mcp_builder.item_id
|
||||
):
|
||||
accumulated = "".join(self._accumulated)
|
||||
yield self._mcp_builder.emit_arguments_done(accumulated)
|
||||
yield self._mcp_builder.emit_completed()
|
||||
yield self._mcp_builder.emit_done(output=_stringify_mcp_output(content.output))
|
||||
self._mcp_builder = None
|
||||
self._active_type = None
|
||||
self._active_id = None
|
||||
self._accumulated.clear()
|
||||
self.needs_async = False
|
||||
return
|
||||
|
||||
else:
|
||||
yield from self._close()
|
||||
self.needs_async = True
|
||||
@@ -777,9 +791,10 @@ class _OutputItemTracker:
|
||||
self._mcp_builder = self._stream.add_output_item_mcp_call(
|
||||
server_label=content.server_name or "default",
|
||||
name=content.tool_name or "",
|
||||
item_id=content.call_id,
|
||||
)
|
||||
self._active_type = "mcp_server_tool_call"
|
||||
self._active_id = f"{content.server_name or 'default'}::{content.tool_name}"
|
||||
self._active_id = content.call_id or f"{content.server_name or 'default'}::{content.tool_name}"
|
||||
yield self._mcp_builder.emit_added()
|
||||
|
||||
def _close(self) -> Generator[ResponseStreamEvent]:
|
||||
@@ -927,16 +942,19 @@ async def _item_to_message(item: Item, *, approval_storage: ApprovalStorage | No
|
||||
|
||||
if item.type == "mcp_call":
|
||||
mcp = cast(ItemMcpToolCall, item)
|
||||
contents = [
|
||||
Content.from_mcp_server_tool_call(
|
||||
mcp.id,
|
||||
mcp.name,
|
||||
server_name=mcp.server_label,
|
||||
arguments=mcp.arguments,
|
||||
)
|
||||
]
|
||||
if getattr(mcp, "output", None) is not None:
|
||||
contents.append(Content.from_mcp_server_tool_result(call_id=mcp.id, output=mcp.output))
|
||||
return Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_mcp_server_tool_call(
|
||||
mcp.id,
|
||||
mcp.name,
|
||||
server_name=mcp.server_label,
|
||||
arguments=mcp.arguments,
|
||||
)
|
||||
],
|
||||
contents=contents,
|
||||
)
|
||||
|
||||
if item.type == "mcp_approval_request":
|
||||
@@ -1197,16 +1215,19 @@ async def _output_item_to_message(item: OutputItem, *, approval_storage: Approva
|
||||
|
||||
if item.type == "mcp_call":
|
||||
mcp = cast(OutputItemMcpToolCall, item)
|
||||
contents = [
|
||||
Content.from_mcp_server_tool_call(
|
||||
mcp.id,
|
||||
mcp.name,
|
||||
server_name=mcp.server_label,
|
||||
arguments=mcp.arguments,
|
||||
)
|
||||
]
|
||||
if getattr(mcp, "output", None) is not None:
|
||||
contents.append(Content.from_mcp_server_tool_result(call_id=mcp.id, output=mcp.output))
|
||||
return Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_mcp_server_tool_call(
|
||||
mcp.id,
|
||||
mcp.name,
|
||||
server_name=mcp.server_label,
|
||||
arguments=mcp.arguments,
|
||||
)
|
||||
],
|
||||
contents=contents,
|
||||
)
|
||||
|
||||
if item.type == "mcp_approval_request":
|
||||
@@ -1583,6 +1604,7 @@ async def _to_outputs(
|
||||
mcp_call = stream.add_output_item_mcp_call(
|
||||
server_label=content.server_name or "default",
|
||||
name=content.tool_name or "",
|
||||
item_id=content.call_id,
|
||||
)
|
||||
yield mcp_call.emit_added()
|
||||
async for event in mcp_call.aarguments(_arguments_to_str(content.arguments)):
|
||||
@@ -1657,4 +1679,91 @@ async def _to_outputs(
|
||||
logger.warning(f"Content type '{content.type}' is not supported yet. This is usually safe to ignore.")
|
||||
|
||||
|
||||
def _stringify_mcp_output(output: Any) -> str:
|
||||
"""Convert hosted MCP output payloads into the string shape expected by mcp_call.output."""
|
||||
if output is None:
|
||||
return ""
|
||||
if isinstance(output, str):
|
||||
return output
|
||||
if isinstance(output, Mapping):
|
||||
text = cast(Any, output).get("text")
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
return json.dumps(output, default=str)
|
||||
if isinstance(output, Sequence) and not isinstance(output, (str, bytes, bytearray)):
|
||||
parts: list[str] = []
|
||||
entries = cast(Sequence[object], output)
|
||||
for entry in entries:
|
||||
if isinstance(entry, Content) and entry.type == "text":
|
||||
parts.append(entry.text or "")
|
||||
continue
|
||||
parts.append(_stringify_mcp_output(entry))
|
||||
return "".join(parts)
|
||||
return str(output)
|
||||
|
||||
|
||||
def _emit_completed_mcp_call(
|
||||
stream: ResponseEventStream,
|
||||
call_content: Content,
|
||||
*,
|
||||
arguments: str,
|
||||
output: str,
|
||||
) -> Generator[ResponseStreamEvent]:
|
||||
"""Emit a single completed MCP call item carrying both arguments and output."""
|
||||
mcp_call = stream.add_output_item_mcp_call(
|
||||
server_label=call_content.server_name or "default",
|
||||
name=call_content.tool_name or "",
|
||||
item_id=call_content.call_id,
|
||||
)
|
||||
yield mcp_call.emit_added()
|
||||
yield mcp_call.emit_arguments_done(arguments)
|
||||
yield mcp_call.emit_completed()
|
||||
yield mcp_call.emit_done(output=output)
|
||||
|
||||
|
||||
async def _to_outputs_for_messages(
|
||||
stream: ResponseEventStream,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
approval_storage: ApprovalStorage | None = None,
|
||||
) -> AsyncIterator[ResponseStreamEvent]:
|
||||
"""Convert messages to output events with hosted-MCP call/result coalescing.
|
||||
|
||||
Parse once in message/content order and emit either:
|
||||
- a single canonical completed ``mcp_call`` when adjacent hosted MCP
|
||||
call/result content are encountered, or
|
||||
- standard output items for all other content types.
|
||||
"""
|
||||
pending_mcp_call: Content | None = None
|
||||
|
||||
for message in messages:
|
||||
for content in message.contents:
|
||||
if pending_mcp_call is not None:
|
||||
if content.type == "mcp_server_tool_result" and content.call_id == pending_mcp_call.call_id:
|
||||
for event in _emit_completed_mcp_call(
|
||||
stream,
|
||||
pending_mcp_call,
|
||||
arguments=_arguments_to_str(pending_mcp_call.arguments),
|
||||
output=_stringify_mcp_output(content.output),
|
||||
):
|
||||
yield event
|
||||
pending_mcp_call = None
|
||||
continue
|
||||
|
||||
async for event in _to_outputs(stream, pending_mcp_call, approval_storage=approval_storage):
|
||||
yield event
|
||||
pending_mcp_call = None
|
||||
|
||||
if content.type == "mcp_server_tool_call" and content.call_id:
|
||||
pending_mcp_call = content
|
||||
continue
|
||||
|
||||
async for event in _to_outputs(stream, content, approval_storage=approval_storage):
|
||||
yield event
|
||||
|
||||
if pending_mcp_call is not None:
|
||||
async for event in _to_outputs(stream, pending_mcp_call, approval_storage=approval_storage):
|
||||
yield event
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -25,7 +25,7 @@ classifiers = [
|
||||
dependencies = [
|
||||
"agent-framework-core>=1.7.0,<2",
|
||||
"azure-ai-agentserver-core>=2.0.0b3,<3",
|
||||
"azure-ai-agentserver-responses>=1.0.0b5,<2",
|
||||
"azure-ai-agentserver-responses>=1.0.0b7,<2",
|
||||
"azure-ai-agentserver-invocations>=1.0.0b3,<2",
|
||||
]
|
||||
|
||||
|
||||
@@ -260,6 +260,50 @@ class TestNonStreaming:
|
||||
assert "function_call_output" in types
|
||||
assert "message" in types
|
||||
|
||||
async def test_hosted_mcp_call_and_result_persist_as_single_mcp_call(self) -> None:
|
||||
agent = _make_agent(
|
||||
response=AgentResponse(
|
||||
messages=[
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_mcp_server_tool_call(
|
||||
call_id="mcp_abc123",
|
||||
tool_name="search",
|
||||
server_name="api_specs",
|
||||
arguments='{"q": "cats"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
Message(
|
||||
role="tool",
|
||||
contents=[
|
||||
Content.from_mcp_server_tool_result(
|
||||
call_id="mcp_abc123",
|
||||
output=[Content.from_text(text="found 10 cats")],
|
||||
)
|
||||
],
|
||||
),
|
||||
Message(role="assistant", contents=[Content.from_text("I found 10 cats!")]),
|
||||
]
|
||||
)
|
||||
)
|
||||
server = _make_server(agent)
|
||||
resp = await _post(server, stream=False)
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["status"] == "completed"
|
||||
|
||||
types = [item["type"] for item in body["output"]]
|
||||
assert "mcp_call" in types
|
||||
assert "custom_tool_call_output" not in types
|
||||
|
||||
mcp_items = [item for item in body["output"] if item["type"] == "mcp_call"]
|
||||
assert len(mcp_items) == 1
|
||||
assert mcp_items[0]["id"] == "mcp_abc123"
|
||||
assert mcp_items[0]["output"] == "found 10 cats"
|
||||
|
||||
async def test_reasoning_content(self) -> None:
|
||||
agent = _make_agent(
|
||||
response=AgentResponse(
|
||||
@@ -617,6 +661,53 @@ class TestStreaming:
|
||||
assert "response.output_item.added" in types
|
||||
assert "response.output_item.done" in types
|
||||
|
||||
async def test_mcp_tool_call_and_result_streaming_emit_single_completed_mcp_call(self) -> None:
|
||||
agent = _make_agent(
|
||||
stream_updates=[
|
||||
AgentResponseUpdate(
|
||||
contents=[
|
||||
Content.from_mcp_server_tool_call(
|
||||
call_id="mcp_abc123",
|
||||
tool_name="search",
|
||||
server_name="api_specs",
|
||||
arguments='{"q":',
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
),
|
||||
AgentResponseUpdate(
|
||||
contents=[
|
||||
Content.from_mcp_server_tool_call(
|
||||
call_id="mcp_abc123",
|
||||
tool_name="search",
|
||||
server_name="api_specs",
|
||||
arguments=' "cats"}',
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
),
|
||||
AgentResponseUpdate(
|
||||
contents=[
|
||||
Content.from_mcp_server_tool_result(
|
||||
call_id="mcp_abc123",
|
||||
output=[Content.from_text(text="found 10 cats")],
|
||||
)
|
||||
],
|
||||
role="tool",
|
||||
),
|
||||
]
|
||||
)
|
||||
server = _make_server(agent)
|
||||
resp = await _post(server, stream=True)
|
||||
|
||||
assert resp.status_code == 200
|
||||
events = _parse_sse_events(resp.text)
|
||||
done_events = [e for e in events if e["event"] == "response.output_item.done"]
|
||||
assert len(done_events) == 1
|
||||
assert done_events[0]["data"]["item"]["type"] == "mcp_call"
|
||||
assert done_events[0]["data"]["item"]["id"] == "mcp_abc123"
|
||||
assert done_events[0]["data"]["item"]["output"] == "found 10 cats"
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -720,6 +811,24 @@ class TestOutputItemToMessage:
|
||||
assert msg.contents[0].server_name == "my_server"
|
||||
assert msg.contents[0].tool_name == "search"
|
||||
|
||||
async def test_mcp_call_with_output_reconstructs_mcp_result_content(self) -> None:
|
||||
from azure.ai.agentserver.responses.models import OutputItemMcpToolCall
|
||||
|
||||
item = OutputItemMcpToolCall({
|
||||
"type": "mcp_call",
|
||||
"id": "mcp-1",
|
||||
"server_label": "my_server",
|
||||
"name": "search",
|
||||
"arguments": '{"q": "test"}',
|
||||
"output": "found 10 cats",
|
||||
})
|
||||
msg = await _output_item_to_message(item)
|
||||
assert msg.role == "assistant"
|
||||
assert len(msg.contents) == 2
|
||||
assert msg.contents[0].type == "mcp_server_tool_call"
|
||||
assert msg.contents[1].type == "mcp_server_tool_result"
|
||||
assert msg.contents[1].output == "found 10 cats"
|
||||
|
||||
async def test_mcp_approval_request(self) -> None:
|
||||
from azure.ai.agentserver.responses.models import OutputItemMcpApprovalRequest
|
||||
|
||||
@@ -1189,6 +1298,25 @@ class TestItemToMessage:
|
||||
assert msg.contents[0].server_name == "my_server"
|
||||
assert msg.contents[0].tool_name == "search"
|
||||
|
||||
async def test_mcp_call_with_output_reconstructs_mcp_result_content(self) -> None:
|
||||
from azure.ai.agentserver.responses.models import ItemMcpToolCall
|
||||
|
||||
item = ItemMcpToolCall({
|
||||
"type": "mcp_call",
|
||||
"id": "mcp-1",
|
||||
"server_label": "my_server",
|
||||
"name": "search",
|
||||
"arguments": '{"q": "test"}',
|
||||
"output": "found 10 cats",
|
||||
})
|
||||
msg = await _item_to_message(item)
|
||||
assert msg is not None
|
||||
assert msg.role == "assistant"
|
||||
assert len(msg.contents) == 2
|
||||
assert msg.contents[0].type == "mcp_server_tool_call"
|
||||
assert msg.contents[1].type == "mcp_server_tool_result"
|
||||
assert msg.contents[1].output == "found 10 cats"
|
||||
|
||||
async def test_mcp_approval_request(self) -> None:
|
||||
from azure.ai.agentserver.responses.models import ItemMcpApprovalRequest
|
||||
|
||||
@@ -1937,6 +2065,75 @@ class TestMultiTurnMixedContent:
|
||||
assert len(fc_contents) >= 1
|
||||
assert fc_contents[0].name == "search"
|
||||
|
||||
async def test_hosted_mcp_call_round_trip_does_not_orphan_function_call_output(self) -> None:
|
||||
"""Turn 1 produces hosted MCP call + result, turn 2 must replay both without orphaning output."""
|
||||
agent = _make_multi_response_agent([
|
||||
AgentResponse(
|
||||
messages=[
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_mcp_server_tool_call(
|
||||
call_id="mcp_abc123",
|
||||
tool_name="search",
|
||||
server_name="api_specs",
|
||||
arguments='{"q": "cats"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
Message(
|
||||
role="tool",
|
||||
contents=[
|
||||
Content.from_mcp_server_tool_result(
|
||||
call_id="mcp_abc123",
|
||||
output=[Content.from_text(text="found 10 cats")],
|
||||
)
|
||||
],
|
||||
),
|
||||
Message(role="assistant", contents=[Content.from_text("I found 10 cats!")]),
|
||||
]
|
||||
),
|
||||
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Here are more details")])]),
|
||||
])
|
||||
server = _make_server(agent)
|
||||
|
||||
resp1 = await _post(server, input_text="Search for cats", stream=False)
|
||||
assert resp1.status_code == 200
|
||||
response_id = resp1.json()["id"]
|
||||
|
||||
types1 = [item["type"] for item in resp1.json()["output"]]
|
||||
assert "mcp_call" in types1
|
||||
assert "custom_tool_call_output" not in types1
|
||||
|
||||
resp2 = await _post_json(
|
||||
server,
|
||||
{
|
||||
"model": "test-model",
|
||||
"input": "Tell me more",
|
||||
"stream": False,
|
||||
"previous_response_id": response_id,
|
||||
},
|
||||
)
|
||||
assert resp2.status_code == 200
|
||||
assert resp2.json()["status"] == "completed"
|
||||
|
||||
second_call_messages = agent.run.call_args_list[1].kwargs["messages"]
|
||||
mcp_call_contents = [
|
||||
c for m in second_call_messages for c in m.contents if c.type == "mcp_server_tool_call"
|
||||
]
|
||||
mcp_result_contents = [
|
||||
c for m in second_call_messages for c in m.contents if c.type == "mcp_server_tool_result"
|
||||
]
|
||||
function_result_contents = [
|
||||
c for m in second_call_messages for c in m.contents if c.type == "function_result"
|
||||
]
|
||||
|
||||
assert len(mcp_call_contents) >= 1
|
||||
assert len(mcp_result_contents) >= 1
|
||||
assert all((c.call_id or "") != "mcp_abc123" for c in function_result_contents)
|
||||
assert any((c.call_id or "") == "mcp_abc123" for c in mcp_call_contents)
|
||||
assert any((c.call_id or "") == "mcp_abc123" for c in mcp_result_contents)
|
||||
|
||||
async def test_multi_turn_reasoning_in_history(self) -> None:
|
||||
"""Turn 1 produces reasoning + text, turn 2 sees them in history."""
|
||||
agent = _make_multi_response_agent([
|
||||
|
||||
Reference in New Issue
Block a user