mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Add Foundry MCP OAuth consent handling
This commit is contained in:
committed by
eavanvalkenburg
Unverified
parent
f6d5c7977b
commit
b88541d5fb
@@ -12,6 +12,7 @@ import threading
|
||||
from collections.abc import AsyncIterable, AsyncIterator, Generator, Mapping, Sequence
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack, suppress
|
||||
from typing import Protocol, cast
|
||||
|
||||
from agent_framework import (
|
||||
@@ -31,6 +32,7 @@ from azure.ai.agentserver.responses import (
|
||||
ResponseProviderProtocol,
|
||||
ResponsesServerOptions,
|
||||
)
|
||||
from azure.ai.agentserver.responses._id_generator import IdGenerator
|
||||
from azure.ai.agentserver.responses.hosting import ResponsesAgentServerHost
|
||||
from azure.ai.agentserver.responses.models import (
|
||||
ApplyPatchToolCallItemParam,
|
||||
@@ -108,11 +110,13 @@ from azure.ai.agentserver.responses.streaming._builders import (
|
||||
ReasoningSummaryPartBuilder,
|
||||
TextContentBuilder,
|
||||
)
|
||||
from mcp import McpError
|
||||
from typing_extensions import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# region Approval Storage
|
||||
class ApprovalStorage(Protocol):
|
||||
"""Storage for saving function approval requests."""
|
||||
|
||||
@@ -247,6 +251,32 @@ def _checkpoint_storage_for_context(root: str, context_id: str) -> FileCheckpoin
|
||||
return FileCheckpointStorage(storage_path)
|
||||
|
||||
|
||||
# endregion Approval Storage
|
||||
|
||||
# Foundry Toolbox Auth integration
|
||||
# Consent-URL error code returned by the Foundry MCP gateway.
|
||||
CONSENT_ERROR_CODE = -32006
|
||||
|
||||
|
||||
def is_consent_error(exc: BaseException) -> str | None:
|
||||
"""Check if the exception is a consent error from the Foundry MCP gateway.
|
||||
|
||||
Args:
|
||||
exc: The exception to check.
|
||||
|
||||
Returns:
|
||||
The consent error message that is the URL if the exception is a consent error, otherwise None.
|
||||
"""
|
||||
inner_exception = next((arg for arg in exc.args if isinstance(arg, McpError)), None)
|
||||
if inner_exception is not None and inner_exception.error.code == CONSENT_ERROR_CODE:
|
||||
return inner_exception.error.message
|
||||
return None
|
||||
|
||||
|
||||
# endregion Foundry Toolbox Auth integration
|
||||
|
||||
|
||||
# region ResponsesHostServer
|
||||
class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
"""A responses server host for an agent."""
|
||||
|
||||
@@ -315,8 +345,43 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
if self.config.is_hosted
|
||||
else InMemoryFunctionApprovalStorage()
|
||||
)
|
||||
# Lazy agent lifecycle: the agent (and any MCP tools it owns) is entered on
|
||||
# the first request rather than at server startup, so that authentication
|
||||
# failures during MCP connect can be surfaced to the client as an
|
||||
# `oauth_consent_request` stream event instead of crashing the server.
|
||||
self._agent_stack: AsyncExitStack | None = None
|
||||
self._agent_init_lock = asyncio.Lock()
|
||||
self.shutdown_handler(self._cleanup_agent) # pyright: ignore[reportUnknownMemberType]
|
||||
self.response_handler(self._handle_response) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
async def _ensure_agent_ready(self) -> None:
|
||||
"""Lazily enter the agent's async context exactly once.
|
||||
|
||||
On failure the partial exit stack is closed and ``_agent_stack`` is left
|
||||
as ``None`` so a subsequent request (e.g. after the user completes OAuth
|
||||
consent) can retry the connection.
|
||||
"""
|
||||
if self._agent_stack is not None:
|
||||
return
|
||||
async with self._agent_init_lock:
|
||||
if self._agent_stack is not None:
|
||||
return
|
||||
stack = AsyncExitStack()
|
||||
try:
|
||||
if isinstance(self._agent, AbstractAsyncContextManager):
|
||||
await stack.enter_async_context(self._agent)
|
||||
except BaseException:
|
||||
await stack.aclose()
|
||||
raise
|
||||
self._agent_stack = stack
|
||||
|
||||
async def _cleanup_agent(self) -> None:
|
||||
"""Close the agent's async context. Registered as the server shutdown handler."""
|
||||
stack = self._agent_stack
|
||||
if stack is not None:
|
||||
self._agent_stack = None
|
||||
await stack.aclose()
|
||||
|
||||
async def _handle_response(
|
||||
self,
|
||||
request: CreateResponse,
|
||||
@@ -359,45 +424,72 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
else:
|
||||
run_kwargs["options"] = chat_options
|
||||
|
||||
if not is_streaming_request:
|
||||
# Run the agent in non-streaming mode
|
||||
response = await self._agent.run(stream=False, **run_kwargs) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
for message in response.messages:
|
||||
for content in message.contents:
|
||||
async for item in _to_outputs(
|
||||
response_event_stream,
|
||||
content,
|
||||
approval_storage=self._approval_storage,
|
||||
):
|
||||
yield item
|
||||
|
||||
yield response_event_stream.emit_completed()
|
||||
return
|
||||
# Lazy-enter the agent (and any MCP tools it owns). If this fails with an
|
||||
# auth/consent error, surface the consent link to the client through the
|
||||
# already-opened response stream instead of crashing the request.
|
||||
try:
|
||||
await self._ensure_agent_ready()
|
||||
except Exception as ex:
|
||||
if consent_url := is_consent_error(ex):
|
||||
logger.warning("OAuth consent required for Foundry MCP gateway.")
|
||||
oauth_item = OAuthConsentRequestOutputItem(
|
||||
id=IdGenerator.new_id("oacr", ""),
|
||||
consent_link=consent_url,
|
||||
server_label="Foundry Toolbox",
|
||||
)
|
||||
builder = response_event_stream.add_output_item(oauth_item.id)
|
||||
yield builder.emit_added(oauth_item)
|
||||
yield builder.emit_done(oauth_item)
|
||||
yield response_event_stream.emit_completed()
|
||||
return
|
||||
else:
|
||||
raise
|
||||
|
||||
# Track the current active output item builder for streaming;
|
||||
# lazily created on matching content, closed when a different type arrives.
|
||||
tracker = _OutputItemTracker(response_event_stream)
|
||||
tracker: _OutputItemTracker | None = _OutputItemTracker(response_event_stream) if is_streaming_request else None
|
||||
|
||||
# Run the agent in streaming mode
|
||||
async for update in self._agent.run(stream=True, **run_kwargs): # type: ignore[reportUnknownMemberType]
|
||||
for content in update.contents:
|
||||
for event in tracker.handle(content):
|
||||
try:
|
||||
if not is_streaming_request:
|
||||
# Run the agent in non-streaming mode
|
||||
response = await self._agent.run(stream=False, **run_kwargs) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
for message in response.messages:
|
||||
for content in message.contents:
|
||||
async for item in _to_outputs(
|
||||
response_event_stream,
|
||||
content,
|
||||
approval_storage=self._approval_storage,
|
||||
):
|
||||
yield item
|
||||
else:
|
||||
if tracker is None: # pragma: no cover - defensive, set above
|
||||
raise RuntimeError("Streaming tracker was not initialized.")
|
||||
# Run the agent in streaming mode
|
||||
async for update in self._agent.run(stream=True, **run_kwargs): # type: ignore[reportUnknownMemberType]
|
||||
for content in update.contents:
|
||||
for event in tracker.handle(content):
|
||||
yield event
|
||||
if tracker.needs_async:
|
||||
async for item in _to_outputs(
|
||||
response_event_stream,
|
||||
content,
|
||||
approval_storage=self._approval_storage,
|
||||
):
|
||||
yield item
|
||||
tracker.needs_async = False
|
||||
|
||||
# Close any remaining active builder
|
||||
for event in tracker.close():
|
||||
yield event
|
||||
if tracker.needs_async:
|
||||
async for item in _to_outputs(
|
||||
response_event_stream,
|
||||
content,
|
||||
approval_storage=self._approval_storage,
|
||||
):
|
||||
yield item
|
||||
tracker.needs_async = False
|
||||
|
||||
# Close any remaining active builder
|
||||
for event in tracker.close():
|
||||
yield event
|
||||
|
||||
yield response_event_stream.emit_completed()
|
||||
except Exception:
|
||||
# Drain any in-progress streaming builder before emitting consent
|
||||
# so the resulting stream stays well-formed.
|
||||
if tracker is not None:
|
||||
for event in tracker.close():
|
||||
yield event
|
||||
yield response_event_stream.emit_completed()
|
||||
raise
|
||||
|
||||
async def _handle_inner_workflow(
|
||||
self,
|
||||
@@ -429,6 +521,11 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
if not isinstance(self._agent, WorkflowAgent):
|
||||
raise RuntimeError("Agent is not a workflow agent.")
|
||||
|
||||
# Workflow agents are not async context managers in any built-in path,
|
||||
# but call _ensure_agent_ready for symmetry with the regular path so
|
||||
# any future async resources owned by the workflow are entered here.
|
||||
await self._ensure_agent_ready()
|
||||
|
||||
# Determine the latest checkpoint (if any) so we can resume the
|
||||
# workflow's prior state for this turn. The directory is keyed by
|
||||
# the inbound context id (conversation_id when set, otherwise
|
||||
@@ -551,6 +648,8 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
await checkpoint_storage.delete(checkpoint.checkpoint_id)
|
||||
|
||||
|
||||
# endregion ResponsesHostServer
|
||||
|
||||
# region Active Builder State
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user