Add Foundry MCP OAuth consent handling

This commit is contained in:
Tao Chen
2026-05-11 16:11:00 -07:00
committed by eavanvalkenburg
Unverified
parent f6d5c7977b
commit b88541d5fb
@@ -12,6 +12,7 @@ import threading
from collections.abc import AsyncIterable, AsyncIterator, Generator, Mapping, Sequence
from contextlib import suppress
from pathlib import Path
from contextlib import AbstractAsyncContextManager, AsyncExitStack, suppress
from typing import Protocol, cast
from agent_framework import (
@@ -31,6 +32,7 @@ from azure.ai.agentserver.responses import (
ResponseProviderProtocol,
ResponsesServerOptions,
)
from azure.ai.agentserver.responses._id_generator import IdGenerator
from azure.ai.agentserver.responses.hosting import ResponsesAgentServerHost
from azure.ai.agentserver.responses.models import (
ApplyPatchToolCallItemParam,
@@ -108,11 +110,13 @@ from azure.ai.agentserver.responses.streaming._builders import (
ReasoningSummaryPartBuilder,
TextContentBuilder,
)
from mcp import McpError
from typing_extensions import Any
logger = logging.getLogger(__name__)
# region Approval Storage
class ApprovalStorage(Protocol):
"""Storage for saving function approval requests."""
@@ -247,6 +251,32 @@ def _checkpoint_storage_for_context(root: str, context_id: str) -> FileCheckpoin
return FileCheckpointStorage(storage_path)
# endregion Approval Storage
# Foundry Toolbox Auth integration
# Consent-URL error code returned by the Foundry MCP gateway.
CONSENT_ERROR_CODE = -32006
def is_consent_error(exc: BaseException) -> str | None:
"""Check if the exception is a consent error from the Foundry MCP gateway.
Args:
exc: The exception to check.
Returns:
The consent error message that is the URL if the exception is a consent error, otherwise None.
"""
inner_exception = next((arg for arg in exc.args if isinstance(arg, McpError)), None)
if inner_exception is not None and inner_exception.error.code == CONSENT_ERROR_CODE:
return inner_exception.error.message
return None
# endregion Foundry Toolbox Auth integration
# region ResponsesHostServer
class ResponsesHostServer(ResponsesAgentServerHost):
"""A responses server host for an agent."""
@@ -315,8 +345,43 @@ class ResponsesHostServer(ResponsesAgentServerHost):
if self.config.is_hosted
else InMemoryFunctionApprovalStorage()
)
# Lazy agent lifecycle: the agent (and any MCP tools it owns) is entered on
# the first request rather than at server startup, so that authentication
# failures during MCP connect can be surfaced to the client as an
# `oauth_consent_request` stream event instead of crashing the server.
self._agent_stack: AsyncExitStack | None = None
self._agent_init_lock = asyncio.Lock()
self.shutdown_handler(self._cleanup_agent) # pyright: ignore[reportUnknownMemberType]
self.response_handler(self._handle_response) # pyright: ignore[reportUnknownMemberType]
async def _ensure_agent_ready(self) -> None:
"""Lazily enter the agent's async context exactly once.
On failure the partial exit stack is closed and ``_agent_stack`` is left
as ``None`` so a subsequent request (e.g. after the user completes OAuth
consent) can retry the connection.
"""
if self._agent_stack is not None:
return
async with self._agent_init_lock:
if self._agent_stack is not None:
return
stack = AsyncExitStack()
try:
if isinstance(self._agent, AbstractAsyncContextManager):
await stack.enter_async_context(self._agent)
except BaseException:
await stack.aclose()
raise
self._agent_stack = stack
async def _cleanup_agent(self) -> None:
"""Close the agent's async context. Registered as the server shutdown handler."""
stack = self._agent_stack
if stack is not None:
self._agent_stack = None
await stack.aclose()
async def _handle_response(
self,
request: CreateResponse,
@@ -359,45 +424,72 @@ class ResponsesHostServer(ResponsesAgentServerHost):
else:
run_kwargs["options"] = chat_options
if not is_streaming_request:
# Run the agent in non-streaming mode
response = await self._agent.run(stream=False, **run_kwargs) # type: ignore[reportUnknownMemberType]
for message in response.messages:
for content in message.contents:
async for item in _to_outputs(
response_event_stream,
content,
approval_storage=self._approval_storage,
):
yield item
yield response_event_stream.emit_completed()
return
# Lazy-enter the agent (and any MCP tools it owns). If this fails with an
# auth/consent error, surface the consent link to the client through the
# already-opened response stream instead of crashing the request.
try:
await self._ensure_agent_ready()
except Exception as ex:
if consent_url := is_consent_error(ex):
logger.warning("OAuth consent required for Foundry MCP gateway.")
oauth_item = OAuthConsentRequestOutputItem(
id=IdGenerator.new_id("oacr", ""),
consent_link=consent_url,
server_label="Foundry Toolbox",
)
builder = response_event_stream.add_output_item(oauth_item.id)
yield builder.emit_added(oauth_item)
yield builder.emit_done(oauth_item)
yield response_event_stream.emit_completed()
return
else:
raise
# Track the current active output item builder for streaming;
# lazily created on matching content, closed when a different type arrives.
tracker = _OutputItemTracker(response_event_stream)
tracker: _OutputItemTracker | None = _OutputItemTracker(response_event_stream) if is_streaming_request else None
# Run the agent in streaming mode
async for update in self._agent.run(stream=True, **run_kwargs): # type: ignore[reportUnknownMemberType]
for content in update.contents:
for event in tracker.handle(content):
try:
if not is_streaming_request:
# Run the agent in non-streaming mode
response = await self._agent.run(stream=False, **run_kwargs) # type: ignore[reportUnknownMemberType]
for message in response.messages:
for content in message.contents:
async for item in _to_outputs(
response_event_stream,
content,
approval_storage=self._approval_storage,
):
yield item
else:
if tracker is None: # pragma: no cover - defensive, set above
raise RuntimeError("Streaming tracker was not initialized.")
# Run the agent in streaming mode
async for update in self._agent.run(stream=True, **run_kwargs): # type: ignore[reportUnknownMemberType]
for content in update.contents:
for event in tracker.handle(content):
yield event
if tracker.needs_async:
async for item in _to_outputs(
response_event_stream,
content,
approval_storage=self._approval_storage,
):
yield item
tracker.needs_async = False
# Close any remaining active builder
for event in tracker.close():
yield event
if tracker.needs_async:
async for item in _to_outputs(
response_event_stream,
content,
approval_storage=self._approval_storage,
):
yield item
tracker.needs_async = False
# Close any remaining active builder
for event in tracker.close():
yield event
yield response_event_stream.emit_completed()
except Exception:
# Drain any in-progress streaming builder before emitting consent
# so the resulting stream stays well-formed.
if tracker is not None:
for event in tracker.close():
yield event
yield response_event_stream.emit_completed()
raise
async def _handle_inner_workflow(
self,
@@ -429,6 +521,11 @@ class ResponsesHostServer(ResponsesAgentServerHost):
if not isinstance(self._agent, WorkflowAgent):
raise RuntimeError("Agent is not a workflow agent.")
# Workflow agents are not async context managers in any built-in path,
# but call _ensure_agent_ready for symmetry with the regular path so
# any future async resources owned by the workflow are entered here.
await self._ensure_agent_ready()
# Determine the latest checkpoint (if any) so we can resume the
# workflow's prior state for this turn. The directory is keyed by
# the inbound context id (conversation_id when set, otherwise
@@ -551,6 +648,8 @@ class ResponsesHostServer(ResponsesAgentServerHost):
await checkpoint_storage.delete(checkpoint.checkpoint_id)
# endregion ResponsesHostServer
# region Active Builder State