mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
9d8e5ca4f5
* Python: Allow hosted checkpoints to restore MessageRole Allow Responses hosting checkpoint storage to deserialize the Azure Responses MessageRole enum that hosted workflows can persist inside Agent Framework Message objects. Add regression coverage for both direct load() and the hosted get_latest() restore path, including the plain-storage failure mode where list_checkpoints logs the blocked type and get_latest() returns None. Ruff also normalizes a duplicate contextlib import in the touched hosting module. * Address MessageRole checkpoint review comments * Cover hosted MessageRole checkpoint restore path
3251 lines
130 KiB
Python
3251 lines
130 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
"""HTTP round-trip tests for ResponsesHostServer.
|
|
|
|
These tests exercise the full HTTP pipeline using httpx.AsyncClient with
|
|
ASGITransport — no real server process is started. Requests go through
|
|
the Starlette routing stack, the Responses API middleware, and arrive at
|
|
the registered _handle_create handler.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from collections.abc import AsyncIterator, Callable
|
|
from dataclasses import dataclass
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
from agent_framework import (
|
|
AgentResponse,
|
|
AgentResponseUpdate,
|
|
Content,
|
|
FileCheckpointStorage,
|
|
HistoryProvider,
|
|
Message,
|
|
RawAgent,
|
|
ResponseStream,
|
|
WorkflowCheckpoint,
|
|
WorkflowCheckpointException,
|
|
WorkflowMessage,
|
|
)
|
|
from azure.ai.agentserver.responses import InMemoryResponseProvider
|
|
from mcp import McpError
|
|
from mcp.types import ErrorData
|
|
from typing_extensions import Any
|
|
|
|
from agent_framework_foundry_hosting import ResponsesHostServer
|
|
from agent_framework_foundry_hosting._responses import (
|
|
_AZURE_RESPONSES_MESSAGE_ROLE_TYPE, # pyright: ignore[reportPrivateUsage]
|
|
CONSENT_ERROR_CODE,
|
|
FileBasedFunctionApprovalStorage, # pyright: ignore[reportPrivateUsage]
|
|
InMemoryFunctionApprovalStorage, # pyright: ignore[reportPrivateUsage]
|
|
_item_to_message, # pyright: ignore[reportPrivateUsage]
|
|
_output_item_to_message, # pyright: ignore[reportPrivateUsage]
|
|
consent_url_from_error,
|
|
)
|
|
|
|
|
|
def _make_function_approval_request_content(
|
|
*,
|
|
request_id: str = "apr_test",
|
|
call_id: str = "call_1",
|
|
name: str = "delete_file",
|
|
arguments: str = '{"path": "/foo"}',
|
|
server_label: str = "my_server",
|
|
) -> Content:
|
|
"""Build a function_approval_request Content with an embedded function_call."""
|
|
function_call = Content.from_function_call(
|
|
call_id, name, arguments=arguments, additional_properties={"server_label": server_label}
|
|
)
|
|
return Content.from_function_approval_request(request_id, function_call)
|
|
|
|
|
|
# region Helpers
|
|
|
|
|
|
def _make_agent(
|
|
*,
|
|
response: AgentResponse | None = None,
|
|
stream_updates: list[AgentResponseUpdate] | None = None,
|
|
raw_agent: bool = True,
|
|
) -> MagicMock:
|
|
"""Create a mock agent implementing SupportsAgentRun."""
|
|
agent = MagicMock(spec=RawAgent) if raw_agent else MagicMock()
|
|
agent.id = "test-agent"
|
|
agent.name = "Test Agent"
|
|
agent.description = "A mock agent for testing"
|
|
agent.context_providers = []
|
|
|
|
if response is not None:
|
|
|
|
async def run_non_streaming(*args: Any, **kwargs: Any) -> AgentResponse:
|
|
return response
|
|
|
|
agent.run = AsyncMock(side_effect=run_non_streaming)
|
|
|
|
if stream_updates is not None:
|
|
|
|
async def _stream_gen() -> AsyncIterator[AgentResponseUpdate]:
|
|
for update in stream_updates:
|
|
yield update
|
|
|
|
def run_streaming(*args: Any, **kwargs: Any) -> Any:
|
|
if kwargs.get("stream"):
|
|
return ResponseStream(_stream_gen()) # type: ignore
|
|
raise NotImplementedError("Only streaming is configured on this mock")
|
|
|
|
agent.run = MagicMock(side_effect=run_streaming)
|
|
|
|
return agent
|
|
|
|
|
|
def _make_server(agent: MagicMock, **kwargs: Any) -> ResponsesHostServer:
|
|
"""Create a ResponsesHostServer with an in-memory store."""
|
|
return ResponsesHostServer(agent, store=InMemoryResponseProvider(), **kwargs)
|
|
|
|
|
|
async def _post(
|
|
server: ResponsesHostServer,
|
|
*,
|
|
input_text: str = "Hello",
|
|
model: str = "test-model",
|
|
stream: bool = False,
|
|
temperature: float | None = None,
|
|
top_p: float | None = None,
|
|
max_output_tokens: int | None = None,
|
|
parallel_tool_calls: bool | None = None,
|
|
) -> httpx.Response:
|
|
"""Send a POST /responses request through the ASGI transport."""
|
|
payload: dict[str, Any] = {"model": model, "input": input_text, "stream": stream}
|
|
if temperature is not None:
|
|
payload["temperature"] = temperature
|
|
if top_p is not None:
|
|
payload["top_p"] = top_p
|
|
if max_output_tokens is not None:
|
|
payload["max_output_tokens"] = max_output_tokens
|
|
if parallel_tool_calls is not None:
|
|
payload["parallel_tool_calls"] = parallel_tool_calls
|
|
|
|
transport = httpx.ASGITransport(app=server)
|
|
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
|
return await client.post("/responses", json=payload)
|
|
|
|
|
|
def _parse_sse_events(body: str) -> list[dict[str, Any]]:
|
|
"""Parse SSE text into a list of event dicts with 'event' and 'data' keys."""
|
|
events: list[dict[str, Any]] = []
|
|
current_event: str | None = None
|
|
current_data_lines: list[str] = []
|
|
|
|
for line in body.split("\n"):
|
|
if line.startswith("event: "):
|
|
current_event = line[len("event: ") :]
|
|
elif line.startswith("data: "):
|
|
current_data_lines.append(line[len("data: ") :])
|
|
elif line.strip() == "" and current_event is not None:
|
|
data_str = "\n".join(current_data_lines)
|
|
try:
|
|
data = json.loads(data_str)
|
|
except json.JSONDecodeError:
|
|
data = data_str
|
|
events.append({"event": current_event, "data": data})
|
|
current_event = None
|
|
current_data_lines = []
|
|
|
|
return events
|
|
|
|
|
|
def _sse_event_types(events: list[dict[str, Any]]) -> list[str]:
|
|
"""Extract event type strings from parsed SSE events."""
|
|
return [e["event"] for e in events]
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Initialization
|
|
|
|
|
|
class TestResponsesHostServerInit:
|
|
def test_init_basic(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hi")])])
|
|
)
|
|
server = _make_server(agent)
|
|
assert server is not None
|
|
|
|
def test_init_rejects_history_provider_with_load_messages(self) -> None:
|
|
hp = HistoryProvider(source_id="test", load_messages=True)
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hi")])])
|
|
)
|
|
agent.context_providers = [hp]
|
|
with pytest.raises(RuntimeError, match="history provider"):
|
|
ResponsesHostServer(agent)
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Health Check
|
|
|
|
|
|
class TestHealthCheck:
|
|
async def test_readiness(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hi")])])
|
|
)
|
|
server = _make_server(agent)
|
|
transport = httpx.ASGITransport(app=server)
|
|
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
|
resp = await client.get("/readiness")
|
|
assert resp.status_code == 200
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Non-streaming
|
|
|
|
|
|
class TestNonStreaming:
|
|
async def test_basic_text_response(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Hello!")])])
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, input_text="Hi", stream=False)
|
|
|
|
assert resp.status_code == 200
|
|
assert "application/json" in resp.headers["content-type"]
|
|
|
|
body = resp.json()
|
|
assert body["object"] == "response"
|
|
assert body["status"] == "completed"
|
|
assert len(body["output"]) > 0
|
|
|
|
# Find the message output item with our text
|
|
text_found = False
|
|
for item in body["output"]:
|
|
assert item["type"] == "message"
|
|
for part in item.get("content", []):
|
|
if part.get("type") == "output_text" and part.get("text") == "Hello!":
|
|
text_found = True
|
|
assert text_found, f"Expected 'Hello!' in output, got: {body['output']}"
|
|
|
|
async def test_function_call_and_result(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(
|
|
messages=[
|
|
Message(
|
|
role="assistant",
|
|
contents=[Content.from_function_call("call_1", "get_weather", arguments='{"loc": "NYC"}')],
|
|
),
|
|
Message(role="tool", contents=[Content.from_function_result("call_1", result="sunny")]),
|
|
Message(role="assistant", contents=[Content.from_text("The weather is sunny!")]),
|
|
]
|
|
)
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=False)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
types = [item["type"] for item in body["output"]]
|
|
assert "function_call" in types
|
|
assert "function_call_output" in types
|
|
assert "message" in types
|
|
|
|
async def test_reasoning_content(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(
|
|
messages=[
|
|
Message(
|
|
role="assistant",
|
|
contents=[
|
|
Content.from_text_reasoning(text="Let me think..."),
|
|
Content.from_text("The answer is 42"),
|
|
],
|
|
),
|
|
]
|
|
)
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=False)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
types = [item["type"] for item in body["output"]]
|
|
assert "reasoning" in types
|
|
assert "message" in types
|
|
|
|
async def test_empty_response(self) -> None:
|
|
agent = _make_agent(response=AgentResponse(messages=[]))
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=False)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
async def test_chat_options_forwarded(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])]),
|
|
raw_agent=True,
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(
|
|
server,
|
|
stream=False,
|
|
temperature=0.5,
|
|
top_p=0.9,
|
|
max_output_tokens=1024,
|
|
parallel_tool_calls=True,
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
agent.run.assert_awaited_once()
|
|
call_kwargs = agent.run.call_args.kwargs
|
|
assert call_kwargs["stream"] is False
|
|
options = call_kwargs["options"]
|
|
assert options["temperature"] == 0.5
|
|
assert options["top_p"] == 0.9
|
|
assert options["max_tokens"] == 1024
|
|
assert options["allow_multiple_tool_calls"] is True
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Streaming
|
|
|
|
|
|
class TestStreaming:
|
|
async def test_chat_options_forwarded(self) -> None:
|
|
agent = _make_agent(
|
|
stream_updates=[AgentResponseUpdate(contents=[Content.from_text("ok")], role="assistant")],
|
|
raw_agent=True,
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(
|
|
server,
|
|
stream=True,
|
|
temperature=0.5,
|
|
top_p=0.9,
|
|
max_output_tokens=1024,
|
|
parallel_tool_calls=True,
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
agent.run.assert_called_once()
|
|
call_kwargs = agent.run.call_args.kwargs
|
|
assert call_kwargs["stream"] is True
|
|
options = call_kwargs["options"]
|
|
assert options["temperature"] == 0.5
|
|
assert options["top_p"] == 0.9
|
|
assert options["max_tokens"] == 1024
|
|
assert options["allow_multiple_tool_calls"] is True
|
|
|
|
async def test_basic_text_streaming(self) -> None:
|
|
agent = _make_agent(
|
|
stream_updates=[
|
|
AgentResponseUpdate(contents=[Content.from_text("Hello ")], role="assistant"),
|
|
AgentResponseUpdate(contents=[Content.from_text("world!")], role="assistant"),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
assert "text/event-stream" in resp.headers["content-type"]
|
|
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
|
|
assert types[0] == "response.created"
|
|
assert types[1] == "response.in_progress"
|
|
assert types[-1] == "response.completed"
|
|
assert "response.output_text.delta" in types
|
|
assert types.count("response.output_text.delta") == 2
|
|
assert "response.output_text.done" in types
|
|
|
|
# Verify the accumulated text in the done event
|
|
done_events = [e for e in events if e["event"] == "response.output_text.done"]
|
|
assert len(done_events) == 1
|
|
assert done_events[0]["data"]["text"] == "Hello world!"
|
|
|
|
async def test_function_call_streaming(self) -> None:
|
|
agent = _make_agent(
|
|
stream_updates=[
|
|
AgentResponseUpdate(
|
|
contents=[Content.from_function_call("call_1", "search", arguments='{"q":')],
|
|
role="assistant",
|
|
),
|
|
AgentResponseUpdate(
|
|
contents=[Content.from_function_call("call_1", "search", arguments=' "hello"}')],
|
|
role="assistant",
|
|
),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
|
|
assert types[0] == "response.created"
|
|
assert types[-1] == "response.completed"
|
|
assert types.count("response.function_call_arguments.delta") == 2
|
|
assert "response.function_call_arguments.done" in types
|
|
|
|
# Verify accumulated arguments
|
|
args_done = [e for e in events if e["event"] == "response.function_call_arguments.done"]
|
|
assert len(args_done) == 1
|
|
assert args_done[0]["data"]["arguments"] == '{"q": "hello"}'
|
|
|
|
async def test_function_call_streaming_serializes_dataclass_arguments(self) -> None:
|
|
@dataclass
|
|
class HandoffLikeRequest:
|
|
agent_response: AgentResponse
|
|
|
|
request = HandoffLikeRequest(
|
|
agent_response=AgentResponse(
|
|
messages=[Message(role="assistant", contents=[Content.from_text("Need more details")])]
|
|
)
|
|
)
|
|
agent = _make_agent(
|
|
stream_updates=[
|
|
AgentResponseUpdate(
|
|
contents=[Content.from_function_call("call_1", "handoff_to_refund", arguments=request)],
|
|
role="assistant",
|
|
),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
args_done = [e for e in events if e["event"] == "response.function_call_arguments.done"]
|
|
assert len(args_done) == 1
|
|
|
|
payload = json.loads(args_done[0]["data"]["arguments"])
|
|
assert payload["agent_response"]["type"] == "agent_response"
|
|
assert payload["agent_response"]["messages"][0]["contents"][0]["text"] == "Need more details"
|
|
|
|
async def test_alternating_text_and_function_call(self) -> None:
|
|
agent = _make_agent(
|
|
stream_updates=[
|
|
# Text deltas
|
|
AgentResponseUpdate(contents=[Content.from_text("Let me ")], role="assistant"),
|
|
AgentResponseUpdate(contents=[Content.from_text("search...")], role="assistant"),
|
|
# Function call argument deltas
|
|
AgentResponseUpdate(
|
|
contents=[Content.from_function_call("call_1", "search", arguments='{"q":')],
|
|
role="assistant",
|
|
),
|
|
AgentResponseUpdate(
|
|
contents=[Content.from_function_call("call_1", "search", arguments=' "x"}')],
|
|
role="assistant",
|
|
),
|
|
# More text deltas
|
|
AgentResponseUpdate(contents=[Content.from_text("Found ")], role="assistant"),
|
|
AgentResponseUpdate(contents=[Content.from_text("it!")], role="assistant"),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
|
|
assert types[0] == "response.created"
|
|
assert types[-1] == "response.completed"
|
|
|
|
# 4 text deltas + 2 function call argument deltas
|
|
assert types.count("response.output_text.delta") == 4
|
|
assert types.count("response.function_call_arguments.delta") == 2
|
|
|
|
# 3 distinct output items (text, fc, text)
|
|
assert types.count("response.output_item.added") == 3
|
|
assert types.count("response.output_item.done") == 3
|
|
|
|
# Verify accumulated content
|
|
text_done = [e for e in events if e["event"] == "response.output_text.done"]
|
|
assert len(text_done) == 2
|
|
assert text_done[0]["data"]["text"] == "Let me search..."
|
|
assert text_done[1]["data"]["text"] == "Found it!"
|
|
|
|
args_done = [e for e in events if e["event"] == "response.function_call_arguments.done"]
|
|
assert len(args_done) == 1
|
|
assert args_done[0]["data"]["arguments"] == '{"q": "x"}'
|
|
|
|
async def test_reasoning_then_text_streaming(self) -> None:
|
|
agent = _make_agent(
|
|
stream_updates=[
|
|
# Reasoning deltas
|
|
AgentResponseUpdate(contents=[Content.from_text_reasoning(text="Let me ")], role="assistant"),
|
|
AgentResponseUpdate(contents=[Content.from_text_reasoning(text="think...")], role="assistant"),
|
|
# Text deltas
|
|
AgentResponseUpdate(contents=[Content.from_text("The answer ")], role="assistant"),
|
|
AgentResponseUpdate(contents=[Content.from_text("is 42")], role="assistant"),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
|
|
assert types[0] == "response.created"
|
|
assert types[-1] == "response.completed"
|
|
# Reasoning + text = 2 output items
|
|
assert types.count("response.output_item.added") == 2
|
|
assert types.count("response.output_item.done") == 2
|
|
assert types.count("response.output_text.delta") == 2
|
|
|
|
# Verify accumulated text
|
|
text_done = [e for e in events if e["event"] == "response.output_text.done"]
|
|
assert len(text_done) == 1
|
|
assert text_done[0]["data"]["text"] == "The answer is 42"
|
|
|
|
async def test_empty_streaming(self) -> None:
|
|
agent = _make_agent(stream_updates=[])
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
|
|
assert types == ["response.created", "response.in_progress", "response.completed"]
|
|
|
|
async def test_mixed_contents_in_single_update(self) -> None:
|
|
"""Text and function call in one update switches builder mid-update."""
|
|
agent = _make_agent(
|
|
stream_updates=[
|
|
AgentResponseUpdate(
|
|
contents=[
|
|
Content.from_text("Let me search"),
|
|
Content.from_function_call("call_1", "search", arguments='{"q": "test"}'),
|
|
],
|
|
role="assistant",
|
|
),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
|
|
assert "response.output_text.delta" in types
|
|
assert "response.output_text.done" in types
|
|
assert "response.function_call_arguments.delta" in types
|
|
assert "response.function_call_arguments.done" in types
|
|
|
|
async def test_different_function_call_ids_produce_separate_items(self) -> None:
|
|
agent = _make_agent(
|
|
stream_updates=[
|
|
AgentResponseUpdate(
|
|
contents=[Content.from_function_call("call_1", "func_a", arguments='{"x":1}')],
|
|
role="assistant",
|
|
),
|
|
AgentResponseUpdate(
|
|
contents=[Content.from_function_call("call_2", "func_b", arguments='{"y":2}')],
|
|
role="assistant",
|
|
),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
|
|
# Two separate function call items
|
|
assert types.count("response.output_item.added") == 2
|
|
assert types.count("response.function_call_arguments.done") == 2
|
|
|
|
async def test_mcp_tool_call_streaming(self) -> None:
|
|
agent = _make_agent(
|
|
stream_updates=[
|
|
AgentResponseUpdate(
|
|
contents=[
|
|
Content(
|
|
type="mcp_server_tool_call",
|
|
server_name="my_server",
|
|
tool_name="search",
|
|
arguments='{"query":',
|
|
)
|
|
],
|
|
role="assistant",
|
|
),
|
|
AgentResponseUpdate(
|
|
contents=[
|
|
Content(
|
|
type="mcp_server_tool_call",
|
|
server_name="my_server",
|
|
tool_name="search",
|
|
arguments=' "test"}',
|
|
)
|
|
],
|
|
role="assistant",
|
|
),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
|
|
assert types[0] == "response.created"
|
|
assert types[-1] == "response.completed"
|
|
assert "response.output_item.added" in types
|
|
assert "response.output_item.done" in types
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region _output_item_to_message conversion
|
|
|
|
|
|
class TestOutputItemToMessage:
|
|
"""Tests for _output_item_to_message covering all supported OutputItem types."""
|
|
|
|
async def test_output_message(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemOutputMessage, OutputMessageContentOutputTextContent
|
|
|
|
item = OutputItemOutputMessage({
|
|
"type": "output_message",
|
|
"role": "assistant",
|
|
"content": [OutputMessageContentOutputTextContent({"type": "output_text", "text": "hello"})],
|
|
"status": "completed",
|
|
"id": "msg-1",
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert len(msg.contents) == 1
|
|
assert msg.contents[0].type == "text"
|
|
assert msg.contents[0].text == "hello"
|
|
|
|
async def test_message(self) -> None:
|
|
from azure.ai.agentserver.responses.models import MessageContentInputTextContent, OutputItemMessage
|
|
|
|
item = OutputItemMessage({
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [MessageContentInputTextContent({"type": "input_text", "text": "hi"})],
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "user"
|
|
assert len(msg.contents) == 1
|
|
assert msg.contents[0].text == "hi"
|
|
|
|
async def test_function_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemFunctionToolCall
|
|
|
|
item = OutputItemFunctionToolCall({
|
|
"type": "function_call",
|
|
"call_id": "call_1",
|
|
"name": "get_weather",
|
|
"arguments": '{"city": "NYC"}',
|
|
"status": "completed",
|
|
"id": "fc-1",
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].call_id == "call_1"
|
|
assert msg.contents[0].name == "get_weather"
|
|
|
|
async def test_function_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import FunctionCallOutputItemParam
|
|
|
|
item = FunctionCallOutputItemParam({"type": "function_call_output", "call_id": "call_1", "output": "sunny"})
|
|
msg = await _output_item_to_message(item) # type: ignore[arg-type]
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "function_result"
|
|
assert msg.contents[0].call_id == "call_1"
|
|
assert msg.contents[0].result == "sunny"
|
|
|
|
async def test_reasoning(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemReasoningItem, SummaryTextContent
|
|
|
|
item = OutputItemReasoningItem({
|
|
"type": "reasoning",
|
|
"id": "r-1",
|
|
"summary": [SummaryTextContent({"type": "summary_text", "text": "thinking hard"})],
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert len(msg.contents) == 1
|
|
assert msg.contents[0].text == "thinking hard"
|
|
|
|
async def test_reasoning_no_summary(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemReasoningItem
|
|
|
|
item = OutputItemReasoningItem({"type": "reasoning", "id": "r-2"})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents == []
|
|
|
|
async def test_mcp_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemMcpToolCall
|
|
|
|
item = OutputItemMcpToolCall({
|
|
"type": "mcp_call",
|
|
"id": "mcp-1",
|
|
"server_label": "my_server",
|
|
"name": "search",
|
|
"arguments": '{"q": "test"}',
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "mcp_server_tool_call"
|
|
assert msg.contents[0].server_name == "my_server"
|
|
assert msg.contents[0].tool_name == "search"
|
|
|
|
async def test_mcp_approval_request(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemMcpApprovalRequest
|
|
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
saved = _make_function_approval_request_content(request_id="apr-1")
|
|
await storage.save_approval_request("apr-1", saved)
|
|
|
|
item = OutputItemMcpApprovalRequest({
|
|
"type": "mcp_approval_request",
|
|
"id": "apr-1",
|
|
"server_label": "srv",
|
|
"name": "dangerous_tool",
|
|
"arguments": "{}",
|
|
})
|
|
msg = await _output_item_to_message(item, approval_storage=storage)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_approval_request"
|
|
|
|
async def test_mcp_approval_response(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemMcpApprovalResponseResource
|
|
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
saved = _make_function_approval_request_content(request_id="apr-1")
|
|
await storage.save_approval_request("apr-1", saved)
|
|
|
|
item = OutputItemMcpApprovalResponseResource({
|
|
"type": "mcp_approval_response",
|
|
"id": "resp-1",
|
|
"approval_request_id": "apr-1",
|
|
"approve": True,
|
|
})
|
|
msg = await _output_item_to_message(item, approval_storage=storage)
|
|
assert msg.role == "user"
|
|
assert msg.contents[0].type == "function_approval_response"
|
|
assert msg.contents[0].approved is True
|
|
|
|
async def test_code_interpreter_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemCodeInterpreterToolCall
|
|
|
|
item = OutputItemCodeInterpreterToolCall({
|
|
"type": "code_interpreter_call",
|
|
"id": "ci-1",
|
|
"status": "completed",
|
|
"container_id": "c-1",
|
|
"code": "print('hi')",
|
|
"outputs": [],
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "code_interpreter_tool_call"
|
|
|
|
async def test_image_generation_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemImageGenToolCall
|
|
|
|
item = OutputItemImageGenToolCall({"type": "image_generation_call", "id": "ig-1", "status": "completed"})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "image_generation_tool_call"
|
|
|
|
async def test_shell_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import (
|
|
FunctionShellAction,
|
|
FunctionShellCallEnvironment,
|
|
OutputItemFunctionShellCall,
|
|
)
|
|
|
|
item = OutputItemFunctionShellCall({
|
|
"type": "shell_call",
|
|
"id": "sc-1",
|
|
"call_id": "call_sc",
|
|
"action": FunctionShellAction({"commands": ["ls", "-la"], "timeout_ms": 5000, "max_output_length": 1024}),
|
|
"status": "completed",
|
|
"environment": FunctionShellCallEnvironment({"type": "local"}),
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "shell_tool_call"
|
|
assert msg.contents[0].commands == ["ls", "-la"]
|
|
assert msg.contents[0].call_id == "call_sc"
|
|
|
|
async def test_shell_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import (
|
|
FunctionShellCallOutputContent,
|
|
FunctionShellCallOutputExitOutcome,
|
|
OutputItemFunctionShellCallOutput,
|
|
)
|
|
|
|
item = OutputItemFunctionShellCallOutput({
|
|
"type": "shell_call_output",
|
|
"id": "sco-1",
|
|
"call_id": "call_sc",
|
|
"status": "completed",
|
|
"output": [
|
|
FunctionShellCallOutputContent({
|
|
"stdout": "file.txt",
|
|
"stderr": "",
|
|
"outcome": FunctionShellCallOutputExitOutcome({"exit_code": 0}),
|
|
})
|
|
],
|
|
"max_output_length": 1024,
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "shell_tool_result"
|
|
assert msg.contents[0].call_id == "call_sc"
|
|
|
|
async def test_local_shell_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import LocalShellExecAction, OutputItemLocalShellToolCall
|
|
|
|
item = OutputItemLocalShellToolCall({
|
|
"type": "local_shell_call",
|
|
"id": "lsc-1",
|
|
"call_id": "call_lsc",
|
|
"action": LocalShellExecAction({"type": "exec", "command": ["echo", "hello"], "env": {}}),
|
|
"status": "completed",
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "shell_tool_call"
|
|
assert msg.contents[0].commands == ["echo", "hello"]
|
|
|
|
async def test_local_shell_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemLocalShellToolCallOutput
|
|
|
|
item = OutputItemLocalShellToolCallOutput({
|
|
"type": "local_shell_call_output",
|
|
"id": "lsco-1",
|
|
"output": "hello\n",
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "shell_tool_result"
|
|
|
|
async def test_file_search_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemFileSearchToolCall
|
|
|
|
item = OutputItemFileSearchToolCall({
|
|
"type": "file_search_call",
|
|
"id": "fs-1",
|
|
"status": "completed",
|
|
"queries": ["what is AI"],
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "file_search"
|
|
assert '"what is AI"' in (msg.contents[0].arguments or "")
|
|
|
|
async def test_web_search_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemWebSearchToolCall, WebSearchActionSearch
|
|
|
|
item = OutputItemWebSearchToolCall({
|
|
"type": "web_search_call",
|
|
"id": "ws-1",
|
|
"status": "completed",
|
|
"action": WebSearchActionSearch({"type": "search", "query": "test"}),
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "web_search"
|
|
|
|
async def test_computer_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ComputerAction, OutputItemComputerToolCall
|
|
|
|
item = OutputItemComputerToolCall({
|
|
"type": "computer_call",
|
|
"id": "cc-1",
|
|
"call_id": "call_cc",
|
|
"action": ComputerAction({"type": "click"}),
|
|
"pending_safety_checks": [],
|
|
"status": "completed",
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "computer_use"
|
|
|
|
async def test_computer_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import (
|
|
ComputerScreenshotImage,
|
|
OutputItemComputerToolCallOutputResource,
|
|
)
|
|
|
|
item = OutputItemComputerToolCallOutputResource({
|
|
"type": "computer_call_output",
|
|
"call_id": "call_cc",
|
|
"output": ComputerScreenshotImage({
|
|
"type": "computer_screenshot",
|
|
"image_url": "data:image/png;base64,abc",
|
|
}),
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "function_result"
|
|
assert msg.contents[0].call_id == "call_cc"
|
|
|
|
async def test_custom_tool_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemCustomToolCall
|
|
|
|
item = OutputItemCustomToolCall({
|
|
"type": "custom_tool_call",
|
|
"call_id": "call_ct",
|
|
"name": "my_tool",
|
|
"input": '{"key": "value"}',
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "my_tool"
|
|
assert msg.contents[0].arguments == '{"key": "value"}'
|
|
|
|
async def test_custom_tool_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemCustomToolCallOutput
|
|
|
|
item = OutputItemCustomToolCallOutput({
|
|
"type": "custom_tool_call_output",
|
|
"call_id": "call_ct",
|
|
"output": "result text",
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "function_result"
|
|
assert msg.contents[0].result == "result text"
|
|
|
|
async def test_custom_tool_call_output_with_mcp_call_id_routes_to_mcp_server_tool_result(self) -> None:
|
|
"""When the host wrote a hosted-MCP result via
|
|
`aoutput_item_custom_tool_call_output`, the persisted call_id keeps
|
|
its `mcp_*` prefix. On read, that result must reconstruct as a
|
|
`mcp_server_tool_result` Content (not `function_result`), so the
|
|
chat-client serialize layer treats it as a hosted-MCP result and
|
|
does not produce an orphan `function_call_output`.
|
|
"""
|
|
from azure.ai.agentserver.responses.models import OutputItemCustomToolCallOutput
|
|
|
|
item = OutputItemCustomToolCallOutput({
|
|
"type": "custom_tool_call_output",
|
|
"call_id": "mcp_06b686e11f118cf40169f0e5badb3081979842929d5cf04920",
|
|
"output": "found 10 cats",
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "tool"
|
|
assert len(msg.contents) == 1
|
|
c = msg.contents[0]
|
|
assert c.type == "mcp_server_tool_result", (
|
|
f"expected mcp_server_tool_result for mcp_-prefixed call_id; got {c.type}"
|
|
)
|
|
assert c.call_id == "mcp_06b686e11f118cf40169f0e5badb3081979842929d5cf04920"
|
|
|
|
async def test_apply_patch_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ApplyPatchUpdateFileOperation, OutputItemApplyPatchToolCall
|
|
|
|
item = OutputItemApplyPatchToolCall({
|
|
"type": "apply_patch_call",
|
|
"id": "ap-1",
|
|
"call_id": "call_ap",
|
|
"status": "completed",
|
|
"operation": ApplyPatchUpdateFileOperation({
|
|
"type": "update_file",
|
|
"path": "file.py",
|
|
"diff": "+ new line",
|
|
}),
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "apply_patch"
|
|
|
|
async def test_apply_patch_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemApplyPatchToolCallOutput
|
|
|
|
item = OutputItemApplyPatchToolCallOutput({
|
|
"type": "apply_patch_call_output",
|
|
"id": "apo-1",
|
|
"call_id": "call_ap",
|
|
"status": "completed",
|
|
"output": "patch applied",
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "function_result"
|
|
assert msg.contents[0].result == "patch applied"
|
|
|
|
async def test_oauth_consent_request(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OAuthConsentRequestOutputItem
|
|
|
|
item = OAuthConsentRequestOutputItem({
|
|
"type": "oauth_consent_request",
|
|
"id": "oauth-1",
|
|
"consent_link": "https://example.com/consent",
|
|
"server_label": "my_server",
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "oauth_consent_request"
|
|
assert msg.contents[0].consent_link == "https://example.com/consent"
|
|
|
|
async def test_structured_outputs_dict(self) -> None:
|
|
from azure.ai.agentserver.responses.models import StructuredOutputsOutputItem
|
|
|
|
item = StructuredOutputsOutputItem({"type": "structured_outputs", "id": "so-1", "output": {"answer": 42}})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "text"
|
|
assert json.loads(msg.contents[0].text or "") == {"answer": 42}
|
|
|
|
async def test_structured_outputs_string(self) -> None:
|
|
from azure.ai.agentserver.responses.models import StructuredOutputsOutputItem
|
|
|
|
item = StructuredOutputsOutputItem({"type": "structured_outputs", "id": "so-2", "output": "plain text"})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].text == "plain text"
|
|
|
|
async def test_unsupported_type_raises(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItem
|
|
|
|
item = OutputItem({"type": "some_unknown_type"})
|
|
with pytest.raises(ValueError, match="Unsupported OutputItem type: some_unknown_type"):
|
|
await _output_item_to_message(item)
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region _item_to_message conversion
|
|
|
|
|
|
class TestItemToMessage:
|
|
"""Tests for _item_to_message covering all supported Item types."""
|
|
|
|
async def test_message_with_string_content(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemMessage
|
|
|
|
item = ItemMessage({"type": "message", "role": "user", "content": "hello"})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "user"
|
|
assert len(msg.contents) == 1
|
|
assert msg.contents[0].type == "text"
|
|
assert msg.contents[0].text == "hello"
|
|
|
|
async def test_message_with_input_text_content(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemMessage, MessageContentInputTextContent
|
|
|
|
item = ItemMessage({
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [MessageContentInputTextContent({"type": "input_text", "text": "hi there"})],
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "user"
|
|
assert len(msg.contents) == 1
|
|
assert msg.contents[0].text == "hi there"
|
|
|
|
async def test_message_with_multiple_contents(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemMessage, MessageContentInputTextContent
|
|
|
|
item = ItemMessage({
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
MessageContentInputTextContent({"type": "input_text", "text": "first"}),
|
|
MessageContentInputTextContent({"type": "input_text", "text": "second"}),
|
|
],
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert len(msg.contents) == 2
|
|
assert msg.contents[0].text == "first"
|
|
assert msg.contents[1].text == "second"
|
|
|
|
async def test_output_message(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemOutputMessage, OutputMessageContentOutputTextContent
|
|
|
|
item = ItemOutputMessage({
|
|
"type": "output_message",
|
|
"role": "assistant",
|
|
"content": [OutputMessageContentOutputTextContent({"type": "output_text", "text": "response"})],
|
|
"status": "completed",
|
|
"id": "msg-1",
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert len(msg.contents) == 1
|
|
assert msg.contents[0].type == "text"
|
|
assert msg.contents[0].text == "response"
|
|
|
|
async def test_function_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemFunctionToolCall
|
|
|
|
item = ItemFunctionToolCall({
|
|
"type": "function_call",
|
|
"call_id": "call_1",
|
|
"name": "get_weather",
|
|
"arguments": '{"city": "NYC"}',
|
|
"status": "completed",
|
|
"id": "fc-1",
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].call_id == "call_1"
|
|
assert msg.contents[0].name == "get_weather"
|
|
assert msg.contents[0].arguments == '{"city": "NYC"}'
|
|
|
|
async def test_function_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import FunctionCallOutputItemParam
|
|
|
|
item = FunctionCallOutputItemParam({"type": "function_call_output", "call_id": "call_1", "output": "sunny"})
|
|
msg = await _item_to_message(item) # type: ignore[arg-type]
|
|
assert msg is not None
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "function_result"
|
|
assert msg.contents[0].call_id == "call_1"
|
|
assert msg.contents[0].result == "sunny"
|
|
|
|
async def test_function_call_output_non_string(self) -> None:
|
|
from azure.ai.agentserver.responses.models import FunctionCallOutputItemParam
|
|
|
|
item = FunctionCallOutputItemParam({"type": "function_call_output", "call_id": "call_2", "output": 42})
|
|
msg = await _item_to_message(item) # type: ignore[arg-type]
|
|
assert msg is not None
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].result == "42"
|
|
|
|
async def test_reasoning_with_summary(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemReasoningItem, SummaryTextContent
|
|
|
|
item = ItemReasoningItem({
|
|
"type": "reasoning",
|
|
"id": "r-1",
|
|
"summary": [SummaryTextContent({"type": "summary_text", "text": "thinking hard"})],
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert len(msg.contents) == 1
|
|
assert msg.contents[0].text == "thinking hard"
|
|
|
|
async def test_reasoning_no_summary(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemReasoningItem
|
|
|
|
item = ItemReasoningItem({"type": "reasoning", "id": "r-2"})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents == []
|
|
|
|
async def test_mcp_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemMcpToolCall
|
|
|
|
item = ItemMcpToolCall({
|
|
"type": "mcp_call",
|
|
"id": "mcp-1",
|
|
"server_label": "my_server",
|
|
"name": "search",
|
|
"arguments": '{"q": "test"}',
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "mcp_server_tool_call"
|
|
assert msg.contents[0].server_name == "my_server"
|
|
assert msg.contents[0].tool_name == "search"
|
|
|
|
async def test_mcp_approval_request(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemMcpApprovalRequest
|
|
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
saved = _make_function_approval_request_content(request_id="apr-1")
|
|
await storage.save_approval_request("apr-1", saved)
|
|
|
|
item = ItemMcpApprovalRequest({
|
|
"type": "mcp_approval_request",
|
|
"id": "apr-1",
|
|
"server_label": "srv",
|
|
"name": "dangerous_tool",
|
|
"arguments": "{}",
|
|
})
|
|
msg = await _item_to_message(item, approval_storage=storage)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_approval_request"
|
|
|
|
async def test_mcp_approval_response(self) -> None:
|
|
from azure.ai.agentserver.responses.models import MCPApprovalResponse
|
|
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
saved = _make_function_approval_request_content(request_id="apr-1")
|
|
await storage.save_approval_request("apr-1", saved)
|
|
|
|
item = MCPApprovalResponse({
|
|
"type": "mcp_approval_response",
|
|
"approval_request_id": "apr-1",
|
|
"approve": True,
|
|
})
|
|
msg = await _item_to_message(item, approval_storage=storage) # type: ignore[arg-type]
|
|
assert msg is not None
|
|
assert msg.role == "user"
|
|
assert msg.contents[0].type == "function_approval_response"
|
|
assert msg.contents[0].approved is True
|
|
|
|
async def test_code_interpreter_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemCodeInterpreterToolCall
|
|
|
|
item = ItemCodeInterpreterToolCall({
|
|
"type": "code_interpreter_call",
|
|
"id": "ci-1",
|
|
"status": "completed",
|
|
"container_id": "c-1",
|
|
"code": "print('hi')",
|
|
"outputs": [],
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "code_interpreter_tool_call"
|
|
|
|
async def test_image_generation_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemImageGenToolCall
|
|
|
|
item = ItemImageGenToolCall({"type": "image_generation_call", "id": "ig-1", "status": "completed"})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "image_generation_tool_call"
|
|
|
|
async def test_shell_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import FunctionShellAction, FunctionShellCallItemParam
|
|
|
|
item = FunctionShellCallItemParam({
|
|
"type": "shell_call",
|
|
"call_id": "call_sc",
|
|
"action": FunctionShellAction({"commands": ["ls", "-la"], "timeout_ms": 5000, "max_output_length": 1024}),
|
|
"status": "in_progress",
|
|
})
|
|
msg = await _item_to_message(item) # type: ignore[arg-type]
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "shell_tool_call"
|
|
assert msg.contents[0].commands == ["ls", "-la"]
|
|
assert msg.contents[0].call_id == "call_sc"
|
|
|
|
async def test_shell_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import (
|
|
FunctionShellCallOutputContent,
|
|
FunctionShellCallOutputExitOutcome,
|
|
FunctionShellCallOutputItemParam,
|
|
)
|
|
|
|
item = FunctionShellCallOutputItemParam({
|
|
"type": "shell_call_output",
|
|
"call_id": "call_sc",
|
|
"output": [
|
|
FunctionShellCallOutputContent({
|
|
"stdout": "file.txt",
|
|
"stderr": "",
|
|
"outcome": FunctionShellCallOutputExitOutcome({"exit_code": 0}),
|
|
})
|
|
],
|
|
"max_output_length": 1024,
|
|
})
|
|
msg = await _item_to_message(item) # type: ignore[arg-type]
|
|
assert msg is not None
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "shell_tool_result"
|
|
assert msg.contents[0].call_id == "call_sc"
|
|
|
|
async def test_local_shell_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemLocalShellToolCall, LocalShellExecAction
|
|
|
|
item = ItemLocalShellToolCall({
|
|
"type": "local_shell_call",
|
|
"id": "lsc-1",
|
|
"call_id": "call_lsc",
|
|
"action": LocalShellExecAction({"type": "exec", "command": ["echo", "hello"], "env": {}}),
|
|
"status": "completed",
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "shell_tool_call"
|
|
assert msg.contents[0].commands == ["echo", "hello"]
|
|
|
|
async def test_local_shell_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemLocalShellToolCallOutput
|
|
|
|
item = ItemLocalShellToolCallOutput({
|
|
"type": "local_shell_call_output",
|
|
"id": "lsco-1",
|
|
"output": "hello\n",
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "shell_tool_result"
|
|
|
|
async def test_file_search_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemFileSearchToolCall
|
|
|
|
item = ItemFileSearchToolCall({
|
|
"type": "file_search_call",
|
|
"id": "fs-1",
|
|
"status": "completed",
|
|
"queries": ["what is AI"],
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "file_search"
|
|
assert '"what is AI"' in (msg.contents[0].arguments or "")
|
|
|
|
async def test_web_search_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemWebSearchToolCall
|
|
|
|
item = ItemWebSearchToolCall({
|
|
"type": "web_search_call",
|
|
"id": "ws-1",
|
|
"status": "completed",
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "web_search"
|
|
|
|
async def test_computer_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ComputerAction, ItemComputerToolCall
|
|
|
|
item = ItemComputerToolCall({
|
|
"type": "computer_call",
|
|
"id": "cc-1",
|
|
"call_id": "call_cc",
|
|
"action": ComputerAction({"type": "click"}),
|
|
"pending_safety_checks": [],
|
|
"status": "completed",
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "computer_use"
|
|
|
|
async def test_computer_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ComputerCallOutputItemParam, ComputerScreenshotImage
|
|
|
|
item = ComputerCallOutputItemParam({
|
|
"type": "computer_call_output",
|
|
"call_id": "call_cc",
|
|
"output": ComputerScreenshotImage({
|
|
"type": "computer_screenshot",
|
|
"image_url": "data:image/png;base64,abc",
|
|
}),
|
|
})
|
|
msg = await _item_to_message(item) # type: ignore[arg-type]
|
|
assert msg is not None
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "function_result"
|
|
assert msg.contents[0].call_id == "call_cc"
|
|
|
|
async def test_custom_tool_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemCustomToolCall
|
|
|
|
item = ItemCustomToolCall({
|
|
"type": "custom_tool_call",
|
|
"call_id": "call_ct",
|
|
"name": "my_tool",
|
|
"input": '{"key": "value"}',
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "my_tool"
|
|
assert msg.contents[0].arguments == '{"key": "value"}'
|
|
|
|
async def test_custom_tool_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemCustomToolCallOutput
|
|
|
|
item = ItemCustomToolCallOutput({
|
|
"type": "custom_tool_call_output",
|
|
"call_id": "call_ct",
|
|
"output": "result text",
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "function_result"
|
|
assert msg.contents[0].result == "result text"
|
|
|
|
async def test_custom_tool_call_output_non_string(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemCustomToolCallOutput
|
|
|
|
item = ItemCustomToolCallOutput({
|
|
"type": "custom_tool_call_output",
|
|
"call_id": "call_ct2",
|
|
"output": 123,
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.contents[0].result == "123"
|
|
|
|
async def test_custom_tool_call_output_with_mcp_call_id_routes_to_mcp_server_tool_result(self) -> None:
|
|
"""Issue #5546: input items carrying a hosted-MCP result (from a
|
|
prior turn that the framework wrote via
|
|
`aoutput_item_custom_tool_call_output`) must reconstruct as a
|
|
`mcp_server_tool_result` Content, not `function_result`. Otherwise
|
|
the chat-client serialize layer turns it into an orphan
|
|
`function_call_output` with `mcp_*` call_id and the Responses API
|
|
rejects the next turn.
|
|
"""
|
|
from azure.ai.agentserver.responses.models import ItemCustomToolCallOutput
|
|
|
|
item = ItemCustomToolCallOutput({
|
|
"type": "custom_tool_call_output",
|
|
"call_id": "mcp_06b686e11f118cf40169f0e5badb3081979842929d5cf04920",
|
|
"output": "found 10 cats",
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "tool"
|
|
assert len(msg.contents) == 1
|
|
c = msg.contents[0]
|
|
assert c.type == "mcp_server_tool_result", (
|
|
f"expected mcp_server_tool_result for mcp_-prefixed call_id; got {c.type}"
|
|
)
|
|
assert c.call_id == "mcp_06b686e11f118cf40169f0e5badb3081979842929d5cf04920"
|
|
|
|
async def test_apply_patch_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ApplyPatchToolCallItemParam, ApplyPatchUpdateFileOperation
|
|
|
|
item = ApplyPatchToolCallItemParam({
|
|
"type": "apply_patch_call",
|
|
"call_id": "call_ap",
|
|
"operation": ApplyPatchUpdateFileOperation({
|
|
"type": "update_file",
|
|
"path": "file.py",
|
|
"diff": "+ new line",
|
|
}),
|
|
})
|
|
msg = await _item_to_message(item) # type: ignore[arg-type]
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "apply_patch"
|
|
|
|
async def test_apply_patch_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ApplyPatchToolCallOutputItemParam
|
|
|
|
item = ApplyPatchToolCallOutputItemParam({
|
|
"type": "apply_patch_call_output",
|
|
"call_id": "call_ap",
|
|
"output": "patch applied",
|
|
})
|
|
msg = await _item_to_message(item) # type: ignore[arg-type]
|
|
assert msg is not None
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "function_result"
|
|
assert msg.contents[0].result == "patch applied"
|
|
|
|
async def test_unsupported_type_raises(self) -> None:
|
|
from azure.ai.agentserver.responses.models import Item
|
|
|
|
item = Item({"type": "some_unknown_type"})
|
|
with pytest.raises(ValueError, match="Unsupported Item type: some_unknown_type"):
|
|
await _item_to_message(item)
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Multi-turn with mixed content
|
|
|
|
|
|
async def _post_json(
|
|
server: ResponsesHostServer,
|
|
payload: dict[str, Any],
|
|
) -> httpx.Response:
|
|
"""Send a POST /responses request with a raw JSON payload."""
|
|
transport = httpx.ASGITransport(app=server)
|
|
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
|
return await client.post("/responses", json=payload)
|
|
|
|
|
|
def _make_multi_response_agent(
|
|
responses: list[AgentResponse],
|
|
stream_updates_list: list[list[AgentResponseUpdate]] | None = None,
|
|
) -> MagicMock:
|
|
"""Create a mock agent that returns different responses on successive calls."""
|
|
agent = MagicMock(spec=RawAgent)
|
|
agent.id = "test-agent"
|
|
agent.name = "Test Agent"
|
|
agent.description = "A mock agent for testing"
|
|
agent.context_providers = []
|
|
|
|
call_index = [0]
|
|
|
|
async def run_non_streaming(*args: Any, **kwargs: Any) -> AgentResponse:
|
|
idx = call_index[0]
|
|
call_index[0] += 1
|
|
return responses[idx]
|
|
|
|
async def _stream_gen(updates: list[AgentResponseUpdate]) -> AsyncIterator[AgentResponseUpdate]:
|
|
for update in updates:
|
|
yield update
|
|
|
|
def run_dispatch(*args: Any, **kwargs: Any) -> Any:
|
|
idx = call_index[0]
|
|
call_index[0] += 1
|
|
if kwargs.get("stream") and stream_updates_list is not None:
|
|
return ResponseStream(_stream_gen(stream_updates_list[idx])) # type: ignore
|
|
if not kwargs.get("stream"):
|
|
# Need to return a coroutine for non-streaming
|
|
async def _ret() -> AgentResponse:
|
|
return responses[idx]
|
|
|
|
return _ret()
|
|
raise NotImplementedError("Streaming not configured for this call index")
|
|
|
|
if stream_updates_list is not None:
|
|
agent.run = MagicMock(side_effect=run_dispatch)
|
|
else:
|
|
agent.run = AsyncMock(side_effect=run_non_streaming)
|
|
|
|
return agent
|
|
|
|
|
|
class TestMultiTurnMixedContent:
|
|
"""End-to-end multi-turn tests with mixed text and non-text content types."""
|
|
|
|
async def test_text_and_image_input_single_turn(self) -> None:
|
|
"""Agent receives a message with text and image content via URL."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("I see a cat!")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "Describe this animal"},
|
|
{"type": "input_image", "image_url": "https://example.com/cat.jpg"},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
# Verify agent received text + image
|
|
messages = agent.run.call_args.kwargs["messages"]
|
|
assert len(messages) == 1
|
|
assert messages[0].role == "user"
|
|
assert len(messages[0].contents) == 2
|
|
assert messages[0].contents[0].type == "text"
|
|
assert messages[0].contents[0].text == "Describe this animal"
|
|
assert messages[0].contents[1].type == "uri"
|
|
assert messages[0].contents[1].uri == "https://example.com/cat.jpg"
|
|
|
|
async def test_text_and_file_input_single_turn(self) -> None:
|
|
"""Agent receives a message with text and file content via URL."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("File received")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "Summarize this document"},
|
|
{"type": "input_file", "file_url": "https://example.com/doc.pdf", "filename": "doc.pdf"},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
messages = agent.run.call_args.kwargs["messages"]
|
|
assert len(messages) == 1
|
|
assert len(messages[0].contents) == 2
|
|
assert messages[0].contents[0].type == "text"
|
|
assert messages[0].contents[0].text == "Summarize this document"
|
|
assert messages[0].contents[1].type == "uri"
|
|
assert messages[0].contents[1].uri == "https://example.com/doc.pdf"
|
|
|
|
async def test_text_and_file_data_input_single_turn(self) -> None:
|
|
"""Agent receives a message with text and file content via inline file_data."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("File received")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "Summarize this document"},
|
|
{
|
|
"type": "input_file",
|
|
"file_data": "data:application/pdf;base64,JVBERi0xLjQ=",
|
|
"filename": "doc.pdf",
|
|
},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
messages = agent.run.call_args.kwargs["messages"]
|
|
assert len(messages) == 1
|
|
assert len(messages[0].contents) == 2
|
|
assert messages[0].contents[0].type == "text"
|
|
assert messages[0].contents[0].text == "Summarize this document"
|
|
assert messages[0].contents[1].type == "data"
|
|
assert messages[0].contents[1].uri == "data:application/pdf;base64,JVBERi0xLjQ="
|
|
|
|
async def test_text_mime_file_data_decoded(self) -> None:
|
|
"""Agent receives a text/* file_data that is base64-decoded to plain text."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Got it")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
import base64
|
|
|
|
encoded = base64.b64encode(b"Hello, world!").decode()
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "input_file",
|
|
"file_data": f"data:text/plain;base64,{encoded}",
|
|
"filename": "greeting.txt",
|
|
},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
|
|
messages = agent.run.call_args.kwargs["messages"]
|
|
assert len(messages) == 1
|
|
assert messages[0].contents[0].type == "text"
|
|
assert messages[0].contents[0].text == "[File: greeting.txt]\nHello, world!"
|
|
|
|
async def test_text_mime_file_data_invalid_base64_falls_through(self) -> None:
|
|
"""Invalid base64 in a text/* file_data falls through to URI passthrough."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Got it")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "input_file",
|
|
"file_data": "data:text/plain;base64,!!!invalid!!!",
|
|
"filename": "bad.txt",
|
|
},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
|
|
messages = agent.run.call_args.kwargs["messages"]
|
|
assert len(messages) == 1
|
|
assert messages[0].contents[0].type == "data"
|
|
assert messages[0].contents[0].uri == "data:text/plain;base64,!!!invalid!!!"
|
|
|
|
async def test_mixed_text_and_image_input(self) -> None:
|
|
"""Agent receives a single message with both text and image content."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Got it!")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "What's in this image?"},
|
|
{"type": "input_image", "image_url": "https://example.com/photo.jpg"},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
messages = agent.run.call_args.kwargs["messages"]
|
|
assert len(messages) == 1
|
|
assert len(messages[0].contents) == 2
|
|
assert messages[0].contents[0].type == "text"
|
|
assert messages[0].contents[0].text == "What's in this image?"
|
|
assert messages[0].contents[1].type == "uri"
|
|
assert messages[0].contents[1].uri == "https://example.com/photo.jpg"
|
|
|
|
async def test_function_call_items_in_input(self) -> None:
|
|
"""Input contains function_call and function_call_output items."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(
|
|
messages=[Message(role="assistant", contents=[Content.from_text("Weather is sunny!")])]
|
|
)
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{"type": "message", "role": "user", "content": "What's the weather?"},
|
|
{
|
|
"type": "function_call",
|
|
"id": "fc-1",
|
|
"call_id": "call_1",
|
|
"name": "get_weather",
|
|
"arguments": '{"city": "NYC"}',
|
|
"status": "completed",
|
|
},
|
|
{"type": "function_call_output", "call_id": "call_1", "output": "sunny, 72F"},
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
messages = agent.run.call_args.kwargs["messages"]
|
|
assert len(messages) == 3
|
|
assert messages[0].role == "user"
|
|
assert messages[0].contents[0].type == "text"
|
|
assert messages[1].role == "assistant"
|
|
assert messages[1].contents[0].type == "function_call"
|
|
assert messages[1].contents[0].name == "get_weather"
|
|
assert messages[2].role == "tool"
|
|
assert messages[2].contents[0].type == "function_result"
|
|
assert messages[2].contents[0].result == "sunny, 72F"
|
|
|
|
async def test_multi_turn_text_then_text_with_image(self) -> None:
|
|
"""First turn sends text, second turn sends text + image with previous_response_id."""
|
|
agent = _make_multi_response_agent([
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Send me an image")])]),
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Nice cat!")])]),
|
|
])
|
|
server = _make_server(agent)
|
|
|
|
# Turn 1: simple text
|
|
resp1 = await _post(server, input_text="Hello", stream=False)
|
|
assert resp1.status_code == 200
|
|
response_id = resp1.json()["id"]
|
|
|
|
# Turn 2: text + image input referencing turn 1
|
|
resp2 = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "Here is my cat photo"},
|
|
{"type": "input_image", "image_url": "https://example.com/cat.jpg"},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
"previous_response_id": response_id,
|
|
},
|
|
)
|
|
|
|
assert resp2.status_code == 200
|
|
body2 = resp2.json()
|
|
assert body2["status"] == "completed"
|
|
|
|
# Verify second call receives history from turn 1 + text+image input
|
|
second_call_messages = agent.run.call_args_list[1].kwargs["messages"]
|
|
# History: output message from turn 1 ("Send me an image")
|
|
# Input: message with text + image
|
|
assert len(second_call_messages) >= 2
|
|
# Last message should be the text+image input
|
|
last_msg = second_call_messages[-1]
|
|
assert last_msg.role == "user"
|
|
assert len(last_msg.contents) == 2
|
|
assert last_msg.contents[0].type == "text"
|
|
assert last_msg.contents[0].text == "Here is my cat photo"
|
|
assert last_msg.contents[1].type == "uri"
|
|
assert last_msg.contents[1].uri == "https://example.com/cat.jpg"
|
|
# History should include the assistant response from turn 1
|
|
history_msgs = second_call_messages[:-1]
|
|
assistant_texts = [
|
|
c.text for m in history_msgs if m.role == "assistant" for c in m.contents if c.type == "text"
|
|
]
|
|
assert "Send me an image" in assistant_texts
|
|
|
|
async def test_multi_turn_function_call_in_history(self) -> None:
|
|
"""Turn 1 produces function call + result, turn 2 sees them in history."""
|
|
agent = _make_multi_response_agent([
|
|
AgentResponse(
|
|
messages=[
|
|
Message(
|
|
role="assistant",
|
|
contents=[Content.from_function_call("call_1", "search", arguments='{"q": "cats"}')],
|
|
),
|
|
Message(role="tool", contents=[Content.from_function_result("call_1", result="found 10 cats")]),
|
|
Message(role="assistant", contents=[Content.from_text("I found 10 cats!")]),
|
|
]
|
|
),
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Here are more details")])]),
|
|
])
|
|
server = _make_server(agent)
|
|
|
|
# Turn 1
|
|
resp1 = await _post(server, input_text="Search for cats", stream=False)
|
|
assert resp1.status_code == 200
|
|
response_id = resp1.json()["id"]
|
|
|
|
# Verify turn 1 output has function_call, function_call_output, and message
|
|
types1 = [item["type"] for item in resp1.json()["output"]]
|
|
assert "function_call" in types1
|
|
assert "function_call_output" in types1
|
|
assert "message" in types1
|
|
|
|
# Turn 2
|
|
resp2 = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": "Tell me more",
|
|
"stream": False,
|
|
"previous_response_id": response_id,
|
|
},
|
|
)
|
|
assert resp2.status_code == 200
|
|
assert resp2.json()["status"] == "completed"
|
|
|
|
# Verify turn 2 received history including function call/result
|
|
second_call_messages = agent.run.call_args_list[1].kwargs["messages"]
|
|
roles = [m.role for m in second_call_messages]
|
|
assert "assistant" in roles
|
|
assert "tool" in roles
|
|
# The function call should be in the history
|
|
fc_contents = [
|
|
c for m in second_call_messages if m.role == "assistant" for c in m.contents if c.type == "function_call"
|
|
]
|
|
assert len(fc_contents) >= 1
|
|
assert fc_contents[0].name == "search"
|
|
|
|
async def test_multi_turn_reasoning_in_history(self) -> None:
|
|
"""Turn 1 produces reasoning + text, turn 2 sees them in history."""
|
|
agent = _make_multi_response_agent([
|
|
AgentResponse(
|
|
messages=[
|
|
Message(
|
|
role="assistant",
|
|
contents=[
|
|
Content.from_text_reasoning(text="Let me think about this..."),
|
|
Content.from_text("The answer is 42"),
|
|
],
|
|
),
|
|
]
|
|
),
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Indeed, it is 42")])]),
|
|
])
|
|
server = _make_server(agent)
|
|
|
|
# Turn 1
|
|
resp1 = await _post(server, input_text="What is the answer?", stream=False)
|
|
assert resp1.status_code == 200
|
|
response_id = resp1.json()["id"]
|
|
types1 = [item["type"] for item in resp1.json()["output"]]
|
|
assert "reasoning" in types1
|
|
assert "message" in types1
|
|
|
|
# Turn 2
|
|
resp2 = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": "Are you sure?",
|
|
"stream": False,
|
|
"previous_response_id": response_id,
|
|
},
|
|
)
|
|
assert resp2.status_code == 200
|
|
assert resp2.json()["status"] == "completed"
|
|
|
|
# Verify history includes the reasoning and text from turn 1
|
|
second_call_messages = agent.run.call_args_list[1].kwargs["messages"]
|
|
assert len(second_call_messages) >= 2 # history + new input
|
|
|
|
async def test_multi_turn_with_mixed_content_and_streaming(self) -> None:
|
|
"""Turn 1 non-streaming, turn 2 streaming with image input."""
|
|
turn2_updates = [
|
|
AgentResponseUpdate(contents=[Content.from_text("I see ")], role="assistant"),
|
|
AgentResponseUpdate(contents=[Content.from_text("a cat!")], role="assistant"),
|
|
]
|
|
|
|
agent = _make_multi_response_agent(
|
|
responses=[
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Send me an image")])]),
|
|
AgentResponse(messages=[]), # placeholder, not used for streaming
|
|
],
|
|
stream_updates_list=[
|
|
[], # placeholder for turn 1 (non-streaming)
|
|
turn2_updates,
|
|
],
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
# Turn 1: non-streaming text
|
|
resp1 = await _post(server, input_text="Hello", stream=False)
|
|
assert resp1.status_code == 200
|
|
response_id = resp1.json()["id"]
|
|
|
|
# Turn 2: streaming with image input
|
|
resp2 = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "Describe this:"},
|
|
{"type": "input_image", "image_url": "https://example.com/cat.jpg"},
|
|
],
|
|
}
|
|
],
|
|
"stream": True,
|
|
"previous_response_id": response_id,
|
|
},
|
|
)
|
|
|
|
assert resp2.status_code == 200
|
|
assert "text/event-stream" in resp2.headers["content-type"]
|
|
|
|
events = _parse_sse_events(resp2.text)
|
|
types = _sse_event_types(events)
|
|
assert types[0] == "response.created"
|
|
assert types[-1] == "response.completed"
|
|
assert "response.output_text.delta" in types
|
|
|
|
# Verify accumulated text
|
|
text_done = [e for e in events if e["event"] == "response.output_text.done"]
|
|
assert len(text_done) == 1
|
|
assert text_done[0]["data"]["text"] == "I see a cat!"
|
|
|
|
async def test_text_with_mcp_call_items(self) -> None:
|
|
"""Input contains text message + mcp_call item and the agent processes it."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(
|
|
messages=[Message(role="assistant", contents=[Content.from_text("MCP result received")])]
|
|
)
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{"type": "message", "role": "user", "content": "Search using MCP"},
|
|
{
|
|
"type": "mcp_call",
|
|
"id": "mcp-1",
|
|
"server_label": "my_server",
|
|
"name": "search",
|
|
"arguments": '{"query": "test"}',
|
|
},
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
messages = agent.run.call_args.kwargs["messages"]
|
|
assert len(messages) == 2
|
|
assert messages[0].role == "user"
|
|
assert messages[0].contents[0].type == "text"
|
|
assert messages[0].contents[0].text == "Search using MCP"
|
|
assert messages[1].role == "assistant"
|
|
assert messages[1].contents[0].type == "mcp_server_tool_call"
|
|
assert messages[1].contents[0].server_name == "my_server"
|
|
assert messages[1].contents[0].tool_name == "search"
|
|
|
|
async def test_three_turn_conversation_with_mixed_content(self) -> None:
|
|
"""Three-turn conversation: text → function call → image input."""
|
|
agent = _make_multi_response_agent([
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Hello! How can I help?")])]),
|
|
AgentResponse(
|
|
messages=[
|
|
Message(
|
|
role="assistant",
|
|
contents=[Content.from_function_call("call_1", "analyze", arguments='{"mode": "deep"}')],
|
|
),
|
|
Message(role="tool", contents=[Content.from_function_result("call_1", result="analysis complete")]),
|
|
Message(role="assistant", contents=[Content.from_text("Analysis done!")]),
|
|
]
|
|
),
|
|
AgentResponse(
|
|
messages=[Message(role="assistant", contents=[Content.from_text("The image shows a chart")])]
|
|
),
|
|
])
|
|
server = _make_server(agent)
|
|
|
|
# Turn 1: text
|
|
resp1 = await _post(server, input_text="Hi", stream=False)
|
|
assert resp1.status_code == 200
|
|
id1 = resp1.json()["id"]
|
|
|
|
# Turn 2: text, referencing turn 1
|
|
resp2 = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": "Analyze something",
|
|
"stream": False,
|
|
"previous_response_id": id1,
|
|
},
|
|
)
|
|
assert resp2.status_code == 200
|
|
id2 = resp2.json()["id"]
|
|
|
|
# Turn 3: image input, referencing turn 2
|
|
resp3 = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "What about this image?"},
|
|
{"type": "input_image", "image_url": "https://example.com/chart.png"},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
"previous_response_id": id2,
|
|
},
|
|
)
|
|
|
|
assert resp3.status_code == 200
|
|
assert resp3.json()["status"] == "completed"
|
|
|
|
# Verify turn 3 received full history from turns 1+2 plus new image input
|
|
third_call_messages = agent.run.call_args_list[2].kwargs["messages"]
|
|
# Should have: history from turn 1 (assistant text) + history from turn 2
|
|
# (function_call, function_call_output, text) + new input (text + image)
|
|
assert len(third_call_messages) >= 5
|
|
|
|
# Last message should contain the image
|
|
last_msg = third_call_messages[-1]
|
|
assert last_msg.role == "user"
|
|
image_contents = [c for c in last_msg.contents if c.type == "uri"]
|
|
assert len(image_contents) == 1
|
|
assert image_contents[0].uri == "https://example.com/chart.png"
|
|
|
|
# History should include function call from turn 2
|
|
fc_contents = [
|
|
c
|
|
for m in third_call_messages[:-1]
|
|
if m.role == "assistant"
|
|
for c in m.contents
|
|
if c.type == "function_call"
|
|
]
|
|
assert any(c.name == "analyze" for c in fc_contents)
|
|
|
|
async def test_input_with_hosted_file_image(self) -> None:
|
|
"""Input contains an image referenced by file_id (hosted file)."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Image analyzed")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "Analyze this image"},
|
|
{"type": "input_image", "file_id": "file-abc123"},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
messages = agent.run.call_args.kwargs["messages"]
|
|
assert len(messages) == 1
|
|
assert len(messages[0].contents) == 2
|
|
assert messages[0].contents[0].type == "text"
|
|
assert messages[0].contents[0].text == "Analyze this image"
|
|
assert messages[0].contents[1].type == "hosted_file"
|
|
assert messages[0].contents[1].file_id == "file-abc123"
|
|
|
|
async def test_multi_turn_text_and_image_then_text_and_file(self) -> None:
|
|
"""Turn 1 sends text+image, turn 2 sends text+file, both in history."""
|
|
agent = _make_multi_response_agent([
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("I see a landscape")])]),
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Document summarized")])]),
|
|
])
|
|
server = _make_server(agent)
|
|
|
|
# Turn 1: text + image
|
|
resp1 = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "What is in this photo?"},
|
|
{"type": "input_image", "image_url": "https://example.com/landscape.jpg"},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
assert resp1.status_code == 200
|
|
id1 = resp1.json()["id"]
|
|
|
|
# Turn 2: text + file, referencing turn 1
|
|
resp2 = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "Now summarize this report"},
|
|
{
|
|
"type": "input_file",
|
|
"file_url": "https://example.com/report.pdf",
|
|
"filename": "report.pdf",
|
|
},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
"previous_response_id": id1,
|
|
},
|
|
)
|
|
assert resp2.status_code == 200
|
|
assert resp2.json()["status"] == "completed"
|
|
|
|
# Verify turn 2 received history from turn 1 + new text+file input
|
|
second_call_messages = agent.run.call_args_list[1].kwargs["messages"]
|
|
assert len(second_call_messages) >= 2
|
|
|
|
# History should include the assistant response from turn 1
|
|
assistant_texts = [
|
|
c.text for m in second_call_messages if m.role == "assistant" for c in m.contents if c.type == "text"
|
|
]
|
|
assert "I see a landscape" in assistant_texts
|
|
|
|
# Last message should be text + file
|
|
last_msg = second_call_messages[-1]
|
|
assert last_msg.role == "user"
|
|
assert len(last_msg.contents) == 2
|
|
assert last_msg.contents[0].type == "text"
|
|
assert last_msg.contents[0].text == "Now summarize this report"
|
|
assert last_msg.contents[1].type == "uri"
|
|
assert last_msg.contents[1].uri == "https://example.com/report.pdf"
|
|
|
|
async def test_multi_turn_function_call_then_text_and_image(self) -> None:
|
|
"""Turn 1: text + function call + result, turn 2: text + image."""
|
|
agent = _make_multi_response_agent([
|
|
AgentResponse(
|
|
messages=[
|
|
Message(
|
|
role="assistant",
|
|
contents=[Content.from_function_call("call_1", "get_info", arguments='{"id": 1}')],
|
|
),
|
|
Message(role="tool", contents=[Content.from_function_result("call_1", result="info data")]),
|
|
Message(role="assistant", contents=[Content.from_text("Here is the info")]),
|
|
]
|
|
),
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Image matches the data")])]),
|
|
])
|
|
server = _make_server(agent)
|
|
|
|
# Turn 1: text triggers function call
|
|
resp1 = await _post(server, input_text="Get info for item 1", stream=False)
|
|
assert resp1.status_code == 200
|
|
id1 = resp1.json()["id"]
|
|
|
|
types1 = [item["type"] for item in resp1.json()["output"]]
|
|
assert "function_call" in types1
|
|
assert "function_call_output" in types1
|
|
assert "message" in types1
|
|
|
|
# Turn 2: text + image referencing turn 1
|
|
resp2 = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "Does this image match?"},
|
|
{"type": "input_image", "image_url": "https://example.com/item1.jpg"},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
"previous_response_id": id1,
|
|
},
|
|
)
|
|
assert resp2.status_code == 200
|
|
assert resp2.json()["status"] == "completed"
|
|
|
|
# Verify turn 2 received history with function call + new text+image
|
|
second_call_messages = agent.run.call_args_list[1].kwargs["messages"]
|
|
# History should contain function_call and function_result from turn 1
|
|
fc_contents = [
|
|
c for m in second_call_messages if m.role == "assistant" for c in m.contents if c.type == "function_call"
|
|
]
|
|
assert any(c.name == "get_info" for c in fc_contents)
|
|
tool_contents = [
|
|
c for m in second_call_messages if m.role == "tool" for c in m.contents if c.type == "function_result"
|
|
]
|
|
assert any(c.result == "info data" for c in tool_contents)
|
|
|
|
# Last message should be text + image
|
|
last_msg = second_call_messages[-1]
|
|
assert last_msg.role == "user"
|
|
assert len(last_msg.contents) == 2
|
|
assert last_msg.contents[0].type == "text"
|
|
assert last_msg.contents[0].text == "Does this image match?"
|
|
assert last_msg.contents[1].type == "uri"
|
|
assert last_msg.contents[1].uri == "https://example.com/item1.jpg"
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Function approval round-trip
|
|
|
|
|
|
class TestFunctionApprovalStorage:
|
|
"""Unit tests for the function approval storage classes."""
|
|
|
|
async def test_in_memory_save_and_load(self) -> None:
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
request = _make_function_approval_request_content(request_id="apr_1")
|
|
await storage.save_approval_request("apr_1", request)
|
|
loaded = await storage.load_approval_request("apr_1")
|
|
assert loaded.type == "function_approval_request"
|
|
assert loaded.id == "apr_1" # type: ignore[attr-defined]
|
|
|
|
async def test_in_memory_duplicate_save_raises(self) -> None:
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
request = _make_function_approval_request_content(request_id="apr_1")
|
|
await storage.save_approval_request("apr_1", request)
|
|
with pytest.raises(ValueError, match="already exists"):
|
|
await storage.save_approval_request("apr_1", request)
|
|
|
|
async def test_in_memory_missing_load_raises(self) -> None:
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
with pytest.raises(KeyError):
|
|
await storage.load_approval_request("missing")
|
|
|
|
async def test_file_based_save_and_load_persists_across_instances(self, tmp_path: Any) -> None:
|
|
path = tmp_path / "subdir" / "approvals.json"
|
|
storage = FileBasedFunctionApprovalStorage(str(path))
|
|
request = _make_function_approval_request_content(request_id="apr_1")
|
|
await storage.save_approval_request("apr_1", request)
|
|
|
|
# Directory + file should now exist.
|
|
assert path.exists()
|
|
|
|
# A new instance pointing at the same path can load the saved entry.
|
|
storage2 = FileBasedFunctionApprovalStorage(str(path))
|
|
loaded = await storage2.load_approval_request("apr_1")
|
|
assert loaded.type == "function_approval_request"
|
|
assert loaded.id == "apr_1" # type: ignore[attr-defined]
|
|
# The embedded function_call survives the round trip.
|
|
assert loaded.function_call.name == "delete_file" # type: ignore[attr-defined]
|
|
|
|
async def test_file_based_duplicate_save_raises(self, tmp_path: Any) -> None:
|
|
path = tmp_path / "approvals.json"
|
|
storage = FileBasedFunctionApprovalStorage(str(path))
|
|
request = _make_function_approval_request_content(request_id="apr_1")
|
|
await storage.save_approval_request("apr_1", request)
|
|
with pytest.raises(ValueError, match="already exists"):
|
|
await storage.save_approval_request("apr_1", request)
|
|
|
|
async def test_file_based_missing_load_raises(self, tmp_path: Any) -> None:
|
|
path = tmp_path / "approvals.json"
|
|
storage = FileBasedFunctionApprovalStorage(str(path))
|
|
with pytest.raises(KeyError):
|
|
await storage.load_approval_request("missing")
|
|
|
|
|
|
class TestFunctionApprovalConversion:
|
|
"""Tests for the approval-aware paths in `_item_to_message` / `_output_item_to_message`."""
|
|
|
|
async def test_output_item_mcp_approval_request_loads_from_storage(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemMcpApprovalRequest
|
|
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
saved = _make_function_approval_request_content(request_id="apr-1")
|
|
await storage.save_approval_request("apr-1", saved)
|
|
|
|
item = OutputItemMcpApprovalRequest({
|
|
"type": "mcp_approval_request",
|
|
"id": "apr-1",
|
|
"server_label": "srv",
|
|
"name": "dangerous_tool",
|
|
"arguments": "{}",
|
|
})
|
|
msg = await _output_item_to_message(item, approval_storage=storage)
|
|
assert msg.role == "assistant"
|
|
c = msg.contents[0]
|
|
assert c.type == "function_approval_request"
|
|
assert c.id == "apr-1" # type: ignore[attr-defined]
|
|
# The full saved Content (incl. function_call) is restored.
|
|
assert c.function_call.name == "delete_file" # type: ignore[attr-defined]
|
|
|
|
async def test_output_item_mcp_approval_request_without_storage_raises(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemMcpApprovalRequest
|
|
|
|
item = OutputItemMcpApprovalRequest({
|
|
"type": "mcp_approval_request",
|
|
"id": "apr-1",
|
|
"server_label": "srv",
|
|
"name": "dangerous_tool",
|
|
"arguments": "{}",
|
|
})
|
|
with pytest.raises(ValueError, match="ApprovalStorage is required"):
|
|
await _output_item_to_message(item)
|
|
|
|
async def test_output_item_mcp_approval_response_resolves_to_approval_response(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemMcpApprovalResponseResource
|
|
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
saved = _make_function_approval_request_content(request_id="apr-1")
|
|
await storage.save_approval_request("apr-1", saved)
|
|
|
|
item = OutputItemMcpApprovalResponseResource({
|
|
"type": "mcp_approval_response",
|
|
"id": "resp-1",
|
|
"approval_request_id": "apr-1",
|
|
"approve": True,
|
|
})
|
|
msg = await _output_item_to_message(item, approval_storage=storage)
|
|
assert msg.role == "user"
|
|
c = msg.contents[0]
|
|
assert c.type == "function_approval_response"
|
|
assert c.approved is True # type: ignore[attr-defined]
|
|
assert c.id == "apr-1" # type: ignore[attr-defined]
|
|
assert c.function_call.name == "delete_file" # type: ignore[attr-defined]
|
|
|
|
async def test_output_item_mcp_approval_response_without_storage_raises(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemMcpApprovalResponseResource
|
|
|
|
item = OutputItemMcpApprovalResponseResource({
|
|
"type": "mcp_approval_response",
|
|
"id": "resp-1",
|
|
"approval_request_id": "apr-1",
|
|
"approve": False,
|
|
})
|
|
with pytest.raises(ValueError, match="ApprovalStorage is required"):
|
|
await _output_item_to_message(item)
|
|
|
|
async def test_input_item_mcp_approval_request_loads_from_storage(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemMcpApprovalRequest
|
|
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
saved = _make_function_approval_request_content(request_id="apr-1")
|
|
await storage.save_approval_request("apr-1", saved)
|
|
|
|
item = ItemMcpApprovalRequest({
|
|
"type": "mcp_approval_request",
|
|
"id": "apr-1",
|
|
"server_label": "srv",
|
|
"name": "dangerous_tool",
|
|
"arguments": "{}",
|
|
})
|
|
msg = await _item_to_message(item, approval_storage=storage)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_approval_request"
|
|
assert msg.contents[0].id == "apr-1" # type: ignore[attr-defined]
|
|
|
|
async def test_input_item_mcp_approval_response_resolves_to_approval_response(self) -> None:
|
|
from azure.ai.agentserver.responses.models import MCPApprovalResponse
|
|
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
saved = _make_function_approval_request_content(request_id="apr-1")
|
|
await storage.save_approval_request("apr-1", saved)
|
|
|
|
item = MCPApprovalResponse({
|
|
"type": "mcp_approval_response",
|
|
"approval_request_id": "apr-1",
|
|
"approve": False,
|
|
})
|
|
msg = await _item_to_message(item, approval_storage=storage) # type: ignore[arg-type]
|
|
assert msg.role == "user"
|
|
c = msg.contents[0]
|
|
assert c.type == "function_approval_response"
|
|
assert c.approved is False # type: ignore[attr-defined]
|
|
|
|
|
|
class TestFunctionApprovalRoundTrip:
|
|
"""End-to-end round-trip tests for the function approval flow.
|
|
|
|
Turn 1: the agent emits a `function_approval_request` content; the
|
|
server emits an `mcp_approval_request` output item and persists
|
|
the original Content under the emitted id in approval storage.
|
|
Turn 2: the caller sends an `mcp_approval_response` input item back;
|
|
the server resolves it (via approval storage) into a
|
|
`function_approval_response` content delivered to the agent.
|
|
"""
|
|
|
|
async def test_non_streaming_emits_mcp_approval_request_and_persists_to_storage(self) -> None:
|
|
request_content = _make_function_approval_request_content()
|
|
agent = _make_agent(response=AgentResponse(messages=[Message(role="assistant", contents=[request_content])]))
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post(server, stream=False)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
approval_items = [item for item in body["output"] if item["type"] == "mcp_approval_request"]
|
|
assert len(approval_items) == 1
|
|
approval_request_id = approval_items[0]["id"]
|
|
assert approval_items[0]["name"] == "delete_file"
|
|
assert approval_items[0]["server_label"] == "my_server"
|
|
|
|
# Storage must contain a saved entry under the emitted request id.
|
|
loaded = await server._approval_storage.load_approval_request( # pyright: ignore[reportPrivateUsage]
|
|
approval_request_id
|
|
)
|
|
assert loaded.type == "function_approval_request"
|
|
assert loaded.function_call.name == "delete_file" # type: ignore[attr-defined]
|
|
|
|
async def test_streaming_emits_mcp_approval_request_and_persists_to_storage(self) -> None:
|
|
request_content = _make_function_approval_request_content(request_id="apr_streaming")
|
|
agent = _make_agent(stream_updates=[AgentResponseUpdate(contents=[request_content], role="assistant")])
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
assert types[0] == "response.created"
|
|
assert types[-1] == "response.completed"
|
|
|
|
approval_request_id: str | None = None
|
|
for e in events:
|
|
if e["event"] != "response.output_item.added":
|
|
continue
|
|
item = e["data"].get("item") or {}
|
|
if item.get("type") == "mcp_approval_request":
|
|
approval_request_id = item.get("id")
|
|
break
|
|
assert approval_request_id is not None
|
|
|
|
loaded = await server._approval_storage.load_approval_request( # pyright: ignore[reportPrivateUsage]
|
|
approval_request_id
|
|
)
|
|
assert loaded.type == "function_approval_request"
|
|
|
|
async def test_round_trip_approval_response_reaches_agent(self) -> None:
|
|
"""Two-turn: turn 1 emits an approval request; turn 2 sends an
|
|
approval response and the agent receives a `function_approval_response`."""
|
|
request_content = _make_function_approval_request_content()
|
|
|
|
agent = _make_multi_response_agent(
|
|
responses=[
|
|
AgentResponse(messages=[Message(role="assistant", contents=[request_content])]),
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("done")])]),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
first = await _post(server, stream=False)
|
|
assert first.status_code == 200
|
|
first_body = first.json()
|
|
approval_items = [item for item in first_body["output"] if item["type"] == "mcp_approval_request"]
|
|
assert len(approval_items) == 1
|
|
approval_request_id = approval_items[0]["id"]
|
|
|
|
# Send back an approval response that references the saved request id.
|
|
second_payload: dict[str, Any] = {
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "mcp_approval_response",
|
|
"approval_request_id": approval_request_id,
|
|
"approve": True,
|
|
}
|
|
],
|
|
"stream": False,
|
|
}
|
|
second = await _post_json(server, second_payload)
|
|
assert second.status_code == 200
|
|
|
|
# The agent's second invocation must have received a
|
|
# function_approval_response content carrying the original function_call.
|
|
assert agent.run.call_count == 2
|
|
second_call_kwargs = agent.run.call_args_list[1].kwargs
|
|
approval_responses = [
|
|
c for m in second_call_kwargs["messages"] for c in m.contents if c.type == "function_approval_response"
|
|
]
|
|
assert len(approval_responses) == 1
|
|
assert approval_responses[0].approved is True
|
|
assert approval_responses[0].function_call.name == "delete_file"
|
|
|
|
async def test_round_trip_approval_response_rejected(self) -> None:
|
|
"""Same as above but the user rejects the approval; the agent must
|
|
receive `approved=False`."""
|
|
request_content = _make_function_approval_request_content()
|
|
|
|
agent = _make_multi_response_agent(
|
|
responses=[
|
|
AgentResponse(messages=[Message(role="assistant", contents=[request_content])]),
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])]),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
first = await _post(server, stream=False)
|
|
approval_request_id = next(
|
|
item["id"] for item in first.json()["output"] if item["type"] == "mcp_approval_request"
|
|
)
|
|
|
|
second = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "mcp_approval_response",
|
|
"approval_request_id": approval_request_id,
|
|
"approve": False,
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
assert second.status_code == 200
|
|
|
|
second_call_kwargs = agent.run.call_args_list[1].kwargs
|
|
approval_responses = [
|
|
c for m in second_call_kwargs["messages"] for c in m.contents if c.type == "function_approval_response"
|
|
]
|
|
assert len(approval_responses) == 1
|
|
assert approval_responses[0].approved is False
|
|
|
|
async def test_approval_response_referencing_unknown_id_fails(self) -> None:
|
|
"""Sending an `mcp_approval_response` for a request id that was
|
|
never persisted must fail (storage raises KeyError)."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "mcp_approval_response",
|
|
"approval_request_id": "apr_unknown",
|
|
"approve": True,
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
# The handler raises a KeyError when the storage lookup misses;
|
|
# the hosting layer surfaces this as a 5xx response.
|
|
assert resp.status_code >= 500
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Checkpoint context path validation
|
|
|
|
|
|
class TestCheckpointContextPathValidation:
|
|
"""Regression tests for the path-traversal hardening of checkpoint storage.
|
|
|
|
These tests guard against CWE-22 in the workflow hosting path. The hosting
|
|
code joins caller-supplied identifiers (``previous_response_id``) and
|
|
server-generated identifiers (``conversation_id`` / ``response_id``) under
|
|
the configured checkpoint root. Without validation, traversal segments
|
|
such as ``../../escape`` or absolute paths cause directory creation
|
|
outside the intended root.
|
|
"""
|
|
|
|
@staticmethod
|
|
def _helper() -> Callable[[str, str], FileCheckpointStorage]:
|
|
from agent_framework_foundry_hosting._responses import ( # pyright: ignore[reportPrivateUsage]
|
|
_checkpoint_storage_for_context,
|
|
)
|
|
|
|
return _checkpoint_storage_for_context
|
|
|
|
@staticmethod
|
|
def _checkpoint_with_azure_message_role() -> WorkflowCheckpoint:
|
|
from azure.ai.agentserver.responses.models import MessageRole
|
|
|
|
return WorkflowCheckpoint(
|
|
workflow_name="wf",
|
|
graph_signature_hash="hash",
|
|
messages={
|
|
"executor": [
|
|
WorkflowMessage(
|
|
data=Message(role=MessageRole.USER, contents=[Content.from_text("hello")]),
|
|
source_id="source",
|
|
)
|
|
]
|
|
},
|
|
)
|
|
|
|
def test_valid_segment_creates_storage_under_root(self, tmp_path: Any) -> None:
|
|
helper = self._helper()
|
|
root = tmp_path / "root"
|
|
root.mkdir()
|
|
storage = helper(str(root), "resp_abc123")
|
|
assert storage.storage_path.is_dir()
|
|
assert storage.storage_path.parent == root.resolve()
|
|
|
|
def test_azure_message_role_allowlist_type_matches_generated_sdk_path(self) -> None:
|
|
assert (
|
|
_AZURE_RESPONSES_MESSAGE_ROLE_TYPE
|
|
== "azure.ai.agentserver.responses.models._generated.sdk.models.models._enums:MessageRole"
|
|
)
|
|
|
|
async def test_storage_allows_azure_message_role_checkpoint_restore(self, tmp_path: Any) -> None:
|
|
from azure.ai.agentserver.responses.models import MessageRole
|
|
|
|
helper = self._helper()
|
|
root = tmp_path / "root"
|
|
root.mkdir()
|
|
storage = helper(str(root), "resp_abc123")
|
|
checkpoint = self._checkpoint_with_azure_message_role()
|
|
|
|
await storage.save(checkpoint)
|
|
loaded = await storage.load(checkpoint.checkpoint_id)
|
|
|
|
loaded_message = loaded.messages["executor"][0].data
|
|
assert isinstance(loaded_message, Message)
|
|
assert type(loaded_message.role) is MessageRole
|
|
assert loaded_message.role == MessageRole.USER
|
|
assert loaded_message.text == "hello"
|
|
|
|
async def test_plain_storage_blocks_azure_message_role_checkpoint_restore(self, tmp_path: Any) -> None:
|
|
storage = FileCheckpointStorage(tmp_path / "plain")
|
|
checkpoint = self._checkpoint_with_azure_message_role()
|
|
|
|
await storage.save(checkpoint)
|
|
with pytest.raises(WorkflowCheckpointException, match="MessageRole"):
|
|
await storage.load(checkpoint.checkpoint_id)
|
|
|
|
async def test_get_latest_restores_azure_message_role(self, tmp_path: Any) -> None:
|
|
from azure.ai.agentserver.responses.models import MessageRole
|
|
|
|
helper = self._helper()
|
|
root = tmp_path / "root"
|
|
root.mkdir()
|
|
storage = helper(str(root), "resp_abc123")
|
|
checkpoint = self._checkpoint_with_azure_message_role()
|
|
|
|
await storage.save(checkpoint)
|
|
latest = await storage.get_latest(workflow_name="wf")
|
|
|
|
assert latest is not None
|
|
assert latest.checkpoint_id == checkpoint.checkpoint_id
|
|
latest_message = latest.messages["executor"][0].data
|
|
assert isinstance(latest_message, Message)
|
|
assert type(latest_message.role) is MessageRole
|
|
|
|
async def test_get_latest_silently_skips_without_allowlist(
|
|
self, tmp_path: Any, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
import logging
|
|
|
|
storage = FileCheckpointStorage(tmp_path / "plain")
|
|
checkpoint = self._checkpoint_with_azure_message_role()
|
|
|
|
await storage.save(checkpoint)
|
|
with caplog.at_level(logging.WARNING, logger="agent_framework"):
|
|
latest = await storage.get_latest(workflow_name="wf")
|
|
|
|
assert latest is None
|
|
assert any("MessageRole" in message for message in caplog.messages)
|
|
|
|
async def test_handle_inner_workflow_restores_message_role_checkpoint_from_previous_response(
|
|
self, tmp_path: Any
|
|
) -> None:
|
|
from agent_framework import WorkflowAgent
|
|
from azure.ai.agentserver.responses import ResponseContext
|
|
from azure.ai.agentserver.responses.models import CreateResponse, ItemMessage
|
|
|
|
previous_response_id = "resp_previous"
|
|
response_id = "resp_current"
|
|
root = tmp_path / "root"
|
|
root.mkdir()
|
|
checkpoint_storage = self._helper()(str(root), previous_response_id)
|
|
checkpoint = self._checkpoint_with_azure_message_role()
|
|
await checkpoint_storage.save(checkpoint)
|
|
|
|
agent = MagicMock(spec=WorkflowAgent)
|
|
agent.id = "wf-agent"
|
|
agent.name = "wf"
|
|
agent.description = ""
|
|
agent.context_providers = []
|
|
agent.workflow = MagicMock()
|
|
agent.workflow.name = "wf"
|
|
agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False)
|
|
agent.run = AsyncMock(
|
|
side_effect=[
|
|
AgentResponse(messages=[]),
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])]),
|
|
]
|
|
)
|
|
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
|
|
server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage]
|
|
|
|
request = CreateResponse(model="m", input="hi", previous_response_id=previous_response_id)
|
|
context = ResponseContext(
|
|
response_id=response_id, previous_response_id=previous_response_id, mode_flags=MagicMock()
|
|
)
|
|
input_item = ItemMessage({"type": "message", "role": "user", "content": "next turn"})
|
|
|
|
with patch.object(ResponseContext, "get_input_items", new=AsyncMock(return_value=[input_item])):
|
|
async for _ in server._handle_inner_workflow(request, context): # pyright: ignore[reportPrivateUsage]
|
|
pass
|
|
|
|
assert agent.run.call_count == 2
|
|
restore_call = agent.run.call_args_list[0]
|
|
assert restore_call.kwargs["checkpoint_id"] == checkpoint.checkpoint_id
|
|
assert restore_call.kwargs["checkpoint_storage"].storage_path == (root / previous_response_id).resolve()
|
|
|
|
new_turn_call = agent.run.call_args_list[1]
|
|
new_turn_messages = new_turn_call.args[0]
|
|
assert len(new_turn_messages) == 1
|
|
assert new_turn_messages[0].text == "next turn"
|
|
assert new_turn_call.kwargs["checkpoint_storage"].storage_path == (root / response_id).resolve()
|
|
|
|
@pytest.mark.parametrize(
|
|
"bad_id",
|
|
[
|
|
# Original MSRC repro: traversal embedded inside an id-shaped value.
|
|
# The 14 ``A``s pad the suffix to mimic the exact length of the
|
|
# ``api-made-dir<14-char-suffix>`` segment from the original report.
|
|
"caresp_x/../../service-data/api-made-dir" + "A" * 14,
|
|
# Variant report repros.
|
|
"../../escape",
|
|
"..",
|
|
".",
|
|
"...",
|
|
"/tmp/escape",
|
|
"/absolute/path",
|
|
"C:\\temp\\escape",
|
|
"..\\..\\escape",
|
|
"foo\\..\\bar",
|
|
"foo/bar",
|
|
"with\x00null",
|
|
"",
|
|
],
|
|
)
|
|
def test_traversal_and_separator_payloads_are_rejected(self, tmp_path: Any, bad_id: str) -> None:
|
|
helper = self._helper()
|
|
# Use a dedicated root *inside* tmp_path so we can assert that nothing
|
|
# was created anywhere under tmp_path (root, siblings, or above).
|
|
# Asserting against tmp_path.parent would be flaky under parallel test
|
|
# execution because tmp_path.parent is shared across tests.
|
|
root = tmp_path / "root"
|
|
root.mkdir()
|
|
before = sorted(p.name for p in tmp_path.iterdir())
|
|
with pytest.raises(RuntimeError):
|
|
helper(str(root), bad_id)
|
|
# No sibling/escape directory should have been created next to the root.
|
|
after = sorted(p.name for p in tmp_path.iterdir())
|
|
assert before == after, f"Unexpected filesystem artifacts created for payload {bad_id!r}"
|
|
# And nothing inside the root either.
|
|
assert list(root.iterdir()) == []
|
|
|
|
def test_non_string_context_id_is_rejected(self, tmp_path: Any) -> None:
|
|
helper = self._helper()
|
|
with pytest.raises(RuntimeError):
|
|
helper(str(tmp_path), None) # type: ignore[arg-type]
|
|
|
|
def test_url_encoded_traversal_is_treated_as_literal_segment(self, tmp_path: Any) -> None:
|
|
"""URL-encoded traversal should not decode to traversal at the filesystem layer.
|
|
|
|
The hosting layer never URL-decodes ids before using them; the helper
|
|
should accept ``%2e%2e`` as a single literal segment (no escape).
|
|
"""
|
|
helper = self._helper()
|
|
root = tmp_path / "root"
|
|
root.mkdir()
|
|
storage = helper(str(root), "%2e%2e")
|
|
assert storage.storage_path.parent == root.resolve()
|
|
assert storage.storage_path.name == "%2e%2e"
|
|
|
|
@pytest.mark.parametrize(
|
|
"context_field,bad_id",
|
|
[
|
|
# Restore sink: caller-controlled previous_response_id.
|
|
("previous_response_id", "../../escape"),
|
|
("previous_response_id", "/tmp/escape-abs"),
|
|
("previous_response_id", "caresp_x/../../service-data/api-made-dir" + "A" * 14),
|
|
# Restore sink: server-issued conversation_id (defense in depth).
|
|
("conversation_id", "../../escape"),
|
|
# Write sink: malicious response_id (defense in depth).
|
|
("response_id", "../../escape"),
|
|
],
|
|
)
|
|
async def test_handle_inner_workflow_rejects_malicious_context_id(
|
|
self, tmp_path: Any, context_field: str, bad_id: str
|
|
) -> None:
|
|
"""End-to-end: ``_handle_inner_workflow`` must reject malicious ids on
|
|
both the restore sink (``previous_response_id`` / ``conversation_id``)
|
|
and the write sink (``response_id``) without creating any directories.
|
|
"""
|
|
from unittest.mock import patch
|
|
|
|
from agent_framework import WorkflowAgent
|
|
from azure.ai.agentserver.responses import ResponseContext
|
|
from azure.ai.agentserver.responses.models import CreateResponse
|
|
|
|
# Build a mock that satisfies isinstance(agent, WorkflowAgent) and the
|
|
# constructor's "no existing checkpointing" guard.
|
|
agent = MagicMock(spec=WorkflowAgent)
|
|
agent.id = "wf-agent"
|
|
agent.name = "wf"
|
|
agent.description = ""
|
|
agent.context_providers = []
|
|
agent.workflow = MagicMock()
|
|
agent.workflow.name = "wf"
|
|
agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False)
|
|
|
|
# Constructor inspects WorkflowAgent.workflow internals; bypass setup
|
|
# by feeding a configured mock through a normal init.
|
|
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
|
|
# Re-root checkpoint storage at our isolated tmp_path so we can detect
|
|
# any escape attempt on the filesystem.
|
|
root = tmp_path / "root"
|
|
root.mkdir()
|
|
server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage]
|
|
|
|
# Build a ResponseContext with the malicious id targeting the chosen sink.
|
|
kwargs: dict[str, Any] = {
|
|
"response_id": "resp_" + "a" * 48,
|
|
"mode_flags": MagicMock(),
|
|
}
|
|
if context_field == "previous_response_id":
|
|
request = CreateResponse(model="m", input="hi", previous_response_id=bad_id)
|
|
kwargs["previous_response_id"] = bad_id
|
|
elif context_field == "conversation_id":
|
|
request = CreateResponse(model="m", input="hi")
|
|
kwargs["conversation_id"] = bad_id
|
|
else: # response_id (write sink)
|
|
request = CreateResponse(model="m", input="hi")
|
|
kwargs["response_id"] = bad_id
|
|
|
|
# Avoid invoking the real input-resolution machinery, which would need
|
|
# a configured provider; we never reach the workflow run on rejection.
|
|
with patch.object(ResponseContext, "get_input_items", new=AsyncMock(return_value=[])):
|
|
context = ResponseContext(**kwargs)
|
|
before = sorted(p.name for p in tmp_path.iterdir())
|
|
with pytest.raises(RuntimeError, match="Invalid checkpoint context id"):
|
|
async for _ in server._handle_inner_workflow(request, context): # pyright: ignore[reportPrivateUsage]
|
|
pass
|
|
after = sorted(p.name for p in tmp_path.iterdir())
|
|
|
|
assert before == after, f"Unexpected filesystem artifacts created for {context_field}={bad_id!r}"
|
|
assert list(root.iterdir()) == [], f"Checkpoint dir created inside root for {context_field}={bad_id!r}"
|
|
|
|
@pytest.mark.parametrize(
|
|
"context_field,bad_id",
|
|
[
|
|
# Restore sink: caller-controlled previous_response_id. These are
|
|
# rejected by request validation (HTTP 400) before the checkpoint
|
|
# code is reached.
|
|
("previous_response_id", "../../escape"),
|
|
("previous_response_id", "/tmp/escape-abs"),
|
|
("previous_response_id", "caresp_x/../../service-data/api-made-dir" + "A" * 14),
|
|
# Restore sink: server-issued conversation id (defense in depth).
|
|
# Reaches the checkpoint code and is rejected there, surfacing as
|
|
# an HTTP 5xx without creating any filesystem artifacts.
|
|
("conversation", "../../escape"),
|
|
("conversation", "/tmp/escape-abs"),
|
|
],
|
|
)
|
|
async def test_malicious_context_id_rejected_e2e(self, tmp_path: Any, context_field: str, bad_id: str) -> None:
|
|
"""End-to-end (ASGI-in-process): malicious context ids must be rejected
|
|
through the full HTTP pipeline, and no checkpoint directory may be
|
|
created on disk for either the validation-layer rejection
|
|
(``previous_response_id``) or the deeper checkpoint-layer rejection
|
|
(``conversation``).
|
|
|
|
The ``response_id`` write-sink is server-generated and not reachable
|
|
via the public HTTP surface, so its defense-in-depth check is covered
|
|
by the helper-level test above.
|
|
"""
|
|
from agent_framework import WorkflowAgent
|
|
|
|
# Build a mock that satisfies isinstance(agent, WorkflowAgent) and the
|
|
# constructor's "no existing checkpointing" guard.
|
|
agent = MagicMock(spec=WorkflowAgent)
|
|
agent.id = "wf-agent"
|
|
agent.name = "wf"
|
|
agent.description = ""
|
|
agent.context_providers = []
|
|
agent.workflow = MagicMock()
|
|
agent.workflow.name = "wf"
|
|
agent.workflow._runner_context.has_checkpointing = MagicMock( # pyright: ignore[reportPrivateUsage]
|
|
return_value=False
|
|
)
|
|
|
|
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
|
|
# Re-root checkpoint storage at our isolated tmp_path so we can detect
|
|
# any escape attempt on the filesystem.
|
|
root = tmp_path / "root"
|
|
root.mkdir()
|
|
server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage]
|
|
|
|
payload: dict[str, Any] = {"model": "m", "input": "hi"}
|
|
if context_field == "previous_response_id":
|
|
payload["previous_response_id"] = bad_id
|
|
else: # conversation
|
|
payload["conversation"] = bad_id
|
|
|
|
before = sorted(p.name for p in tmp_path.iterdir())
|
|
transport = httpx.ASGITransport(app=server)
|
|
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
|
resp = await client.post("/responses", json=payload)
|
|
after = sorted(p.name for p in tmp_path.iterdir())
|
|
|
|
# The request must not succeed; either request validation rejects it
|
|
# (4xx) or the checkpoint layer raises and the server returns 5xx.
|
|
# Either way, no successful response may be produced.
|
|
assert resp.status_code >= 400, (
|
|
f"Expected non-2xx for {context_field}={bad_id!r}, got {resp.status_code}: {resp.text[:200]}"
|
|
)
|
|
assert before == after, (
|
|
f"Unexpected filesystem artifacts under tmp_path for {context_field}={bad_id!r}: "
|
|
f"before={before} after={after}"
|
|
)
|
|
assert list(root.iterdir()) == [], f"Checkpoint directory created inside root for {context_field}={bad_id!r}"
|
|
|
|
|
|
# region Agent lifecycle (lazy entry & OAuth consent surfacing)
|
|
|
|
|
|
def _make_consent_error(url: str = "https://consent.example.com/auth") -> Exception:
|
|
"""Build an exception wrapping a Foundry MCP gateway consent error.
|
|
|
|
Mirrors the real-world wrapping produced by ``MCPStreamableHTTPTool.__aenter__``,
|
|
which catches connection-time ``McpError``s and re-raises them as a
|
|
``ToolExecutionException`` (an ``AgentFrameworkException`` subclass) with the
|
|
original error attached via ``inner_exception``. ``consent_url_from_error``
|
|
then finds the wrapped ``McpError`` in ``exc.args``.
|
|
"""
|
|
from agent_framework.exceptions import ToolExecutionException
|
|
|
|
inner = McpError(ErrorData(code=CONSENT_ERROR_CODE, message=url))
|
|
return ToolExecutionException("MCP consent required", inner_exception=inner)
|
|
|
|
|
|
class TestConsentUrlFromError:
|
|
def test_returns_consent_url_when_inner_arg_is_consent_mcp_error(self) -> None:
|
|
exc = _make_consent_error("https://example.com/consent")
|
|
assert consent_url_from_error(exc) == "https://example.com/consent"
|
|
|
|
def test_returns_none_when_no_mcp_error_in_args(self) -> None:
|
|
assert consent_url_from_error(Exception("boom")) is None
|
|
|
|
def test_returns_none_when_mcp_error_has_different_code(self) -> None:
|
|
inner = McpError(ErrorData(code=-32000, message="some other error"))
|
|
exc = Exception("wrapped", inner)
|
|
assert consent_url_from_error(exc) is None
|
|
|
|
def test_returns_none_for_bare_mcp_error_without_wrapping(self) -> None:
|
|
# `args` of a bare McpError holds the message string, not an McpError
|
|
# instance, so it does not match the wrapping pattern produced by the
|
|
# MCP client when it bubbles consent errors up.
|
|
bare = McpError(ErrorData(code=CONSENT_ERROR_CODE, message="https://x"))
|
|
assert consent_url_from_error(bare) is None
|
|
|
|
|
|
class TestAgentLifecycle:
|
|
async def test_agent_entered_lazily_on_first_request(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hi")])])
|
|
)
|
|
server = _make_server(agent)
|
|
# Construction must not enter the agent.
|
|
assert agent.__aenter__.await_count == 0
|
|
|
|
await _post(server, input_text="hello", stream=False)
|
|
assert agent.__aenter__.await_count == 1
|
|
|
|
async def test_agent_entered_only_once_across_requests(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hi")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
await _post(server, input_text="first", stream=False)
|
|
await _post(server, input_text="second", stream=False)
|
|
await _post(server, input_text="third", stream=False)
|
|
assert agent.__aenter__.await_count == 1
|
|
|
|
async def test_cleanup_exits_agent_and_allows_reentry(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hi")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
await _post(server, input_text="hello", stream=False)
|
|
assert agent.__aenter__.await_count == 1
|
|
assert agent.__aexit__.await_count == 0
|
|
|
|
await server._cleanup_agent() # pyright: ignore[reportPrivateUsage]
|
|
assert agent.__aexit__.await_count == 1
|
|
|
|
# Cleanup is idempotent.
|
|
await server._cleanup_agent() # pyright: ignore[reportPrivateUsage]
|
|
assert agent.__aexit__.await_count == 1
|
|
|
|
# After cleanup, a follow-up request re-enters the agent.
|
|
await _post(server, input_text="again", stream=False)
|
|
assert agent.__aenter__.await_count == 2
|
|
|
|
async def test_failed_entry_does_not_cache_stack(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hi")])])
|
|
)
|
|
agent.__aenter__.side_effect = [_make_consent_error(), None]
|
|
server = _make_server(agent)
|
|
|
|
await _post(server, input_text="first", stream=False)
|
|
# Failed entry must leave the stack empty so the next request retries.
|
|
await _post(server, input_text="second", stream=False)
|
|
assert agent.__aenter__.await_count == 2
|
|
|
|
|
|
class TestOAuthConsentSurfacing:
|
|
async def test_non_streaming_consent_error_emits_oauth_output_item(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hi")])])
|
|
)
|
|
agent.__aenter__.side_effect = _make_consent_error("https://consent.example.com/auth")
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post(server, input_text="hello", stream=False)
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
oauth_items = [it for it in body["output"] if it["type"] == "oauth_consent_request"]
|
|
assert len(oauth_items) == 1
|
|
assert oauth_items[0]["consent_link"] == "https://consent.example.com/auth"
|
|
assert oauth_items[0]["server_label"] == "Foundry Toolbox"
|
|
|
|
# The agent must not be run when entry fails.
|
|
agent.run.assert_not_called()
|
|
|
|
async def test_streaming_consent_error_emits_oauth_output_item(self) -> None:
|
|
agent = _make_agent(stream_updates=[AgentResponseUpdate(contents=[Content.from_text("hi")], role="assistant")])
|
|
agent.__aenter__.side_effect = _make_consent_error("https://consent.example.com/auth")
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post(server, input_text="hello", stream=True)
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
|
|
assert types[0] == "response.created"
|
|
assert types[1] == "response.in_progress"
|
|
assert types[-1] == "response.completed"
|
|
|
|
added = [e for e in events if e["event"] == "response.output_item.added"]
|
|
oauth_added = [e for e in added if e["data"]["item"]["type"] == "oauth_consent_request"]
|
|
assert len(oauth_added) == 1
|
|
assert oauth_added[0]["data"]["item"]["consent_link"] == "https://consent.example.com/auth"
|
|
assert oauth_added[0]["data"]["item"]["server_label"] == "Foundry Toolbox"
|
|
|
|
done = [e for e in events if e["event"] == "response.output_item.done"]
|
|
assert any(e["data"]["item"]["type"] == "oauth_consent_request" for e in done)
|
|
|
|
agent.run.assert_not_called()
|
|
|
|
async def test_non_consent_error_during_entry_propagates(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hi")])])
|
|
)
|
|
agent.__aenter__.side_effect = RuntimeError("boom")
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post(server, input_text="hello", stream=False)
|
|
# Non-consent errors are not swallowed: the response is marked failed
|
|
# and no `oauth_consent_request` item is emitted.
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "failed"
|
|
assert not any(it["type"] == "oauth_consent_request" for it in body.get("output", []))
|
|
agent.run.assert_not_called()
|
|
|
|
async def test_retry_after_consent_succeeds(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hello!")])])
|
|
)
|
|
agent.__aenter__.side_effect = [_make_consent_error("https://consent.example.com/auth"), None]
|
|
server = _make_server(agent)
|
|
|
|
# First request surfaces consent; agent.run is not called.
|
|
resp1 = await _post(server, input_text="first", stream=False)
|
|
assert resp1.status_code == 200
|
|
body1 = resp1.json()
|
|
oauth = [it for it in body1["output"] if it["type"] == "oauth_consent_request"]
|
|
assert len(oauth) == 1
|
|
agent.run.assert_not_called()
|
|
|
|
# After the user authenticates, the next request enters successfully.
|
|
resp2 = await _post(server, input_text="second", stream=False)
|
|
assert resp2.status_code == 200
|
|
body2 = resp2.json()
|
|
assert body2["status"] == "completed"
|
|
assert any(it["type"] == "message" for it in body2["output"])
|
|
assert agent.__aenter__.await_count == 2
|
|
agent.run.assert_awaited_once()
|
|
|
|
|
|
# endregion
|