mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
ef86fb51d5
* Add a HarnessAgent with available features and sample * Fix formatting * Address PR comments and fix mypy error * Add web search support to HarnessAgent * Fix build warning * Apply suggestions from code review Co-authored-by: Eduard van Valkenburg <eavanvalkenburg@users.noreply.github.com> * Address PR comments * Address PR comments * Address further PR comments. * Fix markdown broken link --------- Co-authored-by: Eduard van Valkenburg <eavanvalkenburg@users.noreply.github.com>
3112 lines
124 KiB
Python
3112 lines
124 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
"""HTTP round-trip tests for ResponsesHostServer.
|
|
|
|
These tests exercise the full HTTP pipeline using httpx.AsyncClient with
|
|
ASGITransport — no real server process is started. Requests go through
|
|
the Starlette routing stack, the Responses API middleware, and arrive at
|
|
the registered _handle_create handler.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from collections.abc import AsyncIterator, Callable
|
|
from dataclasses import dataclass
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import httpx
|
|
import pytest
|
|
from agent_framework import (
|
|
AgentResponse,
|
|
AgentResponseUpdate,
|
|
Content,
|
|
FileCheckpointStorage,
|
|
HistoryProvider,
|
|
Message,
|
|
RawAgent,
|
|
ResponseStream,
|
|
)
|
|
from azure.ai.agentserver.responses import InMemoryResponseProvider
|
|
from mcp import McpError
|
|
from mcp.types import ErrorData
|
|
from typing_extensions import Any
|
|
|
|
from agent_framework_foundry_hosting import ResponsesHostServer
|
|
from agent_framework_foundry_hosting._responses import (
|
|
CONSENT_ERROR_CODE,
|
|
FileBasedFunctionApprovalStorage, # pyright: ignore[reportPrivateUsage]
|
|
InMemoryFunctionApprovalStorage, # pyright: ignore[reportPrivateUsage]
|
|
_item_to_message, # pyright: ignore[reportPrivateUsage]
|
|
_output_item_to_message, # pyright: ignore[reportPrivateUsage]
|
|
consent_url_from_error,
|
|
)
|
|
|
|
|
|
def _make_function_approval_request_content(
|
|
*,
|
|
request_id: str = "apr_test",
|
|
call_id: str = "call_1",
|
|
name: str = "delete_file",
|
|
arguments: str = '{"path": "/foo"}',
|
|
server_label: str = "my_server",
|
|
) -> Content:
|
|
"""Build a function_approval_request Content with an embedded function_call."""
|
|
function_call = Content.from_function_call(
|
|
call_id, name, arguments=arguments, additional_properties={"server_label": server_label}
|
|
)
|
|
return Content.from_function_approval_request(request_id, function_call)
|
|
|
|
|
|
# region Helpers
|
|
|
|
|
|
def _make_agent(
|
|
*,
|
|
response: AgentResponse | None = None,
|
|
stream_updates: list[AgentResponseUpdate] | None = None,
|
|
raw_agent: bool = True,
|
|
) -> MagicMock:
|
|
"""Create a mock agent implementing SupportsAgentRun."""
|
|
agent = MagicMock(spec=RawAgent) if raw_agent else MagicMock()
|
|
agent.id = "test-agent"
|
|
agent.name = "Test Agent"
|
|
agent.description = "A mock agent for testing"
|
|
agent.context_providers = []
|
|
|
|
if response is not None:
|
|
|
|
async def run_non_streaming(*args: Any, **kwargs: Any) -> AgentResponse:
|
|
return response
|
|
|
|
agent.run = AsyncMock(side_effect=run_non_streaming)
|
|
|
|
if stream_updates is not None:
|
|
|
|
async def _stream_gen() -> AsyncIterator[AgentResponseUpdate]:
|
|
for update in stream_updates:
|
|
yield update
|
|
|
|
def run_streaming(*args: Any, **kwargs: Any) -> Any:
|
|
if kwargs.get("stream"):
|
|
return ResponseStream(_stream_gen()) # type: ignore
|
|
raise NotImplementedError("Only streaming is configured on this mock")
|
|
|
|
agent.run = MagicMock(side_effect=run_streaming)
|
|
|
|
return agent
|
|
|
|
|
|
def _make_server(agent: MagicMock, **kwargs: Any) -> ResponsesHostServer:
|
|
"""Create a ResponsesHostServer with an in-memory store."""
|
|
return ResponsesHostServer(agent, store=InMemoryResponseProvider(), **kwargs)
|
|
|
|
|
|
async def _post(
|
|
server: ResponsesHostServer,
|
|
*,
|
|
input_text: str = "Hello",
|
|
model: str = "test-model",
|
|
stream: bool = False,
|
|
temperature: float | None = None,
|
|
top_p: float | None = None,
|
|
max_output_tokens: int | None = None,
|
|
parallel_tool_calls: bool | None = None,
|
|
) -> httpx.Response:
|
|
"""Send a POST /responses request through the ASGI transport."""
|
|
payload: dict[str, Any] = {"model": model, "input": input_text, "stream": stream}
|
|
if temperature is not None:
|
|
payload["temperature"] = temperature
|
|
if top_p is not None:
|
|
payload["top_p"] = top_p
|
|
if max_output_tokens is not None:
|
|
payload["max_output_tokens"] = max_output_tokens
|
|
if parallel_tool_calls is not None:
|
|
payload["parallel_tool_calls"] = parallel_tool_calls
|
|
|
|
transport = httpx.ASGITransport(app=server)
|
|
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
|
return await client.post("/responses", json=payload)
|
|
|
|
|
|
def _parse_sse_events(body: str) -> list[dict[str, Any]]:
|
|
"""Parse SSE text into a list of event dicts with 'event' and 'data' keys."""
|
|
events: list[dict[str, Any]] = []
|
|
current_event: str | None = None
|
|
current_data_lines: list[str] = []
|
|
|
|
for line in body.split("\n"):
|
|
if line.startswith("event: "):
|
|
current_event = line[len("event: ") :]
|
|
elif line.startswith("data: "):
|
|
current_data_lines.append(line[len("data: ") :])
|
|
elif line.strip() == "" and current_event is not None:
|
|
data_str = "\n".join(current_data_lines)
|
|
try:
|
|
data = json.loads(data_str)
|
|
except json.JSONDecodeError:
|
|
data = data_str
|
|
events.append({"event": current_event, "data": data})
|
|
current_event = None
|
|
current_data_lines = []
|
|
|
|
return events
|
|
|
|
|
|
def _sse_event_types(events: list[dict[str, Any]]) -> list[str]:
|
|
"""Extract event type strings from parsed SSE events."""
|
|
return [e["event"] for e in events]
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Initialization
|
|
|
|
|
|
class TestResponsesHostServerInit:
|
|
def test_init_basic(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hi")])])
|
|
)
|
|
server = _make_server(agent)
|
|
assert server is not None
|
|
|
|
def test_init_rejects_history_provider_with_load_messages(self) -> None:
|
|
hp = HistoryProvider(source_id="test", load_messages=True)
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hi")])])
|
|
)
|
|
agent.context_providers = [hp]
|
|
with pytest.raises(RuntimeError, match="history provider"):
|
|
ResponsesHostServer(agent)
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Health Check
|
|
|
|
|
|
class TestHealthCheck:
|
|
async def test_readiness(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hi")])])
|
|
)
|
|
server = _make_server(agent)
|
|
transport = httpx.ASGITransport(app=server)
|
|
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
|
resp = await client.get("/readiness")
|
|
assert resp.status_code == 200
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Non-streaming
|
|
|
|
|
|
class TestNonStreaming:
|
|
async def test_basic_text_response(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Hello!")])])
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, input_text="Hi", stream=False)
|
|
|
|
assert resp.status_code == 200
|
|
assert "application/json" in resp.headers["content-type"]
|
|
|
|
body = resp.json()
|
|
assert body["object"] == "response"
|
|
assert body["status"] == "completed"
|
|
assert len(body["output"]) > 0
|
|
|
|
# Find the message output item with our text
|
|
text_found = False
|
|
for item in body["output"]:
|
|
assert item["type"] == "message"
|
|
for part in item.get("content", []):
|
|
if part.get("type") == "output_text" and part.get("text") == "Hello!":
|
|
text_found = True
|
|
assert text_found, f"Expected 'Hello!' in output, got: {body['output']}"
|
|
|
|
async def test_function_call_and_result(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(
|
|
messages=[
|
|
Message(
|
|
role="assistant",
|
|
contents=[Content.from_function_call("call_1", "get_weather", arguments='{"loc": "NYC"}')],
|
|
),
|
|
Message(role="tool", contents=[Content.from_function_result("call_1", result="sunny")]),
|
|
Message(role="assistant", contents=[Content.from_text("The weather is sunny!")]),
|
|
]
|
|
)
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=False)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
types = [item["type"] for item in body["output"]]
|
|
assert "function_call" in types
|
|
assert "function_call_output" in types
|
|
assert "message" in types
|
|
|
|
async def test_reasoning_content(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(
|
|
messages=[
|
|
Message(
|
|
role="assistant",
|
|
contents=[
|
|
Content.from_text_reasoning(text="Let me think..."),
|
|
Content.from_text("The answer is 42"),
|
|
],
|
|
),
|
|
]
|
|
)
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=False)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
types = [item["type"] for item in body["output"]]
|
|
assert "reasoning" in types
|
|
assert "message" in types
|
|
|
|
async def test_empty_response(self) -> None:
|
|
agent = _make_agent(response=AgentResponse(messages=[]))
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=False)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
async def test_chat_options_forwarded(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])]),
|
|
raw_agent=True,
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(
|
|
server,
|
|
stream=False,
|
|
temperature=0.5,
|
|
top_p=0.9,
|
|
max_output_tokens=1024,
|
|
parallel_tool_calls=True,
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
agent.run.assert_awaited_once()
|
|
call_kwargs = agent.run.call_args.kwargs
|
|
assert call_kwargs["stream"] is False
|
|
options = call_kwargs["options"]
|
|
assert options["temperature"] == 0.5
|
|
assert options["top_p"] == 0.9
|
|
assert options["max_tokens"] == 1024
|
|
assert options["allow_multiple_tool_calls"] is True
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Streaming
|
|
|
|
|
|
class TestStreaming:
|
|
async def test_chat_options_forwarded(self) -> None:
|
|
agent = _make_agent(
|
|
stream_updates=[AgentResponseUpdate(contents=[Content.from_text("ok")], role="assistant")],
|
|
raw_agent=True,
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(
|
|
server,
|
|
stream=True,
|
|
temperature=0.5,
|
|
top_p=0.9,
|
|
max_output_tokens=1024,
|
|
parallel_tool_calls=True,
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
agent.run.assert_called_once()
|
|
call_kwargs = agent.run.call_args.kwargs
|
|
assert call_kwargs["stream"] is True
|
|
options = call_kwargs["options"]
|
|
assert options["temperature"] == 0.5
|
|
assert options["top_p"] == 0.9
|
|
assert options["max_tokens"] == 1024
|
|
assert options["allow_multiple_tool_calls"] is True
|
|
|
|
async def test_basic_text_streaming(self) -> None:
|
|
agent = _make_agent(
|
|
stream_updates=[
|
|
AgentResponseUpdate(contents=[Content.from_text("Hello ")], role="assistant"),
|
|
AgentResponseUpdate(contents=[Content.from_text("world!")], role="assistant"),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
assert "text/event-stream" in resp.headers["content-type"]
|
|
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
|
|
assert types[0] == "response.created"
|
|
assert types[1] == "response.in_progress"
|
|
assert types[-1] == "response.completed"
|
|
assert "response.output_text.delta" in types
|
|
assert types.count("response.output_text.delta") == 2
|
|
assert "response.output_text.done" in types
|
|
|
|
# Verify the accumulated text in the done event
|
|
done_events = [e for e in events if e["event"] == "response.output_text.done"]
|
|
assert len(done_events) == 1
|
|
assert done_events[0]["data"]["text"] == "Hello world!"
|
|
|
|
async def test_function_call_streaming(self) -> None:
|
|
agent = _make_agent(
|
|
stream_updates=[
|
|
AgentResponseUpdate(
|
|
contents=[Content.from_function_call("call_1", "search", arguments='{"q":')],
|
|
role="assistant",
|
|
),
|
|
AgentResponseUpdate(
|
|
contents=[Content.from_function_call("call_1", "search", arguments=' "hello"}')],
|
|
role="assistant",
|
|
),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
|
|
assert types[0] == "response.created"
|
|
assert types[-1] == "response.completed"
|
|
assert types.count("response.function_call_arguments.delta") == 2
|
|
assert "response.function_call_arguments.done" in types
|
|
|
|
# Verify accumulated arguments
|
|
args_done = [e for e in events if e["event"] == "response.function_call_arguments.done"]
|
|
assert len(args_done) == 1
|
|
assert args_done[0]["data"]["arguments"] == '{"q": "hello"}'
|
|
|
|
async def test_function_call_streaming_serializes_dataclass_arguments(self) -> None:
|
|
@dataclass
|
|
class HandoffLikeRequest:
|
|
agent_response: AgentResponse
|
|
|
|
request = HandoffLikeRequest(
|
|
agent_response=AgentResponse(
|
|
messages=[Message(role="assistant", contents=[Content.from_text("Need more details")])]
|
|
)
|
|
)
|
|
agent = _make_agent(
|
|
stream_updates=[
|
|
AgentResponseUpdate(
|
|
contents=[Content.from_function_call("call_1", "handoff_to_refund", arguments=request)],
|
|
role="assistant",
|
|
),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
args_done = [e for e in events if e["event"] == "response.function_call_arguments.done"]
|
|
assert len(args_done) == 1
|
|
|
|
payload = json.loads(args_done[0]["data"]["arguments"])
|
|
assert payload["agent_response"]["type"] == "agent_response"
|
|
assert payload["agent_response"]["messages"][0]["contents"][0]["text"] == "Need more details"
|
|
|
|
async def test_alternating_text_and_function_call(self) -> None:
|
|
agent = _make_agent(
|
|
stream_updates=[
|
|
# Text deltas
|
|
AgentResponseUpdate(contents=[Content.from_text("Let me ")], role="assistant"),
|
|
AgentResponseUpdate(contents=[Content.from_text("search...")], role="assistant"),
|
|
# Function call argument deltas
|
|
AgentResponseUpdate(
|
|
contents=[Content.from_function_call("call_1", "search", arguments='{"q":')],
|
|
role="assistant",
|
|
),
|
|
AgentResponseUpdate(
|
|
contents=[Content.from_function_call("call_1", "search", arguments=' "x"}')],
|
|
role="assistant",
|
|
),
|
|
# More text deltas
|
|
AgentResponseUpdate(contents=[Content.from_text("Found ")], role="assistant"),
|
|
AgentResponseUpdate(contents=[Content.from_text("it!")], role="assistant"),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
|
|
assert types[0] == "response.created"
|
|
assert types[-1] == "response.completed"
|
|
|
|
# 4 text deltas + 2 function call argument deltas
|
|
assert types.count("response.output_text.delta") == 4
|
|
assert types.count("response.function_call_arguments.delta") == 2
|
|
|
|
# 3 distinct output items (text, fc, text)
|
|
assert types.count("response.output_item.added") == 3
|
|
assert types.count("response.output_item.done") == 3
|
|
|
|
# Verify accumulated content
|
|
text_done = [e for e in events if e["event"] == "response.output_text.done"]
|
|
assert len(text_done) == 2
|
|
assert text_done[0]["data"]["text"] == "Let me search..."
|
|
assert text_done[1]["data"]["text"] == "Found it!"
|
|
|
|
args_done = [e for e in events if e["event"] == "response.function_call_arguments.done"]
|
|
assert len(args_done) == 1
|
|
assert args_done[0]["data"]["arguments"] == '{"q": "x"}'
|
|
|
|
async def test_reasoning_then_text_streaming(self) -> None:
|
|
agent = _make_agent(
|
|
stream_updates=[
|
|
# Reasoning deltas
|
|
AgentResponseUpdate(contents=[Content.from_text_reasoning(text="Let me ")], role="assistant"),
|
|
AgentResponseUpdate(contents=[Content.from_text_reasoning(text="think...")], role="assistant"),
|
|
# Text deltas
|
|
AgentResponseUpdate(contents=[Content.from_text("The answer ")], role="assistant"),
|
|
AgentResponseUpdate(contents=[Content.from_text("is 42")], role="assistant"),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
|
|
assert types[0] == "response.created"
|
|
assert types[-1] == "response.completed"
|
|
# Reasoning + text = 2 output items
|
|
assert types.count("response.output_item.added") == 2
|
|
assert types.count("response.output_item.done") == 2
|
|
assert types.count("response.output_text.delta") == 2
|
|
|
|
# Verify accumulated text
|
|
text_done = [e for e in events if e["event"] == "response.output_text.done"]
|
|
assert len(text_done) == 1
|
|
assert text_done[0]["data"]["text"] == "The answer is 42"
|
|
|
|
async def test_empty_streaming(self) -> None:
|
|
agent = _make_agent(stream_updates=[])
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
|
|
assert types == ["response.created", "response.in_progress", "response.completed"]
|
|
|
|
async def test_mixed_contents_in_single_update(self) -> None:
|
|
"""Text and function call in one update switches builder mid-update."""
|
|
agent = _make_agent(
|
|
stream_updates=[
|
|
AgentResponseUpdate(
|
|
contents=[
|
|
Content.from_text("Let me search"),
|
|
Content.from_function_call("call_1", "search", arguments='{"q": "test"}'),
|
|
],
|
|
role="assistant",
|
|
),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
|
|
assert "response.output_text.delta" in types
|
|
assert "response.output_text.done" in types
|
|
assert "response.function_call_arguments.delta" in types
|
|
assert "response.function_call_arguments.done" in types
|
|
|
|
async def test_different_function_call_ids_produce_separate_items(self) -> None:
|
|
agent = _make_agent(
|
|
stream_updates=[
|
|
AgentResponseUpdate(
|
|
contents=[Content.from_function_call("call_1", "func_a", arguments='{"x":1}')],
|
|
role="assistant",
|
|
),
|
|
AgentResponseUpdate(
|
|
contents=[Content.from_function_call("call_2", "func_b", arguments='{"y":2}')],
|
|
role="assistant",
|
|
),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
|
|
# Two separate function call items
|
|
assert types.count("response.output_item.added") == 2
|
|
assert types.count("response.function_call_arguments.done") == 2
|
|
|
|
async def test_mcp_tool_call_streaming(self) -> None:
|
|
agent = _make_agent(
|
|
stream_updates=[
|
|
AgentResponseUpdate(
|
|
contents=[
|
|
Content(
|
|
type="mcp_server_tool_call",
|
|
server_name="my_server",
|
|
tool_name="search",
|
|
arguments='{"query":',
|
|
)
|
|
],
|
|
role="assistant",
|
|
),
|
|
AgentResponseUpdate(
|
|
contents=[
|
|
Content(
|
|
type="mcp_server_tool_call",
|
|
server_name="my_server",
|
|
tool_name="search",
|
|
arguments=' "test"}',
|
|
)
|
|
],
|
|
role="assistant",
|
|
),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
|
|
assert types[0] == "response.created"
|
|
assert types[-1] == "response.completed"
|
|
assert "response.output_item.added" in types
|
|
assert "response.output_item.done" in types
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region _output_item_to_message conversion
|
|
|
|
|
|
class TestOutputItemToMessage:
|
|
"""Tests for _output_item_to_message covering all supported OutputItem types."""
|
|
|
|
async def test_output_message(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemOutputMessage, OutputMessageContentOutputTextContent
|
|
|
|
item = OutputItemOutputMessage({
|
|
"type": "output_message",
|
|
"role": "assistant",
|
|
"content": [OutputMessageContentOutputTextContent({"type": "output_text", "text": "hello"})],
|
|
"status": "completed",
|
|
"id": "msg-1",
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert len(msg.contents) == 1
|
|
assert msg.contents[0].type == "text"
|
|
assert msg.contents[0].text == "hello"
|
|
|
|
async def test_message(self) -> None:
|
|
from azure.ai.agentserver.responses.models import MessageContentInputTextContent, OutputItemMessage
|
|
|
|
item = OutputItemMessage({
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [MessageContentInputTextContent({"type": "input_text", "text": "hi"})],
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "user"
|
|
assert len(msg.contents) == 1
|
|
assert msg.contents[0].text == "hi"
|
|
|
|
async def test_function_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemFunctionToolCall
|
|
|
|
item = OutputItemFunctionToolCall({
|
|
"type": "function_call",
|
|
"call_id": "call_1",
|
|
"name": "get_weather",
|
|
"arguments": '{"city": "NYC"}',
|
|
"status": "completed",
|
|
"id": "fc-1",
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].call_id == "call_1"
|
|
assert msg.contents[0].name == "get_weather"
|
|
|
|
async def test_function_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import FunctionCallOutputItemParam
|
|
|
|
item = FunctionCallOutputItemParam({"type": "function_call_output", "call_id": "call_1", "output": "sunny"})
|
|
msg = await _output_item_to_message(item) # type: ignore[arg-type]
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "function_result"
|
|
assert msg.contents[0].call_id == "call_1"
|
|
assert msg.contents[0].result == "sunny"
|
|
|
|
async def test_reasoning(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemReasoningItem, SummaryTextContent
|
|
|
|
item = OutputItemReasoningItem({
|
|
"type": "reasoning",
|
|
"id": "r-1",
|
|
"summary": [SummaryTextContent({"type": "summary_text", "text": "thinking hard"})],
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert len(msg.contents) == 1
|
|
assert msg.contents[0].text == "thinking hard"
|
|
|
|
async def test_reasoning_no_summary(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemReasoningItem
|
|
|
|
item = OutputItemReasoningItem({"type": "reasoning", "id": "r-2"})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents == []
|
|
|
|
async def test_mcp_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemMcpToolCall
|
|
|
|
item = OutputItemMcpToolCall({
|
|
"type": "mcp_call",
|
|
"id": "mcp-1",
|
|
"server_label": "my_server",
|
|
"name": "search",
|
|
"arguments": '{"q": "test"}',
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "mcp_server_tool_call"
|
|
assert msg.contents[0].server_name == "my_server"
|
|
assert msg.contents[0].tool_name == "search"
|
|
|
|
async def test_mcp_approval_request(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemMcpApprovalRequest
|
|
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
saved = _make_function_approval_request_content(request_id="apr-1")
|
|
await storage.save_approval_request("apr-1", saved)
|
|
|
|
item = OutputItemMcpApprovalRequest({
|
|
"type": "mcp_approval_request",
|
|
"id": "apr-1",
|
|
"server_label": "srv",
|
|
"name": "dangerous_tool",
|
|
"arguments": "{}",
|
|
})
|
|
msg = await _output_item_to_message(item, approval_storage=storage)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_approval_request"
|
|
|
|
async def test_mcp_approval_response(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemMcpApprovalResponseResource
|
|
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
saved = _make_function_approval_request_content(request_id="apr-1")
|
|
await storage.save_approval_request("apr-1", saved)
|
|
|
|
item = OutputItemMcpApprovalResponseResource({
|
|
"type": "mcp_approval_response",
|
|
"id": "resp-1",
|
|
"approval_request_id": "apr-1",
|
|
"approve": True,
|
|
})
|
|
msg = await _output_item_to_message(item, approval_storage=storage)
|
|
assert msg.role == "user"
|
|
assert msg.contents[0].type == "function_approval_response"
|
|
assert msg.contents[0].approved is True
|
|
|
|
async def test_code_interpreter_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemCodeInterpreterToolCall
|
|
|
|
item = OutputItemCodeInterpreterToolCall({
|
|
"type": "code_interpreter_call",
|
|
"id": "ci-1",
|
|
"status": "completed",
|
|
"container_id": "c-1",
|
|
"code": "print('hi')",
|
|
"outputs": [],
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "code_interpreter_tool_call"
|
|
|
|
async def test_image_generation_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemImageGenToolCall
|
|
|
|
item = OutputItemImageGenToolCall({"type": "image_generation_call", "id": "ig-1", "status": "completed"})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "image_generation_tool_call"
|
|
|
|
async def test_shell_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import (
|
|
FunctionShellAction,
|
|
FunctionShellCallEnvironment,
|
|
OutputItemFunctionShellCall,
|
|
)
|
|
|
|
item = OutputItemFunctionShellCall({
|
|
"type": "shell_call",
|
|
"id": "sc-1",
|
|
"call_id": "call_sc",
|
|
"action": FunctionShellAction({"commands": ["ls", "-la"], "timeout_ms": 5000, "max_output_length": 1024}),
|
|
"status": "completed",
|
|
"environment": FunctionShellCallEnvironment({"type": "local"}),
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "shell_tool_call"
|
|
assert msg.contents[0].commands == ["ls", "-la"]
|
|
assert msg.contents[0].call_id == "call_sc"
|
|
|
|
async def test_shell_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import (
|
|
FunctionShellCallOutputContent,
|
|
FunctionShellCallOutputExitOutcome,
|
|
OutputItemFunctionShellCallOutput,
|
|
)
|
|
|
|
item = OutputItemFunctionShellCallOutput({
|
|
"type": "shell_call_output",
|
|
"id": "sco-1",
|
|
"call_id": "call_sc",
|
|
"status": "completed",
|
|
"output": [
|
|
FunctionShellCallOutputContent({
|
|
"stdout": "file.txt",
|
|
"stderr": "",
|
|
"outcome": FunctionShellCallOutputExitOutcome({"exit_code": 0}),
|
|
})
|
|
],
|
|
"max_output_length": 1024,
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "shell_tool_result"
|
|
assert msg.contents[0].call_id == "call_sc"
|
|
|
|
async def test_local_shell_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import LocalShellExecAction, OutputItemLocalShellToolCall
|
|
|
|
item = OutputItemLocalShellToolCall({
|
|
"type": "local_shell_call",
|
|
"id": "lsc-1",
|
|
"call_id": "call_lsc",
|
|
"action": LocalShellExecAction({"type": "exec", "command": ["echo", "hello"], "env": {}}),
|
|
"status": "completed",
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "shell_tool_call"
|
|
assert msg.contents[0].commands == ["echo", "hello"]
|
|
|
|
async def test_local_shell_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemLocalShellToolCallOutput
|
|
|
|
item = OutputItemLocalShellToolCallOutput({
|
|
"type": "local_shell_call_output",
|
|
"id": "lsco-1",
|
|
"output": "hello\n",
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "shell_tool_result"
|
|
|
|
async def test_file_search_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemFileSearchToolCall
|
|
|
|
item = OutputItemFileSearchToolCall({
|
|
"type": "file_search_call",
|
|
"id": "fs-1",
|
|
"status": "completed",
|
|
"queries": ["what is AI"],
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "file_search"
|
|
assert '"what is AI"' in (msg.contents[0].arguments or "")
|
|
|
|
async def test_web_search_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemWebSearchToolCall, WebSearchActionSearch
|
|
|
|
item = OutputItemWebSearchToolCall({
|
|
"type": "web_search_call",
|
|
"id": "ws-1",
|
|
"status": "completed",
|
|
"action": WebSearchActionSearch({"type": "search", "query": "test"}),
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "web_search"
|
|
|
|
async def test_computer_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ComputerAction, OutputItemComputerToolCall
|
|
|
|
item = OutputItemComputerToolCall({
|
|
"type": "computer_call",
|
|
"id": "cc-1",
|
|
"call_id": "call_cc",
|
|
"action": ComputerAction({"type": "click"}),
|
|
"pending_safety_checks": [],
|
|
"status": "completed",
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "computer_use"
|
|
|
|
async def test_computer_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import (
|
|
ComputerScreenshotImage,
|
|
OutputItemComputerToolCallOutputResource,
|
|
)
|
|
|
|
item = OutputItemComputerToolCallOutputResource({
|
|
"type": "computer_call_output",
|
|
"call_id": "call_cc",
|
|
"output": ComputerScreenshotImage({
|
|
"type": "computer_screenshot",
|
|
"image_url": "data:image/png;base64,abc",
|
|
}),
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "function_result"
|
|
assert msg.contents[0].call_id == "call_cc"
|
|
|
|
async def test_custom_tool_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemCustomToolCall
|
|
|
|
item = OutputItemCustomToolCall({
|
|
"type": "custom_tool_call",
|
|
"call_id": "call_ct",
|
|
"name": "my_tool",
|
|
"input": '{"key": "value"}',
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "my_tool"
|
|
assert msg.contents[0].arguments == '{"key": "value"}'
|
|
|
|
async def test_custom_tool_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemCustomToolCallOutput
|
|
|
|
item = OutputItemCustomToolCallOutput({
|
|
"type": "custom_tool_call_output",
|
|
"call_id": "call_ct",
|
|
"output": "result text",
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "function_result"
|
|
assert msg.contents[0].result == "result text"
|
|
|
|
async def test_custom_tool_call_output_with_mcp_call_id_routes_to_mcp_server_tool_result(self) -> None:
|
|
"""When the host wrote a hosted-MCP result via
|
|
`aoutput_item_custom_tool_call_output`, the persisted call_id keeps
|
|
its `mcp_*` prefix. On read, that result must reconstruct as a
|
|
`mcp_server_tool_result` Content (not `function_result`), so the
|
|
chat-client serialize layer treats it as a hosted-MCP result and
|
|
does not produce an orphan `function_call_output`.
|
|
"""
|
|
from azure.ai.agentserver.responses.models import OutputItemCustomToolCallOutput
|
|
|
|
item = OutputItemCustomToolCallOutput({
|
|
"type": "custom_tool_call_output",
|
|
"call_id": "mcp_06b686e11f118cf40169f0e5badb3081979842929d5cf04920",
|
|
"output": "found 10 cats",
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "tool"
|
|
assert len(msg.contents) == 1
|
|
c = msg.contents[0]
|
|
assert c.type == "mcp_server_tool_result", (
|
|
f"expected mcp_server_tool_result for mcp_-prefixed call_id; got {c.type}"
|
|
)
|
|
assert c.call_id == "mcp_06b686e11f118cf40169f0e5badb3081979842929d5cf04920"
|
|
|
|
async def test_apply_patch_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ApplyPatchUpdateFileOperation, OutputItemApplyPatchToolCall
|
|
|
|
item = OutputItemApplyPatchToolCall({
|
|
"type": "apply_patch_call",
|
|
"id": "ap-1",
|
|
"call_id": "call_ap",
|
|
"status": "completed",
|
|
"operation": ApplyPatchUpdateFileOperation({
|
|
"type": "update_file",
|
|
"path": "file.py",
|
|
"diff": "+ new line",
|
|
}),
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "apply_patch"
|
|
|
|
async def test_apply_patch_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemApplyPatchToolCallOutput
|
|
|
|
item = OutputItemApplyPatchToolCallOutput({
|
|
"type": "apply_patch_call_output",
|
|
"id": "apo-1",
|
|
"call_id": "call_ap",
|
|
"status": "completed",
|
|
"output": "patch applied",
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "function_result"
|
|
assert msg.contents[0].result == "patch applied"
|
|
|
|
async def test_oauth_consent_request(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OAuthConsentRequestOutputItem
|
|
|
|
item = OAuthConsentRequestOutputItem({
|
|
"type": "oauth_consent_request",
|
|
"id": "oauth-1",
|
|
"consent_link": "https://example.com/consent",
|
|
"server_label": "my_server",
|
|
})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "oauth_consent_request"
|
|
assert msg.contents[0].consent_link == "https://example.com/consent"
|
|
|
|
async def test_structured_outputs_dict(self) -> None:
|
|
from azure.ai.agentserver.responses.models import StructuredOutputsOutputItem
|
|
|
|
item = StructuredOutputsOutputItem({"type": "structured_outputs", "id": "so-1", "output": {"answer": 42}})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "text"
|
|
assert json.loads(msg.contents[0].text or "") == {"answer": 42}
|
|
|
|
async def test_structured_outputs_string(self) -> None:
|
|
from azure.ai.agentserver.responses.models import StructuredOutputsOutputItem
|
|
|
|
item = StructuredOutputsOutputItem({"type": "structured_outputs", "id": "so-2", "output": "plain text"})
|
|
msg = await _output_item_to_message(item)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].text == "plain text"
|
|
|
|
async def test_unsupported_type_raises(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItem
|
|
|
|
item = OutputItem({"type": "some_unknown_type"})
|
|
with pytest.raises(ValueError, match="Unsupported OutputItem type: some_unknown_type"):
|
|
await _output_item_to_message(item)
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region _item_to_message conversion
|
|
|
|
|
|
class TestItemToMessage:
|
|
"""Tests for _item_to_message covering all supported Item types."""
|
|
|
|
async def test_message_with_string_content(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemMessage
|
|
|
|
item = ItemMessage({"type": "message", "role": "user", "content": "hello"})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "user"
|
|
assert len(msg.contents) == 1
|
|
assert msg.contents[0].type == "text"
|
|
assert msg.contents[0].text == "hello"
|
|
|
|
async def test_message_with_input_text_content(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemMessage, MessageContentInputTextContent
|
|
|
|
item = ItemMessage({
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [MessageContentInputTextContent({"type": "input_text", "text": "hi there"})],
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "user"
|
|
assert len(msg.contents) == 1
|
|
assert msg.contents[0].text == "hi there"
|
|
|
|
async def test_message_with_multiple_contents(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemMessage, MessageContentInputTextContent
|
|
|
|
item = ItemMessage({
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
MessageContentInputTextContent({"type": "input_text", "text": "first"}),
|
|
MessageContentInputTextContent({"type": "input_text", "text": "second"}),
|
|
],
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert len(msg.contents) == 2
|
|
assert msg.contents[0].text == "first"
|
|
assert msg.contents[1].text == "second"
|
|
|
|
async def test_output_message(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemOutputMessage, OutputMessageContentOutputTextContent
|
|
|
|
item = ItemOutputMessage({
|
|
"type": "output_message",
|
|
"role": "assistant",
|
|
"content": [OutputMessageContentOutputTextContent({"type": "output_text", "text": "response"})],
|
|
"status": "completed",
|
|
"id": "msg-1",
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert len(msg.contents) == 1
|
|
assert msg.contents[0].type == "text"
|
|
assert msg.contents[0].text == "response"
|
|
|
|
async def test_function_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemFunctionToolCall
|
|
|
|
item = ItemFunctionToolCall({
|
|
"type": "function_call",
|
|
"call_id": "call_1",
|
|
"name": "get_weather",
|
|
"arguments": '{"city": "NYC"}',
|
|
"status": "completed",
|
|
"id": "fc-1",
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].call_id == "call_1"
|
|
assert msg.contents[0].name == "get_weather"
|
|
assert msg.contents[0].arguments == '{"city": "NYC"}'
|
|
|
|
async def test_function_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import FunctionCallOutputItemParam
|
|
|
|
item = FunctionCallOutputItemParam({"type": "function_call_output", "call_id": "call_1", "output": "sunny"})
|
|
msg = await _item_to_message(item) # type: ignore[arg-type]
|
|
assert msg is not None
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "function_result"
|
|
assert msg.contents[0].call_id == "call_1"
|
|
assert msg.contents[0].result == "sunny"
|
|
|
|
async def test_function_call_output_non_string(self) -> None:
|
|
from azure.ai.agentserver.responses.models import FunctionCallOutputItemParam
|
|
|
|
item = FunctionCallOutputItemParam({"type": "function_call_output", "call_id": "call_2", "output": 42})
|
|
msg = await _item_to_message(item) # type: ignore[arg-type]
|
|
assert msg is not None
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].result == "42"
|
|
|
|
async def test_reasoning_with_summary(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemReasoningItem, SummaryTextContent
|
|
|
|
item = ItemReasoningItem({
|
|
"type": "reasoning",
|
|
"id": "r-1",
|
|
"summary": [SummaryTextContent({"type": "summary_text", "text": "thinking hard"})],
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert len(msg.contents) == 1
|
|
assert msg.contents[0].text == "thinking hard"
|
|
|
|
async def test_reasoning_no_summary(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemReasoningItem
|
|
|
|
item = ItemReasoningItem({"type": "reasoning", "id": "r-2"})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents == []
|
|
|
|
async def test_mcp_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemMcpToolCall
|
|
|
|
item = ItemMcpToolCall({
|
|
"type": "mcp_call",
|
|
"id": "mcp-1",
|
|
"server_label": "my_server",
|
|
"name": "search",
|
|
"arguments": '{"q": "test"}',
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "mcp_server_tool_call"
|
|
assert msg.contents[0].server_name == "my_server"
|
|
assert msg.contents[0].tool_name == "search"
|
|
|
|
async def test_mcp_approval_request(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemMcpApprovalRequest
|
|
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
saved = _make_function_approval_request_content(request_id="apr-1")
|
|
await storage.save_approval_request("apr-1", saved)
|
|
|
|
item = ItemMcpApprovalRequest({
|
|
"type": "mcp_approval_request",
|
|
"id": "apr-1",
|
|
"server_label": "srv",
|
|
"name": "dangerous_tool",
|
|
"arguments": "{}",
|
|
})
|
|
msg = await _item_to_message(item, approval_storage=storage)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_approval_request"
|
|
|
|
async def test_mcp_approval_response(self) -> None:
|
|
from azure.ai.agentserver.responses.models import MCPApprovalResponse
|
|
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
saved = _make_function_approval_request_content(request_id="apr-1")
|
|
await storage.save_approval_request("apr-1", saved)
|
|
|
|
item = MCPApprovalResponse({
|
|
"type": "mcp_approval_response",
|
|
"approval_request_id": "apr-1",
|
|
"approve": True,
|
|
})
|
|
msg = await _item_to_message(item, approval_storage=storage) # type: ignore[arg-type]
|
|
assert msg is not None
|
|
assert msg.role == "user"
|
|
assert msg.contents[0].type == "function_approval_response"
|
|
assert msg.contents[0].approved is True
|
|
|
|
async def test_code_interpreter_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemCodeInterpreterToolCall
|
|
|
|
item = ItemCodeInterpreterToolCall({
|
|
"type": "code_interpreter_call",
|
|
"id": "ci-1",
|
|
"status": "completed",
|
|
"container_id": "c-1",
|
|
"code": "print('hi')",
|
|
"outputs": [],
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "code_interpreter_tool_call"
|
|
|
|
async def test_image_generation_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemImageGenToolCall
|
|
|
|
item = ItemImageGenToolCall({"type": "image_generation_call", "id": "ig-1", "status": "completed"})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "image_generation_tool_call"
|
|
|
|
async def test_shell_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import FunctionShellAction, FunctionShellCallItemParam
|
|
|
|
item = FunctionShellCallItemParam({
|
|
"type": "shell_call",
|
|
"call_id": "call_sc",
|
|
"action": FunctionShellAction({"commands": ["ls", "-la"], "timeout_ms": 5000, "max_output_length": 1024}),
|
|
"status": "in_progress",
|
|
})
|
|
msg = await _item_to_message(item) # type: ignore[arg-type]
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "shell_tool_call"
|
|
assert msg.contents[0].commands == ["ls", "-la"]
|
|
assert msg.contents[0].call_id == "call_sc"
|
|
|
|
async def test_shell_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import (
|
|
FunctionShellCallOutputContent,
|
|
FunctionShellCallOutputExitOutcome,
|
|
FunctionShellCallOutputItemParam,
|
|
)
|
|
|
|
item = FunctionShellCallOutputItemParam({
|
|
"type": "shell_call_output",
|
|
"call_id": "call_sc",
|
|
"output": [
|
|
FunctionShellCallOutputContent({
|
|
"stdout": "file.txt",
|
|
"stderr": "",
|
|
"outcome": FunctionShellCallOutputExitOutcome({"exit_code": 0}),
|
|
})
|
|
],
|
|
"max_output_length": 1024,
|
|
})
|
|
msg = await _item_to_message(item) # type: ignore[arg-type]
|
|
assert msg is not None
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "shell_tool_result"
|
|
assert msg.contents[0].call_id == "call_sc"
|
|
|
|
async def test_local_shell_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemLocalShellToolCall, LocalShellExecAction
|
|
|
|
item = ItemLocalShellToolCall({
|
|
"type": "local_shell_call",
|
|
"id": "lsc-1",
|
|
"call_id": "call_lsc",
|
|
"action": LocalShellExecAction({"type": "exec", "command": ["echo", "hello"], "env": {}}),
|
|
"status": "completed",
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "shell_tool_call"
|
|
assert msg.contents[0].commands == ["echo", "hello"]
|
|
|
|
async def test_local_shell_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemLocalShellToolCallOutput
|
|
|
|
item = ItemLocalShellToolCallOutput({
|
|
"type": "local_shell_call_output",
|
|
"id": "lsco-1",
|
|
"output": "hello\n",
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "shell_tool_result"
|
|
|
|
async def test_file_search_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemFileSearchToolCall
|
|
|
|
item = ItemFileSearchToolCall({
|
|
"type": "file_search_call",
|
|
"id": "fs-1",
|
|
"status": "completed",
|
|
"queries": ["what is AI"],
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "file_search"
|
|
assert '"what is AI"' in (msg.contents[0].arguments or "")
|
|
|
|
async def test_web_search_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemWebSearchToolCall
|
|
|
|
item = ItemWebSearchToolCall({
|
|
"type": "web_search_call",
|
|
"id": "ws-1",
|
|
"status": "completed",
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "web_search"
|
|
|
|
async def test_computer_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ComputerAction, ItemComputerToolCall
|
|
|
|
item = ItemComputerToolCall({
|
|
"type": "computer_call",
|
|
"id": "cc-1",
|
|
"call_id": "call_cc",
|
|
"action": ComputerAction({"type": "click"}),
|
|
"pending_safety_checks": [],
|
|
"status": "completed",
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "computer_use"
|
|
|
|
async def test_computer_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ComputerCallOutputItemParam, ComputerScreenshotImage
|
|
|
|
item = ComputerCallOutputItemParam({
|
|
"type": "computer_call_output",
|
|
"call_id": "call_cc",
|
|
"output": ComputerScreenshotImage({
|
|
"type": "computer_screenshot",
|
|
"image_url": "data:image/png;base64,abc",
|
|
}),
|
|
})
|
|
msg = await _item_to_message(item) # type: ignore[arg-type]
|
|
assert msg is not None
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "function_result"
|
|
assert msg.contents[0].call_id == "call_cc"
|
|
|
|
async def test_custom_tool_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemCustomToolCall
|
|
|
|
item = ItemCustomToolCall({
|
|
"type": "custom_tool_call",
|
|
"call_id": "call_ct",
|
|
"name": "my_tool",
|
|
"input": '{"key": "value"}',
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "my_tool"
|
|
assert msg.contents[0].arguments == '{"key": "value"}'
|
|
|
|
async def test_custom_tool_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemCustomToolCallOutput
|
|
|
|
item = ItemCustomToolCallOutput({
|
|
"type": "custom_tool_call_output",
|
|
"call_id": "call_ct",
|
|
"output": "result text",
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "function_result"
|
|
assert msg.contents[0].result == "result text"
|
|
|
|
async def test_custom_tool_call_output_non_string(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemCustomToolCallOutput
|
|
|
|
item = ItemCustomToolCallOutput({
|
|
"type": "custom_tool_call_output",
|
|
"call_id": "call_ct2",
|
|
"output": 123,
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.contents[0].result == "123"
|
|
|
|
async def test_custom_tool_call_output_with_mcp_call_id_routes_to_mcp_server_tool_result(self) -> None:
|
|
"""Issue #5546: input items carrying a hosted-MCP result (from a
|
|
prior turn that the framework wrote via
|
|
`aoutput_item_custom_tool_call_output`) must reconstruct as a
|
|
`mcp_server_tool_result` Content, not `function_result`. Otherwise
|
|
the chat-client serialize layer turns it into an orphan
|
|
`function_call_output` with `mcp_*` call_id and the Responses API
|
|
rejects the next turn.
|
|
"""
|
|
from azure.ai.agentserver.responses.models import ItemCustomToolCallOutput
|
|
|
|
item = ItemCustomToolCallOutput({
|
|
"type": "custom_tool_call_output",
|
|
"call_id": "mcp_06b686e11f118cf40169f0e5badb3081979842929d5cf04920",
|
|
"output": "found 10 cats",
|
|
})
|
|
msg = await _item_to_message(item)
|
|
assert msg is not None
|
|
assert msg.role == "tool"
|
|
assert len(msg.contents) == 1
|
|
c = msg.contents[0]
|
|
assert c.type == "mcp_server_tool_result", (
|
|
f"expected mcp_server_tool_result for mcp_-prefixed call_id; got {c.type}"
|
|
)
|
|
assert c.call_id == "mcp_06b686e11f118cf40169f0e5badb3081979842929d5cf04920"
|
|
|
|
async def test_apply_patch_call(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ApplyPatchToolCallItemParam, ApplyPatchUpdateFileOperation
|
|
|
|
item = ApplyPatchToolCallItemParam({
|
|
"type": "apply_patch_call",
|
|
"call_id": "call_ap",
|
|
"operation": ApplyPatchUpdateFileOperation({
|
|
"type": "update_file",
|
|
"path": "file.py",
|
|
"diff": "+ new line",
|
|
}),
|
|
})
|
|
msg = await _item_to_message(item) # type: ignore[arg-type]
|
|
assert msg is not None
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_call"
|
|
assert msg.contents[0].name == "apply_patch"
|
|
|
|
async def test_apply_patch_call_output(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ApplyPatchToolCallOutputItemParam
|
|
|
|
item = ApplyPatchToolCallOutputItemParam({
|
|
"type": "apply_patch_call_output",
|
|
"call_id": "call_ap",
|
|
"output": "patch applied",
|
|
})
|
|
msg = await _item_to_message(item) # type: ignore[arg-type]
|
|
assert msg is not None
|
|
assert msg.role == "tool"
|
|
assert msg.contents[0].type == "function_result"
|
|
assert msg.contents[0].result == "patch applied"
|
|
|
|
async def test_unsupported_type_raises(self) -> None:
|
|
from azure.ai.agentserver.responses.models import Item
|
|
|
|
item = Item({"type": "some_unknown_type"})
|
|
with pytest.raises(ValueError, match="Unsupported Item type: some_unknown_type"):
|
|
await _item_to_message(item)
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Multi-turn with mixed content
|
|
|
|
|
|
async def _post_json(
|
|
server: ResponsesHostServer,
|
|
payload: dict[str, Any],
|
|
) -> httpx.Response:
|
|
"""Send a POST /responses request with a raw JSON payload."""
|
|
transport = httpx.ASGITransport(app=server)
|
|
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
|
return await client.post("/responses", json=payload)
|
|
|
|
|
|
def _make_multi_response_agent(
|
|
responses: list[AgentResponse],
|
|
stream_updates_list: list[list[AgentResponseUpdate]] | None = None,
|
|
) -> MagicMock:
|
|
"""Create a mock agent that returns different responses on successive calls."""
|
|
agent = MagicMock(spec=RawAgent)
|
|
agent.id = "test-agent"
|
|
agent.name = "Test Agent"
|
|
agent.description = "A mock agent for testing"
|
|
agent.context_providers = []
|
|
|
|
call_index = [0]
|
|
|
|
async def run_non_streaming(*args: Any, **kwargs: Any) -> AgentResponse:
|
|
idx = call_index[0]
|
|
call_index[0] += 1
|
|
return responses[idx]
|
|
|
|
async def _stream_gen(updates: list[AgentResponseUpdate]) -> AsyncIterator[AgentResponseUpdate]:
|
|
for update in updates:
|
|
yield update
|
|
|
|
def run_dispatch(*args: Any, **kwargs: Any) -> Any:
|
|
idx = call_index[0]
|
|
call_index[0] += 1
|
|
if kwargs.get("stream") and stream_updates_list is not None:
|
|
return ResponseStream(_stream_gen(stream_updates_list[idx])) # type: ignore
|
|
if not kwargs.get("stream"):
|
|
# Need to return a coroutine for non-streaming
|
|
async def _ret() -> AgentResponse:
|
|
return responses[idx]
|
|
|
|
return _ret()
|
|
raise NotImplementedError("Streaming not configured for this call index")
|
|
|
|
if stream_updates_list is not None:
|
|
agent.run = MagicMock(side_effect=run_dispatch)
|
|
else:
|
|
agent.run = AsyncMock(side_effect=run_non_streaming)
|
|
|
|
return agent
|
|
|
|
|
|
class TestMultiTurnMixedContent:
|
|
"""End-to-end multi-turn tests with mixed text and non-text content types."""
|
|
|
|
async def test_text_and_image_input_single_turn(self) -> None:
|
|
"""Agent receives a message with text and image content via URL."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("I see a cat!")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "Describe this animal"},
|
|
{"type": "input_image", "image_url": "https://example.com/cat.jpg"},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
# Verify agent received text + image
|
|
messages = agent.run.call_args.kwargs["messages"]
|
|
assert len(messages) == 1
|
|
assert messages[0].role == "user"
|
|
assert len(messages[0].contents) == 2
|
|
assert messages[0].contents[0].type == "text"
|
|
assert messages[0].contents[0].text == "Describe this animal"
|
|
assert messages[0].contents[1].type == "uri"
|
|
assert messages[0].contents[1].uri == "https://example.com/cat.jpg"
|
|
|
|
async def test_text_and_file_input_single_turn(self) -> None:
|
|
"""Agent receives a message with text and file content via URL."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("File received")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "Summarize this document"},
|
|
{"type": "input_file", "file_url": "https://example.com/doc.pdf", "filename": "doc.pdf"},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
messages = agent.run.call_args.kwargs["messages"]
|
|
assert len(messages) == 1
|
|
assert len(messages[0].contents) == 2
|
|
assert messages[0].contents[0].type == "text"
|
|
assert messages[0].contents[0].text == "Summarize this document"
|
|
assert messages[0].contents[1].type == "uri"
|
|
assert messages[0].contents[1].uri == "https://example.com/doc.pdf"
|
|
|
|
async def test_text_and_file_data_input_single_turn(self) -> None:
|
|
"""Agent receives a message with text and file content via inline file_data."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("File received")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "Summarize this document"},
|
|
{
|
|
"type": "input_file",
|
|
"file_data": "data:application/pdf;base64,JVBERi0xLjQ=",
|
|
"filename": "doc.pdf",
|
|
},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
messages = agent.run.call_args.kwargs["messages"]
|
|
assert len(messages) == 1
|
|
assert len(messages[0].contents) == 2
|
|
assert messages[0].contents[0].type == "text"
|
|
assert messages[0].contents[0].text == "Summarize this document"
|
|
assert messages[0].contents[1].type == "data"
|
|
assert messages[0].contents[1].uri == "data:application/pdf;base64,JVBERi0xLjQ="
|
|
|
|
async def test_text_mime_file_data_decoded(self) -> None:
|
|
"""Agent receives a text/* file_data that is base64-decoded to plain text."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Got it")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
import base64
|
|
|
|
encoded = base64.b64encode(b"Hello, world!").decode()
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "input_file",
|
|
"file_data": f"data:text/plain;base64,{encoded}",
|
|
"filename": "greeting.txt",
|
|
},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
|
|
messages = agent.run.call_args.kwargs["messages"]
|
|
assert len(messages) == 1
|
|
assert messages[0].contents[0].type == "text"
|
|
assert messages[0].contents[0].text == "[File: greeting.txt]\nHello, world!"
|
|
|
|
async def test_text_mime_file_data_invalid_base64_falls_through(self) -> None:
|
|
"""Invalid base64 in a text/* file_data falls through to URI passthrough."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Got it")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "input_file",
|
|
"file_data": "data:text/plain;base64,!!!invalid!!!",
|
|
"filename": "bad.txt",
|
|
},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
|
|
messages = agent.run.call_args.kwargs["messages"]
|
|
assert len(messages) == 1
|
|
assert messages[0].contents[0].type == "data"
|
|
assert messages[0].contents[0].uri == "data:text/plain;base64,!!!invalid!!!"
|
|
|
|
async def test_mixed_text_and_image_input(self) -> None:
|
|
"""Agent receives a single message with both text and image content."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Got it!")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "What's in this image?"},
|
|
{"type": "input_image", "image_url": "https://example.com/photo.jpg"},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
messages = agent.run.call_args.kwargs["messages"]
|
|
assert len(messages) == 1
|
|
assert len(messages[0].contents) == 2
|
|
assert messages[0].contents[0].type == "text"
|
|
assert messages[0].contents[0].text == "What's in this image?"
|
|
assert messages[0].contents[1].type == "uri"
|
|
assert messages[0].contents[1].uri == "https://example.com/photo.jpg"
|
|
|
|
async def test_function_call_items_in_input(self) -> None:
|
|
"""Input contains function_call and function_call_output items."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(
|
|
messages=[Message(role="assistant", contents=[Content.from_text("Weather is sunny!")])]
|
|
)
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{"type": "message", "role": "user", "content": "What's the weather?"},
|
|
{
|
|
"type": "function_call",
|
|
"id": "fc-1",
|
|
"call_id": "call_1",
|
|
"name": "get_weather",
|
|
"arguments": '{"city": "NYC"}',
|
|
"status": "completed",
|
|
},
|
|
{"type": "function_call_output", "call_id": "call_1", "output": "sunny, 72F"},
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
messages = agent.run.call_args.kwargs["messages"]
|
|
assert len(messages) == 3
|
|
assert messages[0].role == "user"
|
|
assert messages[0].contents[0].type == "text"
|
|
assert messages[1].role == "assistant"
|
|
assert messages[1].contents[0].type == "function_call"
|
|
assert messages[1].contents[0].name == "get_weather"
|
|
assert messages[2].role == "tool"
|
|
assert messages[2].contents[0].type == "function_result"
|
|
assert messages[2].contents[0].result == "sunny, 72F"
|
|
|
|
async def test_multi_turn_text_then_text_with_image(self) -> None:
|
|
"""First turn sends text, second turn sends text + image with previous_response_id."""
|
|
agent = _make_multi_response_agent([
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Send me an image")])]),
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Nice cat!")])]),
|
|
])
|
|
server = _make_server(agent)
|
|
|
|
# Turn 1: simple text
|
|
resp1 = await _post(server, input_text="Hello", stream=False)
|
|
assert resp1.status_code == 200
|
|
response_id = resp1.json()["id"]
|
|
|
|
# Turn 2: text + image input referencing turn 1
|
|
resp2 = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "Here is my cat photo"},
|
|
{"type": "input_image", "image_url": "https://example.com/cat.jpg"},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
"previous_response_id": response_id,
|
|
},
|
|
)
|
|
|
|
assert resp2.status_code == 200
|
|
body2 = resp2.json()
|
|
assert body2["status"] == "completed"
|
|
|
|
# Verify second call receives history from turn 1 + text+image input
|
|
second_call_messages = agent.run.call_args_list[1].kwargs["messages"]
|
|
# History: output message from turn 1 ("Send me an image")
|
|
# Input: message with text + image
|
|
assert len(second_call_messages) >= 2
|
|
# Last message should be the text+image input
|
|
last_msg = second_call_messages[-1]
|
|
assert last_msg.role == "user"
|
|
assert len(last_msg.contents) == 2
|
|
assert last_msg.contents[0].type == "text"
|
|
assert last_msg.contents[0].text == "Here is my cat photo"
|
|
assert last_msg.contents[1].type == "uri"
|
|
assert last_msg.contents[1].uri == "https://example.com/cat.jpg"
|
|
# History should include the assistant response from turn 1
|
|
history_msgs = second_call_messages[:-1]
|
|
assistant_texts = [
|
|
c.text for m in history_msgs if m.role == "assistant" for c in m.contents if c.type == "text"
|
|
]
|
|
assert "Send me an image" in assistant_texts
|
|
|
|
async def test_multi_turn_function_call_in_history(self) -> None:
|
|
"""Turn 1 produces function call + result, turn 2 sees them in history."""
|
|
agent = _make_multi_response_agent([
|
|
AgentResponse(
|
|
messages=[
|
|
Message(
|
|
role="assistant",
|
|
contents=[Content.from_function_call("call_1", "search", arguments='{"q": "cats"}')],
|
|
),
|
|
Message(role="tool", contents=[Content.from_function_result("call_1", result="found 10 cats")]),
|
|
Message(role="assistant", contents=[Content.from_text("I found 10 cats!")]),
|
|
]
|
|
),
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Here are more details")])]),
|
|
])
|
|
server = _make_server(agent)
|
|
|
|
# Turn 1
|
|
resp1 = await _post(server, input_text="Search for cats", stream=False)
|
|
assert resp1.status_code == 200
|
|
response_id = resp1.json()["id"]
|
|
|
|
# Verify turn 1 output has function_call, function_call_output, and message
|
|
types1 = [item["type"] for item in resp1.json()["output"]]
|
|
assert "function_call" in types1
|
|
assert "function_call_output" in types1
|
|
assert "message" in types1
|
|
|
|
# Turn 2
|
|
resp2 = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": "Tell me more",
|
|
"stream": False,
|
|
"previous_response_id": response_id,
|
|
},
|
|
)
|
|
assert resp2.status_code == 200
|
|
assert resp2.json()["status"] == "completed"
|
|
|
|
# Verify turn 2 received history including function call/result
|
|
second_call_messages = agent.run.call_args_list[1].kwargs["messages"]
|
|
roles = [m.role for m in second_call_messages]
|
|
assert "assistant" in roles
|
|
assert "tool" in roles
|
|
# The function call should be in the history
|
|
fc_contents = [
|
|
c for m in second_call_messages if m.role == "assistant" for c in m.contents if c.type == "function_call"
|
|
]
|
|
assert len(fc_contents) >= 1
|
|
assert fc_contents[0].name == "search"
|
|
|
|
async def test_multi_turn_reasoning_in_history(self) -> None:
|
|
"""Turn 1 produces reasoning + text, turn 2 sees them in history."""
|
|
agent = _make_multi_response_agent([
|
|
AgentResponse(
|
|
messages=[
|
|
Message(
|
|
role="assistant",
|
|
contents=[
|
|
Content.from_text_reasoning(text="Let me think about this..."),
|
|
Content.from_text("The answer is 42"),
|
|
],
|
|
),
|
|
]
|
|
),
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Indeed, it is 42")])]),
|
|
])
|
|
server = _make_server(agent)
|
|
|
|
# Turn 1
|
|
resp1 = await _post(server, input_text="What is the answer?", stream=False)
|
|
assert resp1.status_code == 200
|
|
response_id = resp1.json()["id"]
|
|
types1 = [item["type"] for item in resp1.json()["output"]]
|
|
assert "reasoning" in types1
|
|
assert "message" in types1
|
|
|
|
# Turn 2
|
|
resp2 = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": "Are you sure?",
|
|
"stream": False,
|
|
"previous_response_id": response_id,
|
|
},
|
|
)
|
|
assert resp2.status_code == 200
|
|
assert resp2.json()["status"] == "completed"
|
|
|
|
# Verify history includes the reasoning and text from turn 1
|
|
second_call_messages = agent.run.call_args_list[1].kwargs["messages"]
|
|
assert len(second_call_messages) >= 2 # history + new input
|
|
|
|
async def test_multi_turn_with_mixed_content_and_streaming(self) -> None:
|
|
"""Turn 1 non-streaming, turn 2 streaming with image input."""
|
|
turn2_updates = [
|
|
AgentResponseUpdate(contents=[Content.from_text("I see ")], role="assistant"),
|
|
AgentResponseUpdate(contents=[Content.from_text("a cat!")], role="assistant"),
|
|
]
|
|
|
|
agent = _make_multi_response_agent(
|
|
responses=[
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Send me an image")])]),
|
|
AgentResponse(messages=[]), # placeholder, not used for streaming
|
|
],
|
|
stream_updates_list=[
|
|
[], # placeholder for turn 1 (non-streaming)
|
|
turn2_updates,
|
|
],
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
# Turn 1: non-streaming text
|
|
resp1 = await _post(server, input_text="Hello", stream=False)
|
|
assert resp1.status_code == 200
|
|
response_id = resp1.json()["id"]
|
|
|
|
# Turn 2: streaming with image input
|
|
resp2 = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "Describe this:"},
|
|
{"type": "input_image", "image_url": "https://example.com/cat.jpg"},
|
|
],
|
|
}
|
|
],
|
|
"stream": True,
|
|
"previous_response_id": response_id,
|
|
},
|
|
)
|
|
|
|
assert resp2.status_code == 200
|
|
assert "text/event-stream" in resp2.headers["content-type"]
|
|
|
|
events = _parse_sse_events(resp2.text)
|
|
types = _sse_event_types(events)
|
|
assert types[0] == "response.created"
|
|
assert types[-1] == "response.completed"
|
|
assert "response.output_text.delta" in types
|
|
|
|
# Verify accumulated text
|
|
text_done = [e for e in events if e["event"] == "response.output_text.done"]
|
|
assert len(text_done) == 1
|
|
assert text_done[0]["data"]["text"] == "I see a cat!"
|
|
|
|
async def test_text_with_mcp_call_items(self) -> None:
|
|
"""Input contains text message + mcp_call item and the agent processes it."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(
|
|
messages=[Message(role="assistant", contents=[Content.from_text("MCP result received")])]
|
|
)
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{"type": "message", "role": "user", "content": "Search using MCP"},
|
|
{
|
|
"type": "mcp_call",
|
|
"id": "mcp-1",
|
|
"server_label": "my_server",
|
|
"name": "search",
|
|
"arguments": '{"query": "test"}',
|
|
},
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
messages = agent.run.call_args.kwargs["messages"]
|
|
assert len(messages) == 2
|
|
assert messages[0].role == "user"
|
|
assert messages[0].contents[0].type == "text"
|
|
assert messages[0].contents[0].text == "Search using MCP"
|
|
assert messages[1].role == "assistant"
|
|
assert messages[1].contents[0].type == "mcp_server_tool_call"
|
|
assert messages[1].contents[0].server_name == "my_server"
|
|
assert messages[1].contents[0].tool_name == "search"
|
|
|
|
async def test_three_turn_conversation_with_mixed_content(self) -> None:
|
|
"""Three-turn conversation: text → function call → image input."""
|
|
agent = _make_multi_response_agent([
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Hello! How can I help?")])]),
|
|
AgentResponse(
|
|
messages=[
|
|
Message(
|
|
role="assistant",
|
|
contents=[Content.from_function_call("call_1", "analyze", arguments='{"mode": "deep"}')],
|
|
),
|
|
Message(role="tool", contents=[Content.from_function_result("call_1", result="analysis complete")]),
|
|
Message(role="assistant", contents=[Content.from_text("Analysis done!")]),
|
|
]
|
|
),
|
|
AgentResponse(
|
|
messages=[Message(role="assistant", contents=[Content.from_text("The image shows a chart")])]
|
|
),
|
|
])
|
|
server = _make_server(agent)
|
|
|
|
# Turn 1: text
|
|
resp1 = await _post(server, input_text="Hi", stream=False)
|
|
assert resp1.status_code == 200
|
|
id1 = resp1.json()["id"]
|
|
|
|
# Turn 2: text, referencing turn 1
|
|
resp2 = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": "Analyze something",
|
|
"stream": False,
|
|
"previous_response_id": id1,
|
|
},
|
|
)
|
|
assert resp2.status_code == 200
|
|
id2 = resp2.json()["id"]
|
|
|
|
# Turn 3: image input, referencing turn 2
|
|
resp3 = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "What about this image?"},
|
|
{"type": "input_image", "image_url": "https://example.com/chart.png"},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
"previous_response_id": id2,
|
|
},
|
|
)
|
|
|
|
assert resp3.status_code == 200
|
|
assert resp3.json()["status"] == "completed"
|
|
|
|
# Verify turn 3 received full history from turns 1+2 plus new image input
|
|
third_call_messages = agent.run.call_args_list[2].kwargs["messages"]
|
|
# Should have: history from turn 1 (assistant text) + history from turn 2
|
|
# (function_call, function_call_output, text) + new input (text + image)
|
|
assert len(third_call_messages) >= 5
|
|
|
|
# Last message should contain the image
|
|
last_msg = third_call_messages[-1]
|
|
assert last_msg.role == "user"
|
|
image_contents = [c for c in last_msg.contents if c.type == "uri"]
|
|
assert len(image_contents) == 1
|
|
assert image_contents[0].uri == "https://example.com/chart.png"
|
|
|
|
# History should include function call from turn 2
|
|
fc_contents = [
|
|
c
|
|
for m in third_call_messages[:-1]
|
|
if m.role == "assistant"
|
|
for c in m.contents
|
|
if c.type == "function_call"
|
|
]
|
|
assert any(c.name == "analyze" for c in fc_contents)
|
|
|
|
async def test_input_with_hosted_file_image(self) -> None:
|
|
"""Input contains an image referenced by file_id (hosted file)."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Image analyzed")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "Analyze this image"},
|
|
{"type": "input_image", "file_id": "file-abc123"},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
messages = agent.run.call_args.kwargs["messages"]
|
|
assert len(messages) == 1
|
|
assert len(messages[0].contents) == 2
|
|
assert messages[0].contents[0].type == "text"
|
|
assert messages[0].contents[0].text == "Analyze this image"
|
|
assert messages[0].contents[1].type == "hosted_file"
|
|
assert messages[0].contents[1].file_id == "file-abc123"
|
|
|
|
async def test_multi_turn_text_and_image_then_text_and_file(self) -> None:
|
|
"""Turn 1 sends text+image, turn 2 sends text+file, both in history."""
|
|
agent = _make_multi_response_agent([
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("I see a landscape")])]),
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Document summarized")])]),
|
|
])
|
|
server = _make_server(agent)
|
|
|
|
# Turn 1: text + image
|
|
resp1 = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "What is in this photo?"},
|
|
{"type": "input_image", "image_url": "https://example.com/landscape.jpg"},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
assert resp1.status_code == 200
|
|
id1 = resp1.json()["id"]
|
|
|
|
# Turn 2: text + file, referencing turn 1
|
|
resp2 = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "Now summarize this report"},
|
|
{
|
|
"type": "input_file",
|
|
"file_url": "https://example.com/report.pdf",
|
|
"filename": "report.pdf",
|
|
},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
"previous_response_id": id1,
|
|
},
|
|
)
|
|
assert resp2.status_code == 200
|
|
assert resp2.json()["status"] == "completed"
|
|
|
|
# Verify turn 2 received history from turn 1 + new text+file input
|
|
second_call_messages = agent.run.call_args_list[1].kwargs["messages"]
|
|
assert len(second_call_messages) >= 2
|
|
|
|
# History should include the assistant response from turn 1
|
|
assistant_texts = [
|
|
c.text for m in second_call_messages if m.role == "assistant" for c in m.contents if c.type == "text"
|
|
]
|
|
assert "I see a landscape" in assistant_texts
|
|
|
|
# Last message should be text + file
|
|
last_msg = second_call_messages[-1]
|
|
assert last_msg.role == "user"
|
|
assert len(last_msg.contents) == 2
|
|
assert last_msg.contents[0].type == "text"
|
|
assert last_msg.contents[0].text == "Now summarize this report"
|
|
assert last_msg.contents[1].type == "uri"
|
|
assert last_msg.contents[1].uri == "https://example.com/report.pdf"
|
|
|
|
async def test_multi_turn_function_call_then_text_and_image(self) -> None:
|
|
"""Turn 1: text + function call + result, turn 2: text + image."""
|
|
agent = _make_multi_response_agent([
|
|
AgentResponse(
|
|
messages=[
|
|
Message(
|
|
role="assistant",
|
|
contents=[Content.from_function_call("call_1", "get_info", arguments='{"id": 1}')],
|
|
),
|
|
Message(role="tool", contents=[Content.from_function_result("call_1", result="info data")]),
|
|
Message(role="assistant", contents=[Content.from_text("Here is the info")]),
|
|
]
|
|
),
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("Image matches the data")])]),
|
|
])
|
|
server = _make_server(agent)
|
|
|
|
# Turn 1: text triggers function call
|
|
resp1 = await _post(server, input_text="Get info for item 1", stream=False)
|
|
assert resp1.status_code == 200
|
|
id1 = resp1.json()["id"]
|
|
|
|
types1 = [item["type"] for item in resp1.json()["output"]]
|
|
assert "function_call" in types1
|
|
assert "function_call_output" in types1
|
|
assert "message" in types1
|
|
|
|
# Turn 2: text + image referencing turn 1
|
|
resp2 = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "input_text", "text": "Does this image match?"},
|
|
{"type": "input_image", "image_url": "https://example.com/item1.jpg"},
|
|
],
|
|
}
|
|
],
|
|
"stream": False,
|
|
"previous_response_id": id1,
|
|
},
|
|
)
|
|
assert resp2.status_code == 200
|
|
assert resp2.json()["status"] == "completed"
|
|
|
|
# Verify turn 2 received history with function call + new text+image
|
|
second_call_messages = agent.run.call_args_list[1].kwargs["messages"]
|
|
# History should contain function_call and function_result from turn 1
|
|
fc_contents = [
|
|
c for m in second_call_messages if m.role == "assistant" for c in m.contents if c.type == "function_call"
|
|
]
|
|
assert any(c.name == "get_info" for c in fc_contents)
|
|
tool_contents = [
|
|
c for m in second_call_messages if m.role == "tool" for c in m.contents if c.type == "function_result"
|
|
]
|
|
assert any(c.result == "info data" for c in tool_contents)
|
|
|
|
# Last message should be text + image
|
|
last_msg = second_call_messages[-1]
|
|
assert last_msg.role == "user"
|
|
assert len(last_msg.contents) == 2
|
|
assert last_msg.contents[0].type == "text"
|
|
assert last_msg.contents[0].text == "Does this image match?"
|
|
assert last_msg.contents[1].type == "uri"
|
|
assert last_msg.contents[1].uri == "https://example.com/item1.jpg"
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Function approval round-trip
|
|
|
|
|
|
class TestFunctionApprovalStorage:
|
|
"""Unit tests for the function approval storage classes."""
|
|
|
|
async def test_in_memory_save_and_load(self) -> None:
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
request = _make_function_approval_request_content(request_id="apr_1")
|
|
await storage.save_approval_request("apr_1", request)
|
|
loaded = await storage.load_approval_request("apr_1")
|
|
assert loaded.type == "function_approval_request"
|
|
assert loaded.id == "apr_1" # type: ignore[attr-defined]
|
|
|
|
async def test_in_memory_duplicate_save_raises(self) -> None:
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
request = _make_function_approval_request_content(request_id="apr_1")
|
|
await storage.save_approval_request("apr_1", request)
|
|
with pytest.raises(ValueError, match="already exists"):
|
|
await storage.save_approval_request("apr_1", request)
|
|
|
|
async def test_in_memory_missing_load_raises(self) -> None:
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
with pytest.raises(KeyError):
|
|
await storage.load_approval_request("missing")
|
|
|
|
async def test_file_based_save_and_load_persists_across_instances(self, tmp_path: Any) -> None:
|
|
path = tmp_path / "subdir" / "approvals.json"
|
|
storage = FileBasedFunctionApprovalStorage(str(path))
|
|
request = _make_function_approval_request_content(request_id="apr_1")
|
|
await storage.save_approval_request("apr_1", request)
|
|
|
|
# Directory + file should now exist.
|
|
assert path.exists()
|
|
|
|
# A new instance pointing at the same path can load the saved entry.
|
|
storage2 = FileBasedFunctionApprovalStorage(str(path))
|
|
loaded = await storage2.load_approval_request("apr_1")
|
|
assert loaded.type == "function_approval_request"
|
|
assert loaded.id == "apr_1" # type: ignore[attr-defined]
|
|
# The embedded function_call survives the round trip.
|
|
assert loaded.function_call.name == "delete_file" # type: ignore[attr-defined]
|
|
|
|
async def test_file_based_duplicate_save_raises(self, tmp_path: Any) -> None:
|
|
path = tmp_path / "approvals.json"
|
|
storage = FileBasedFunctionApprovalStorage(str(path))
|
|
request = _make_function_approval_request_content(request_id="apr_1")
|
|
await storage.save_approval_request("apr_1", request)
|
|
with pytest.raises(ValueError, match="already exists"):
|
|
await storage.save_approval_request("apr_1", request)
|
|
|
|
async def test_file_based_missing_load_raises(self, tmp_path: Any) -> None:
|
|
path = tmp_path / "approvals.json"
|
|
storage = FileBasedFunctionApprovalStorage(str(path))
|
|
with pytest.raises(KeyError):
|
|
await storage.load_approval_request("missing")
|
|
|
|
|
|
class TestFunctionApprovalConversion:
|
|
"""Tests for the approval-aware paths in `_item_to_message` / `_output_item_to_message`."""
|
|
|
|
async def test_output_item_mcp_approval_request_loads_from_storage(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemMcpApprovalRequest
|
|
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
saved = _make_function_approval_request_content(request_id="apr-1")
|
|
await storage.save_approval_request("apr-1", saved)
|
|
|
|
item = OutputItemMcpApprovalRequest({
|
|
"type": "mcp_approval_request",
|
|
"id": "apr-1",
|
|
"server_label": "srv",
|
|
"name": "dangerous_tool",
|
|
"arguments": "{}",
|
|
})
|
|
msg = await _output_item_to_message(item, approval_storage=storage)
|
|
assert msg.role == "assistant"
|
|
c = msg.contents[0]
|
|
assert c.type == "function_approval_request"
|
|
assert c.id == "apr-1" # type: ignore[attr-defined]
|
|
# The full saved Content (incl. function_call) is restored.
|
|
assert c.function_call.name == "delete_file" # type: ignore[attr-defined]
|
|
|
|
async def test_output_item_mcp_approval_request_without_storage_raises(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemMcpApprovalRequest
|
|
|
|
item = OutputItemMcpApprovalRequest({
|
|
"type": "mcp_approval_request",
|
|
"id": "apr-1",
|
|
"server_label": "srv",
|
|
"name": "dangerous_tool",
|
|
"arguments": "{}",
|
|
})
|
|
with pytest.raises(ValueError, match="ApprovalStorage is required"):
|
|
await _output_item_to_message(item)
|
|
|
|
async def test_output_item_mcp_approval_response_resolves_to_approval_response(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemMcpApprovalResponseResource
|
|
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
saved = _make_function_approval_request_content(request_id="apr-1")
|
|
await storage.save_approval_request("apr-1", saved)
|
|
|
|
item = OutputItemMcpApprovalResponseResource({
|
|
"type": "mcp_approval_response",
|
|
"id": "resp-1",
|
|
"approval_request_id": "apr-1",
|
|
"approve": True,
|
|
})
|
|
msg = await _output_item_to_message(item, approval_storage=storage)
|
|
assert msg.role == "user"
|
|
c = msg.contents[0]
|
|
assert c.type == "function_approval_response"
|
|
assert c.approved is True # type: ignore[attr-defined]
|
|
assert c.id == "apr-1" # type: ignore[attr-defined]
|
|
assert c.function_call.name == "delete_file" # type: ignore[attr-defined]
|
|
|
|
async def test_output_item_mcp_approval_response_without_storage_raises(self) -> None:
|
|
from azure.ai.agentserver.responses.models import OutputItemMcpApprovalResponseResource
|
|
|
|
item = OutputItemMcpApprovalResponseResource({
|
|
"type": "mcp_approval_response",
|
|
"id": "resp-1",
|
|
"approval_request_id": "apr-1",
|
|
"approve": False,
|
|
})
|
|
with pytest.raises(ValueError, match="ApprovalStorage is required"):
|
|
await _output_item_to_message(item)
|
|
|
|
async def test_input_item_mcp_approval_request_loads_from_storage(self) -> None:
|
|
from azure.ai.agentserver.responses.models import ItemMcpApprovalRequest
|
|
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
saved = _make_function_approval_request_content(request_id="apr-1")
|
|
await storage.save_approval_request("apr-1", saved)
|
|
|
|
item = ItemMcpApprovalRequest({
|
|
"type": "mcp_approval_request",
|
|
"id": "apr-1",
|
|
"server_label": "srv",
|
|
"name": "dangerous_tool",
|
|
"arguments": "{}",
|
|
})
|
|
msg = await _item_to_message(item, approval_storage=storage)
|
|
assert msg.role == "assistant"
|
|
assert msg.contents[0].type == "function_approval_request"
|
|
assert msg.contents[0].id == "apr-1" # type: ignore[attr-defined]
|
|
|
|
async def test_input_item_mcp_approval_response_resolves_to_approval_response(self) -> None:
|
|
from azure.ai.agentserver.responses.models import MCPApprovalResponse
|
|
|
|
storage = InMemoryFunctionApprovalStorage()
|
|
saved = _make_function_approval_request_content(request_id="apr-1")
|
|
await storage.save_approval_request("apr-1", saved)
|
|
|
|
item = MCPApprovalResponse({
|
|
"type": "mcp_approval_response",
|
|
"approval_request_id": "apr-1",
|
|
"approve": False,
|
|
})
|
|
msg = await _item_to_message(item, approval_storage=storage) # type: ignore[arg-type]
|
|
assert msg.role == "user"
|
|
c = msg.contents[0]
|
|
assert c.type == "function_approval_response"
|
|
assert c.approved is False # type: ignore[attr-defined]
|
|
|
|
|
|
class TestFunctionApprovalRoundTrip:
|
|
"""End-to-end round-trip tests for the function approval flow.
|
|
|
|
Turn 1: the agent emits a `function_approval_request` content; the
|
|
server emits an `mcp_approval_request` output item and persists
|
|
the original Content under the emitted id in approval storage.
|
|
Turn 2: the caller sends an `mcp_approval_response` input item back;
|
|
the server resolves it (via approval storage) into a
|
|
`function_approval_response` content delivered to the agent.
|
|
"""
|
|
|
|
async def test_non_streaming_emits_mcp_approval_request_and_persists_to_storage(self) -> None:
|
|
request_content = _make_function_approval_request_content()
|
|
agent = _make_agent(response=AgentResponse(messages=[Message(role="assistant", contents=[request_content])]))
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post(server, stream=False)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
approval_items = [item for item in body["output"] if item["type"] == "mcp_approval_request"]
|
|
assert len(approval_items) == 1
|
|
approval_request_id = approval_items[0]["id"]
|
|
assert approval_items[0]["name"] == "delete_file"
|
|
assert approval_items[0]["server_label"] == "my_server"
|
|
|
|
# Storage must contain a saved entry under the emitted request id.
|
|
loaded = await server._approval_storage.load_approval_request( # pyright: ignore[reportPrivateUsage]
|
|
approval_request_id
|
|
)
|
|
assert loaded.type == "function_approval_request"
|
|
assert loaded.function_call.name == "delete_file" # type: ignore[attr-defined]
|
|
|
|
async def test_streaming_emits_mcp_approval_request_and_persists_to_storage(self) -> None:
|
|
request_content = _make_function_approval_request_content(request_id="apr_streaming")
|
|
agent = _make_agent(stream_updates=[AgentResponseUpdate(contents=[request_content], role="assistant")])
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post(server, stream=True)
|
|
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
assert types[0] == "response.created"
|
|
assert types[-1] == "response.completed"
|
|
|
|
approval_request_id: str | None = None
|
|
for e in events:
|
|
if e["event"] != "response.output_item.added":
|
|
continue
|
|
item = e["data"].get("item") or {}
|
|
if item.get("type") == "mcp_approval_request":
|
|
approval_request_id = item.get("id")
|
|
break
|
|
assert approval_request_id is not None
|
|
|
|
loaded = await server._approval_storage.load_approval_request( # pyright: ignore[reportPrivateUsage]
|
|
approval_request_id
|
|
)
|
|
assert loaded.type == "function_approval_request"
|
|
|
|
async def test_round_trip_approval_response_reaches_agent(self) -> None:
|
|
"""Two-turn: turn 1 emits an approval request; turn 2 sends an
|
|
approval response and the agent receives a `function_approval_response`."""
|
|
request_content = _make_function_approval_request_content()
|
|
|
|
agent = _make_multi_response_agent(
|
|
responses=[
|
|
AgentResponse(messages=[Message(role="assistant", contents=[request_content])]),
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("done")])]),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
first = await _post(server, stream=False)
|
|
assert first.status_code == 200
|
|
first_body = first.json()
|
|
approval_items = [item for item in first_body["output"] if item["type"] == "mcp_approval_request"]
|
|
assert len(approval_items) == 1
|
|
approval_request_id = approval_items[0]["id"]
|
|
|
|
# Send back an approval response that references the saved request id.
|
|
second_payload: dict[str, Any] = {
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "mcp_approval_response",
|
|
"approval_request_id": approval_request_id,
|
|
"approve": True,
|
|
}
|
|
],
|
|
"stream": False,
|
|
}
|
|
second = await _post_json(server, second_payload)
|
|
assert second.status_code == 200
|
|
|
|
# The agent's second invocation must have received a
|
|
# function_approval_response content carrying the original function_call.
|
|
assert agent.run.call_count == 2
|
|
second_call_kwargs = agent.run.call_args_list[1].kwargs
|
|
approval_responses = [
|
|
c for m in second_call_kwargs["messages"] for c in m.contents if c.type == "function_approval_response"
|
|
]
|
|
assert len(approval_responses) == 1
|
|
assert approval_responses[0].approved is True
|
|
assert approval_responses[0].function_call.name == "delete_file"
|
|
|
|
async def test_round_trip_approval_response_rejected(self) -> None:
|
|
"""Same as above but the user rejects the approval; the agent must
|
|
receive `approved=False`."""
|
|
request_content = _make_function_approval_request_content()
|
|
|
|
agent = _make_multi_response_agent(
|
|
responses=[
|
|
AgentResponse(messages=[Message(role="assistant", contents=[request_content])]),
|
|
AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])]),
|
|
]
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
first = await _post(server, stream=False)
|
|
approval_request_id = next(
|
|
item["id"] for item in first.json()["output"] if item["type"] == "mcp_approval_request"
|
|
)
|
|
|
|
second = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "mcp_approval_response",
|
|
"approval_request_id": approval_request_id,
|
|
"approve": False,
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
assert second.status_code == 200
|
|
|
|
second_call_kwargs = agent.run.call_args_list[1].kwargs
|
|
approval_responses = [
|
|
c for m in second_call_kwargs["messages"] for c in m.contents if c.type == "function_approval_response"
|
|
]
|
|
assert len(approval_responses) == 1
|
|
assert approval_responses[0].approved is False
|
|
|
|
async def test_approval_response_referencing_unknown_id_fails(self) -> None:
|
|
"""Sending an `mcp_approval_response` for a request id that was
|
|
never persisted must fail (storage raises KeyError)."""
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post_json(
|
|
server,
|
|
{
|
|
"model": "test-model",
|
|
"input": [
|
|
{
|
|
"type": "mcp_approval_response",
|
|
"approval_request_id": "apr_unknown",
|
|
"approve": True,
|
|
}
|
|
],
|
|
"stream": False,
|
|
},
|
|
)
|
|
# The handler raises a KeyError when the storage lookup misses;
|
|
# the hosting layer surfaces this as a 5xx response.
|
|
assert resp.status_code >= 500
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Checkpoint context path validation
|
|
|
|
|
|
class TestCheckpointContextPathValidation:
|
|
"""Regression tests for the path-traversal hardening of checkpoint storage.
|
|
|
|
These tests guard against CWE-22 in the workflow hosting path. The hosting
|
|
code joins caller-supplied identifiers (``previous_response_id``) and
|
|
server-generated identifiers (``conversation_id`` / ``response_id``) under
|
|
the configured checkpoint root. Without validation, traversal segments
|
|
such as ``../../escape`` or absolute paths cause directory creation
|
|
outside the intended root.
|
|
"""
|
|
|
|
@staticmethod
|
|
def _helper() -> Callable[[str, str], FileCheckpointStorage]:
|
|
from agent_framework_foundry_hosting._responses import ( # pyright: ignore[reportPrivateUsage]
|
|
_checkpoint_storage_for_context,
|
|
)
|
|
|
|
return _checkpoint_storage_for_context
|
|
|
|
def test_valid_segment_creates_storage_under_root(self, tmp_path: Any) -> None:
|
|
helper = self._helper()
|
|
root = tmp_path / "root"
|
|
root.mkdir()
|
|
storage = helper(str(root), "resp_abc123")
|
|
assert storage.storage_path.is_dir()
|
|
assert storage.storage_path.parent == root.resolve()
|
|
|
|
@pytest.mark.parametrize(
|
|
"bad_id",
|
|
[
|
|
# Original MSRC repro: traversal embedded inside an id-shaped value.
|
|
# The 14 ``A``s pad the suffix to mimic the exact length of the
|
|
# ``api-made-dir<14-char-suffix>`` segment from the original report.
|
|
"caresp_x/../../service-data/api-made-dir" + "A" * 14,
|
|
# Variant report repros.
|
|
"../../escape",
|
|
"..",
|
|
".",
|
|
"...",
|
|
"/tmp/escape",
|
|
"/absolute/path",
|
|
"C:\\temp\\escape",
|
|
"..\\..\\escape",
|
|
"foo\\..\\bar",
|
|
"foo/bar",
|
|
"with\x00null",
|
|
"",
|
|
],
|
|
)
|
|
def test_traversal_and_separator_payloads_are_rejected(self, tmp_path: Any, bad_id: str) -> None:
|
|
helper = self._helper()
|
|
# Use a dedicated root *inside* tmp_path so we can assert that nothing
|
|
# was created anywhere under tmp_path (root, siblings, or above).
|
|
# Asserting against tmp_path.parent would be flaky under parallel test
|
|
# execution because tmp_path.parent is shared across tests.
|
|
root = tmp_path / "root"
|
|
root.mkdir()
|
|
before = sorted(p.name for p in tmp_path.iterdir())
|
|
with pytest.raises(RuntimeError):
|
|
helper(str(root), bad_id)
|
|
# No sibling/escape directory should have been created next to the root.
|
|
after = sorted(p.name for p in tmp_path.iterdir())
|
|
assert before == after, f"Unexpected filesystem artifacts created for payload {bad_id!r}"
|
|
# And nothing inside the root either.
|
|
assert list(root.iterdir()) == []
|
|
|
|
def test_non_string_context_id_is_rejected(self, tmp_path: Any) -> None:
|
|
helper = self._helper()
|
|
with pytest.raises(RuntimeError):
|
|
helper(str(tmp_path), None) # type: ignore[arg-type]
|
|
|
|
def test_url_encoded_traversal_is_treated_as_literal_segment(self, tmp_path: Any) -> None:
|
|
"""URL-encoded traversal should not decode to traversal at the filesystem layer.
|
|
|
|
The hosting layer never URL-decodes ids before using them; the helper
|
|
should accept ``%2e%2e`` as a single literal segment (no escape).
|
|
"""
|
|
helper = self._helper()
|
|
root = tmp_path / "root"
|
|
root.mkdir()
|
|
storage = helper(str(root), "%2e%2e")
|
|
assert storage.storage_path.parent == root.resolve()
|
|
assert storage.storage_path.name == "%2e%2e"
|
|
|
|
@pytest.mark.parametrize(
|
|
"context_field,bad_id",
|
|
[
|
|
# Restore sink: caller-controlled previous_response_id.
|
|
("previous_response_id", "../../escape"),
|
|
("previous_response_id", "/tmp/escape-abs"),
|
|
("previous_response_id", "caresp_x/../../service-data/api-made-dir" + "A" * 14),
|
|
# Restore sink: server-issued conversation_id (defense in depth).
|
|
("conversation_id", "../../escape"),
|
|
# Write sink: malicious response_id (defense in depth).
|
|
("response_id", "../../escape"),
|
|
],
|
|
)
|
|
async def test_handle_inner_workflow_rejects_malicious_context_id(
|
|
self, tmp_path: Any, context_field: str, bad_id: str
|
|
) -> None:
|
|
"""End-to-end: ``_handle_inner_workflow`` must reject malicious ids on
|
|
both the restore sink (``previous_response_id`` / ``conversation_id``)
|
|
and the write sink (``response_id``) without creating any directories.
|
|
"""
|
|
from unittest.mock import patch
|
|
|
|
from agent_framework import WorkflowAgent
|
|
from azure.ai.agentserver.responses import ResponseContext
|
|
from azure.ai.agentserver.responses.models import CreateResponse
|
|
|
|
# Build a mock that satisfies isinstance(agent, WorkflowAgent) and the
|
|
# constructor's "no existing checkpointing" guard.
|
|
agent = MagicMock(spec=WorkflowAgent)
|
|
agent.id = "wf-agent"
|
|
agent.name = "wf"
|
|
agent.description = ""
|
|
agent.context_providers = []
|
|
agent.workflow = MagicMock()
|
|
agent.workflow.name = "wf"
|
|
agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False)
|
|
|
|
# Constructor inspects WorkflowAgent.workflow internals; bypass setup
|
|
# by feeding a configured mock through a normal init.
|
|
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
|
|
# Re-root checkpoint storage at our isolated tmp_path so we can detect
|
|
# any escape attempt on the filesystem.
|
|
root = tmp_path / "root"
|
|
root.mkdir()
|
|
server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage]
|
|
|
|
# Build a ResponseContext with the malicious id targeting the chosen sink.
|
|
kwargs: dict[str, Any] = {
|
|
"response_id": "resp_" + "a" * 48,
|
|
"mode_flags": MagicMock(),
|
|
}
|
|
if context_field == "previous_response_id":
|
|
request = CreateResponse(model="m", input="hi", previous_response_id=bad_id)
|
|
kwargs["previous_response_id"] = bad_id
|
|
elif context_field == "conversation_id":
|
|
request = CreateResponse(model="m", input="hi")
|
|
kwargs["conversation_id"] = bad_id
|
|
else: # response_id (write sink)
|
|
request = CreateResponse(model="m", input="hi")
|
|
kwargs["response_id"] = bad_id
|
|
|
|
# Avoid invoking the real input-resolution machinery, which would need
|
|
# a configured provider; we never reach the workflow run on rejection.
|
|
with patch.object(ResponseContext, "get_input_items", new=AsyncMock(return_value=[])):
|
|
context = ResponseContext(**kwargs)
|
|
before = sorted(p.name for p in tmp_path.iterdir())
|
|
with pytest.raises(RuntimeError, match="Invalid checkpoint context id"):
|
|
async for _ in server._handle_inner_workflow(request, context): # pyright: ignore[reportPrivateUsage]
|
|
pass
|
|
after = sorted(p.name for p in tmp_path.iterdir())
|
|
|
|
assert before == after, f"Unexpected filesystem artifacts created for {context_field}={bad_id!r}"
|
|
assert list(root.iterdir()) == [], f"Checkpoint dir created inside root for {context_field}={bad_id!r}"
|
|
|
|
@pytest.mark.parametrize(
|
|
"context_field,bad_id",
|
|
[
|
|
# Restore sink: caller-controlled previous_response_id. These are
|
|
# rejected by request validation (HTTP 400) before the checkpoint
|
|
# code is reached.
|
|
("previous_response_id", "../../escape"),
|
|
("previous_response_id", "/tmp/escape-abs"),
|
|
("previous_response_id", "caresp_x/../../service-data/api-made-dir" + "A" * 14),
|
|
# Restore sink: server-issued conversation id (defense in depth).
|
|
# Reaches the checkpoint code and is rejected there, surfacing as
|
|
# an HTTP 5xx without creating any filesystem artifacts.
|
|
("conversation", "../../escape"),
|
|
("conversation", "/tmp/escape-abs"),
|
|
],
|
|
)
|
|
async def test_malicious_context_id_rejected_e2e(self, tmp_path: Any, context_field: str, bad_id: str) -> None:
|
|
"""End-to-end (ASGI-in-process): malicious context ids must be rejected
|
|
through the full HTTP pipeline, and no checkpoint directory may be
|
|
created on disk for either the validation-layer rejection
|
|
(``previous_response_id``) or the deeper checkpoint-layer rejection
|
|
(``conversation``).
|
|
|
|
The ``response_id`` write-sink is server-generated and not reachable
|
|
via the public HTTP surface, so its defense-in-depth check is covered
|
|
by the helper-level test above.
|
|
"""
|
|
from agent_framework import WorkflowAgent
|
|
|
|
# Build a mock that satisfies isinstance(agent, WorkflowAgent) and the
|
|
# constructor's "no existing checkpointing" guard.
|
|
agent = MagicMock(spec=WorkflowAgent)
|
|
agent.id = "wf-agent"
|
|
agent.name = "wf"
|
|
agent.description = ""
|
|
agent.context_providers = []
|
|
agent.workflow = MagicMock()
|
|
agent.workflow.name = "wf"
|
|
agent.workflow._runner_context.has_checkpointing = MagicMock( # pyright: ignore[reportPrivateUsage]
|
|
return_value=False
|
|
)
|
|
|
|
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
|
|
# Re-root checkpoint storage at our isolated tmp_path so we can detect
|
|
# any escape attempt on the filesystem.
|
|
root = tmp_path / "root"
|
|
root.mkdir()
|
|
server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage]
|
|
|
|
payload: dict[str, Any] = {"model": "m", "input": "hi"}
|
|
if context_field == "previous_response_id":
|
|
payload["previous_response_id"] = bad_id
|
|
else: # conversation
|
|
payload["conversation"] = bad_id
|
|
|
|
before = sorted(p.name for p in tmp_path.iterdir())
|
|
transport = httpx.ASGITransport(app=server)
|
|
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
|
resp = await client.post("/responses", json=payload)
|
|
after = sorted(p.name for p in tmp_path.iterdir())
|
|
|
|
# The request must not succeed; either request validation rejects it
|
|
# (4xx) or the checkpoint layer raises and the server returns 5xx.
|
|
# Either way, no successful response may be produced.
|
|
assert resp.status_code >= 400, (
|
|
f"Expected non-2xx for {context_field}={bad_id!r}, got {resp.status_code}: {resp.text[:200]}"
|
|
)
|
|
assert before == after, (
|
|
f"Unexpected filesystem artifacts under tmp_path for {context_field}={bad_id!r}: "
|
|
f"before={before} after={after}"
|
|
)
|
|
assert list(root.iterdir()) == [], f"Checkpoint directory created inside root for {context_field}={bad_id!r}"
|
|
|
|
|
|
# region Agent lifecycle (lazy entry & OAuth consent surfacing)
|
|
|
|
|
|
def _make_consent_error(url: str = "https://consent.example.com/auth") -> Exception:
|
|
"""Build an exception wrapping a Foundry MCP gateway consent error.
|
|
|
|
Mirrors the real-world wrapping produced by ``MCPStreamableHTTPTool.__aenter__``,
|
|
which catches connection-time ``McpError``s and re-raises them as a
|
|
``ToolExecutionException`` (an ``AgentFrameworkException`` subclass) with the
|
|
original error attached via ``inner_exception``. ``consent_url_from_error``
|
|
then finds the wrapped ``McpError`` in ``exc.args``.
|
|
"""
|
|
from agent_framework.exceptions import ToolExecutionException
|
|
|
|
inner = McpError(ErrorData(code=CONSENT_ERROR_CODE, message=url))
|
|
return ToolExecutionException("MCP consent required", inner_exception=inner)
|
|
|
|
|
|
class TestConsentUrlFromError:
|
|
def test_returns_consent_url_when_inner_arg_is_consent_mcp_error(self) -> None:
|
|
exc = _make_consent_error("https://example.com/consent")
|
|
assert consent_url_from_error(exc) == "https://example.com/consent"
|
|
|
|
def test_returns_none_when_no_mcp_error_in_args(self) -> None:
|
|
assert consent_url_from_error(Exception("boom")) is None
|
|
|
|
def test_returns_none_when_mcp_error_has_different_code(self) -> None:
|
|
inner = McpError(ErrorData(code=-32000, message="some other error"))
|
|
exc = Exception("wrapped", inner)
|
|
assert consent_url_from_error(exc) is None
|
|
|
|
def test_returns_none_for_bare_mcp_error_without_wrapping(self) -> None:
|
|
# `args` of a bare McpError holds the message string, not an McpError
|
|
# instance, so it does not match the wrapping pattern produced by the
|
|
# MCP client when it bubbles consent errors up.
|
|
bare = McpError(ErrorData(code=CONSENT_ERROR_CODE, message="https://x"))
|
|
assert consent_url_from_error(bare) is None
|
|
|
|
|
|
class TestAgentLifecycle:
|
|
async def test_agent_entered_lazily_on_first_request(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hi")])])
|
|
)
|
|
server = _make_server(agent)
|
|
# Construction must not enter the agent.
|
|
assert agent.__aenter__.await_count == 0
|
|
|
|
await _post(server, input_text="hello", stream=False)
|
|
assert agent.__aenter__.await_count == 1
|
|
|
|
async def test_agent_entered_only_once_across_requests(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hi")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
await _post(server, input_text="first", stream=False)
|
|
await _post(server, input_text="second", stream=False)
|
|
await _post(server, input_text="third", stream=False)
|
|
assert agent.__aenter__.await_count == 1
|
|
|
|
async def test_cleanup_exits_agent_and_allows_reentry(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hi")])])
|
|
)
|
|
server = _make_server(agent)
|
|
|
|
await _post(server, input_text="hello", stream=False)
|
|
assert agent.__aenter__.await_count == 1
|
|
assert agent.__aexit__.await_count == 0
|
|
|
|
await server._cleanup_agent() # pyright: ignore[reportPrivateUsage]
|
|
assert agent.__aexit__.await_count == 1
|
|
|
|
# Cleanup is idempotent.
|
|
await server._cleanup_agent() # pyright: ignore[reportPrivateUsage]
|
|
assert agent.__aexit__.await_count == 1
|
|
|
|
# After cleanup, a follow-up request re-enters the agent.
|
|
await _post(server, input_text="again", stream=False)
|
|
assert agent.__aenter__.await_count == 2
|
|
|
|
async def test_failed_entry_does_not_cache_stack(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hi")])])
|
|
)
|
|
agent.__aenter__.side_effect = [_make_consent_error(), None]
|
|
server = _make_server(agent)
|
|
|
|
await _post(server, input_text="first", stream=False)
|
|
# Failed entry must leave the stack empty so the next request retries.
|
|
await _post(server, input_text="second", stream=False)
|
|
assert agent.__aenter__.await_count == 2
|
|
|
|
|
|
class TestOAuthConsentSurfacing:
|
|
async def test_non_streaming_consent_error_emits_oauth_output_item(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hi")])])
|
|
)
|
|
agent.__aenter__.side_effect = _make_consent_error("https://consent.example.com/auth")
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post(server, input_text="hello", stream=False)
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "completed"
|
|
|
|
oauth_items = [it for it in body["output"] if it["type"] == "oauth_consent_request"]
|
|
assert len(oauth_items) == 1
|
|
assert oauth_items[0]["consent_link"] == "https://consent.example.com/auth"
|
|
assert oauth_items[0]["server_label"] == "Foundry Toolbox"
|
|
|
|
# The agent must not be run when entry fails.
|
|
agent.run.assert_not_called()
|
|
|
|
async def test_streaming_consent_error_emits_oauth_output_item(self) -> None:
|
|
agent = _make_agent(stream_updates=[AgentResponseUpdate(contents=[Content.from_text("hi")], role="assistant")])
|
|
agent.__aenter__.side_effect = _make_consent_error("https://consent.example.com/auth")
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post(server, input_text="hello", stream=True)
|
|
assert resp.status_code == 200
|
|
events = _parse_sse_events(resp.text)
|
|
types = _sse_event_types(events)
|
|
|
|
assert types[0] == "response.created"
|
|
assert types[1] == "response.in_progress"
|
|
assert types[-1] == "response.completed"
|
|
|
|
added = [e for e in events if e["event"] == "response.output_item.added"]
|
|
oauth_added = [e for e in added if e["data"]["item"]["type"] == "oauth_consent_request"]
|
|
assert len(oauth_added) == 1
|
|
assert oauth_added[0]["data"]["item"]["consent_link"] == "https://consent.example.com/auth"
|
|
assert oauth_added[0]["data"]["item"]["server_label"] == "Foundry Toolbox"
|
|
|
|
done = [e for e in events if e["event"] == "response.output_item.done"]
|
|
assert any(e["data"]["item"]["type"] == "oauth_consent_request" for e in done)
|
|
|
|
agent.run.assert_not_called()
|
|
|
|
async def test_non_consent_error_during_entry_propagates(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hi")])])
|
|
)
|
|
agent.__aenter__.side_effect = RuntimeError("boom")
|
|
server = _make_server(agent)
|
|
|
|
resp = await _post(server, input_text="hello", stream=False)
|
|
# Non-consent errors are not swallowed: the response is marked failed
|
|
# and no `oauth_consent_request` item is emitted.
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["status"] == "failed"
|
|
assert not any(it["type"] == "oauth_consent_request" for it in body.get("output", []))
|
|
agent.run.assert_not_called()
|
|
|
|
async def test_retry_after_consent_succeeds(self) -> None:
|
|
agent = _make_agent(
|
|
response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hello!")])])
|
|
)
|
|
agent.__aenter__.side_effect = [_make_consent_error("https://consent.example.com/auth"), None]
|
|
server = _make_server(agent)
|
|
|
|
# First request surfaces consent; agent.run is not called.
|
|
resp1 = await _post(server, input_text="first", stream=False)
|
|
assert resp1.status_code == 200
|
|
body1 = resp1.json()
|
|
oauth = [it for it in body1["output"] if it["type"] == "oauth_consent_request"]
|
|
assert len(oauth) == 1
|
|
agent.run.assert_not_called()
|
|
|
|
# After the user authenticates, the next request enters successfully.
|
|
resp2 = await _post(server, input_text="second", stream=False)
|
|
assert resp2.status_code == 200
|
|
body2 = resp2.json()
|
|
assert body2["status"] == "completed"
|
|
assert any(it["type"] == "message" for it in body2["output"])
|
|
assert agent.__aenter__.await_count == 2
|
|
agent.run.assert_awaited_once()
|
|
|
|
|
|
# endregion
|