mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: add agent-framework-hosting-entra identity-link helpers (#5644)
* feat(hosting-entra): add Entra (Azure AD) identity-linking channel New ``agent-framework-hosting-entra`` package implementing a Microsoft Entra OAuth-based identity-linking channel for the Hosting framework. Mounts a small set of routes (``/entra/login``, ``/entra/callback``, ``/entra/whoami``) that walk a user through an Entra/Azure AD authorization-code flow and stick the resulting verified identity (``oid`` / ``email`` / ``tid``) onto the host's identity table so later requests on any other channel (Responses, Telegram, …) can be linked to the same user. Surface (re-exported from ``agent_framework_hosting_entra``): - ``EntraChannel`` -- concrete ``Channel`` implementation. Owns the three Starlette routes, signs/verifies short-lived ``state`` tokens to bind the round-trip to the originating channel, exchanges the authorization code for an ID token via MSAL, and writes the verified identity into the host's identity store via the standard ``ChannelIdentity`` plumbing so cross-channel push (e.g. send a Telegram message to the user who completed the link from Responses) works without the channels having to coordinate directly. - 14 unit tests covering route wiring, ``state`` issue / verify, callback exchange happy + failure paths, and identity-store write. Registers the package in ``python/pyproject.toml`` ``[tool.uv.sources]`` and adds the matching pyright ``executionEnvironments`` entry. Stacks on PR-2 (Hosting core); independent of PR-3 / PR-4 / PR-6. The cross-channel sample (``local_identity_link/``) that demonstrates this end-to-end alongside Responses + Telegram lands in PR-8 (samples). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix(hosting-entra): close IDOR + reflected-XSS + open-redirect on the OAuth flow Three SECURITY-CRITICAL fixes flagged in round-2 review. 1. IDOR on /auth/start (3198518308). Without authentication the endpoint accepted (channel, channel_id) from the query string and bound *whoever signed in* to that pair. An attacker could bind their own Entra oid to a victim's per-channel id (e.g. `telegram:<victim_chat_id>`), redirecting all of the victim's future inbound traffic to the attacker's isolation key. Fix: introduce link_token_secret + mint_start_url(channel, id, ...). When set, /auth/start requires `exp` + `sig` (HMAC-SHA256 over `channel|channel_id|expires_at`) before issuing the redirect. Channels that hand out start URLs (a Telegram /link command after verifying the inbound webhook signature) call mint_start_url so the token proves the (channel, id) pair was authorised by the channel that owns the surface. Unsigned mode is opt-in and logs a loud WARNING at startup *and* on every accepted request. 2. Reflected XSS on /auth/callback (3198520256, 3198527896). `error`, `error_description`, channel_key (from the unauthenticated /start query), and `upn` (from a Graph response) flowed straight into the text/html response body unescaped. With the IDOR above, an attacker could stash `<script>` payloads in `channel` or `id` and serve them from the auth host's origin (full XSS on the auth surface — cookies/storage of anything else mounted there). Fix: html.escape() every value before HTML output. 3. Open redirect on `return_to` (3198524746). Accepted any URL. Fix: `_validate_return_to` allows only relative paths starting with `/` (and not `//`) or absolute URLs whose host equals the configured `public_base_url` host. Validated at /start mint time AND defensively re-validated at /callback before redirect. 12 new tests cover signed-token rejection (missing/forged/expired), mint helper requirements, startup warning visibility, XSS escaping on both error and success paths, and the open-redirect allowlist (external rejected, relative accepted, same-origin accepted, protocol-relative `//evil.example/` rejected). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * test(hosting): drop redundant @pytest.mark.asyncio decorators asyncio_mode = "auto" is configured in pyproject.toml across the hosting packages, so individual @pytest.mark.asyncio decorators are unnecessary. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
cdea9fa956
commit
fe89da15b6
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
||||
@@ -0,0 +1,39 @@
|
||||
# agent-framework-hosting-entra
|
||||
|
||||
Microsoft Entra (Azure AD) identity-linking sidecar channel for
|
||||
[agent-framework-hosting](../hosting). Owns the OAuth 2.0 Authorization Code
|
||||
flow that binds a per-channel id (e.g. a Telegram chat id) to the user's
|
||||
Entra object id, so multiple non-Entra channels can share a single
|
||||
`entra:<oid>` isolation key.
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
from agent_framework_hosting import AgentFrameworkHost
|
||||
from agent_framework_hosting_entra import (
|
||||
EntraIdentityLinkChannel,
|
||||
EntraIdentityStore,
|
||||
)
|
||||
|
||||
store = EntraIdentityStore(Path("./identity_links.json"))
|
||||
|
||||
host = AgentFrameworkHost(
|
||||
target=my_agent,
|
||||
channels=[
|
||||
EntraIdentityLinkChannel(
|
||||
store=store,
|
||||
tenant_id="<tenant id>",
|
||||
client_id="<entra app id>",
|
||||
client_secret="<entra app secret>",
|
||||
public_base_url="https://your.host",
|
||||
),
|
||||
# ... other channels whose run hooks call store.lookup(...)
|
||||
],
|
||||
)
|
||||
host.serve()
|
||||
```
|
||||
|
||||
For tenants that disallow client secrets, pass `certificate_path=` (and
|
||||
optionally `certificate_password=`) instead of `client_secret`. The PEM
|
||||
layout matches the one used by `agent-framework-hosting-teams`.
|
||||
@@ -0,0 +1,15 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Microsoft Entra (Azure AD) identity channel for :mod:`agent_framework_hosting`."""
|
||||
|
||||
from ._channel import (
|
||||
EntraIdentityLinkChannel,
|
||||
EntraIdentityStore,
|
||||
entra_isolation_key,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EntraIdentityLinkChannel",
|
||||
"EntraIdentityStore",
|
||||
"entra_isolation_key",
|
||||
]
|
||||
@@ -0,0 +1,505 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Microsoft Entra (Azure AD) identity-linking sidecar channel.
|
||||
|
||||
Implements the OAuth 2.0 Authorization Code flow against Entra so users on
|
||||
non-Entra channels (Telegram, Responses callers without a verified token,
|
||||
etc.) can bind their per-channel id to a stable ``entra:<oid>`` isolation
|
||||
key. Once the link is established, channel run-hooks can call
|
||||
:meth:`EntraIdentityStore.lookup` and rewrite the request to use the Entra
|
||||
key instead of the channel-native id.
|
||||
|
||||
Two credential modes are supported:
|
||||
|
||||
* ``client_secret`` — confidential-client secret.
|
||||
* ``certificate_path`` — PEM bundle (private key + cert) for tenants that
|
||||
disallow secrets. The Teams channel uses the same PEM layout; see
|
||||
:mod:`agent_framework_hosting_teams` for the openssl recipe.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import html
|
||||
import json
|
||||
import secrets
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import httpx
|
||||
import msal
|
||||
from agent_framework_hosting import (
|
||||
ChannelContext,
|
||||
ChannelContribution,
|
||||
logger,
|
||||
)
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse, RedirectResponse, Response
|
||||
from starlette.routing import Route
|
||||
|
||||
|
||||
def entra_isolation_key(oid: str) -> str:
|
||||
"""Canonical isolation key for a user identified by Entra object id."""
|
||||
return f"entra:{oid}"
|
||||
|
||||
|
||||
class EntraIdentityStore:
|
||||
"""Tiny JSON-backed mapping ``<channel>:<channel_id> → entra:<oid>``.
|
||||
|
||||
Production deployments should swap this for a real KV store. Single-file
|
||||
JSON is fine for samples because writes are infrequent (only during the
|
||||
OAuth callback) and we serialize them under an asyncio lock.
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path) -> None:
|
||||
"""Open an identity store backed by ``path``.
|
||||
|
||||
Loads any existing JSON document; an unreadable or corrupt file is
|
||||
logged and replaced with an empty in-memory map so callers always
|
||||
get a usable store.
|
||||
"""
|
||||
self._path = path
|
||||
self._lock = asyncio.Lock()
|
||||
self._data: dict[str, str] = {}
|
||||
if path.exists():
|
||||
try:
|
||||
self._data = json.loads(path.read_text())
|
||||
except Exception:
|
||||
logger.exception("identity store load failed; starting empty")
|
||||
|
||||
def lookup(self, channel_key: str) -> str | None:
|
||||
"""Return the linked ``entra:<oid>`` key for a per-channel id, or ``None``."""
|
||||
return self._data.get(channel_key)
|
||||
|
||||
async def link(self, channel_key: str, oid: str) -> None:
|
||||
"""Bind ``channel_key`` (e.g. ``telegram:123``) to the Entra ``oid`` and persist.
|
||||
|
||||
Overwrites any existing mapping for ``channel_key`` and rewrites the
|
||||
backing JSON file under the lock so concurrent callers cannot race.
|
||||
"""
|
||||
async with self._lock:
|
||||
self._data[channel_key] = entra_isolation_key(oid)
|
||||
self._path.write_text(json.dumps(self._data, indent=2, sort_keys=True))
|
||||
|
||||
async def unlink(self, channel_key: str) -> None:
|
||||
"""Remove the mapping for ``channel_key``; no-op if absent.
|
||||
|
||||
The file is only rewritten when an entry actually existed so we
|
||||
don't churn disk on idempotent unlink calls.
|
||||
"""
|
||||
async with self._lock:
|
||||
if self._data.pop(channel_key, None) is not None:
|
||||
self._path.write_text(json.dumps(self._data, indent=2, sort_keys=True))
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PendingAuth:
|
||||
"""In-memory record of an authorize redirect waiting for its OAuth callback."""
|
||||
|
||||
channel: str
|
||||
channel_id: str
|
||||
expires_at: float
|
||||
return_to: str | None = None
|
||||
|
||||
|
||||
def _link_html(body: str, *, status: int = 200) -> HTMLResponse:
|
||||
"""Wrap ``body`` in a minimal HTML shell suitable for browser link UIs."""
|
||||
return HTMLResponse(
|
||||
f"<!doctype html><html><body style='font-family:system-ui;padding:2rem;max-width:40rem'>{body}</body></html>",
|
||||
status_code=status,
|
||||
)
|
||||
|
||||
|
||||
def _load_certificate_credential(certificate_path: str | Path, certificate_password: bytes | None) -> dict[str, str]:
|
||||
"""Build the ``msal`` certificate credential dict from a PEM bundle.
|
||||
|
||||
Expects ``certificate_path`` to point at a single PEM containing the
|
||||
private key followed by the X.509 certificate (the layout produced by
|
||||
``cat key.pem cert.pem > combined.pem``).
|
||||
"""
|
||||
pem_bytes = Path(certificate_path).read_bytes()
|
||||
private_key = serialization.load_pem_private_key(pem_bytes, password=certificate_password)
|
||||
cert = x509.load_pem_x509_certificate(pem_bytes)
|
||||
|
||||
private_key_pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
).decode()
|
||||
public_cert_pem = cert.public_bytes(serialization.Encoding.PEM).decode()
|
||||
# SHA-1 thumbprint is required by the Entra ``client_assertion`` spec for cert auth — not a security choice.
|
||||
thumbprint = cert.fingerprint(hashes.SHA1()).hex() # noqa: S303
|
||||
return {
|
||||
"private_key": private_key_pem,
|
||||
"thumbprint": thumbprint,
|
||||
"public_certificate": public_cert_pem,
|
||||
}
|
||||
|
||||
|
||||
class EntraIdentityLinkChannel:
|
||||
"""Sidecar Channel exposing ``GET /auth/start`` and ``GET /auth/callback``.
|
||||
|
||||
Demonstrates that ``Channel`` is a general extensibility point — not just
|
||||
for chat surfaces. Owns the Entra OAuth Authorization Code flow used to
|
||||
bind a per-channel id (e.g. Telegram chat id) to the user's Entra object
|
||||
id.
|
||||
|
||||
Two credential modes are supported (mutually exclusive):
|
||||
|
||||
* ``client_secret`` — classic confidential-client secret.
|
||||
* ``certificate_path`` — PEM bundle (private key + certificate) for
|
||||
tenants that disallow secrets. See ``teams.py`` module docstring for
|
||||
an ``openssl`` recipe; the same PEM works here.
|
||||
|
||||
Flow (OAuth 2.0 Authorization Code, confidential client):
|
||||
|
||||
1. ``GET /auth/start?channel=<name>&id=<channel_id>`` mints a one-shot
|
||||
``state`` token and 302s to the Entra ``authorize`` endpoint.
|
||||
2. User signs in; Entra calls ``GET /auth/callback?code=...&state=...``.
|
||||
3. We exchange the code for a token (via ``msal`` so secret + cert auth
|
||||
look identical at the call site), call Microsoft Graph ``/me`` to
|
||||
read ``id`` (oid), persist ``<channel>:<id> → entra:<oid>``, and
|
||||
respond with a friendly HTML page (or 302 to ``return_to``).
|
||||
|
||||
Tokens never leave the host process; only the ``oid`` claim is stored.
|
||||
"""
|
||||
|
||||
name = "identity"
|
||||
path = "/auth"
|
||||
|
||||
_AUTHORITY_TEMPLATE = "https://login.microsoftonline.com/{tenant}"
|
||||
_GRAPH_ME = "https://graph.microsoft.com/v1.0/me"
|
||||
_PENDING_TTL_SECONDS = 600 # 10 minutes
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
store: EntraIdentityStore,
|
||||
tenant_id: str,
|
||||
client_id: str,
|
||||
public_base_url: str,
|
||||
client_secret: str | None = None,
|
||||
certificate_path: str | Path | None = None,
|
||||
certificate_password: bytes | None = None,
|
||||
scope: str = "openid profile User.Read",
|
||||
link_token_secret: str | None = None,
|
||||
link_token_ttl_seconds: int = 600,
|
||||
) -> None:
|
||||
if bool(client_secret) == bool(certificate_path):
|
||||
raise ValueError("IdentityLinkChannel: pass exactly one of client_secret or certificate_path.")
|
||||
if certificate_path is not None:
|
||||
credential: str | dict[str, str] = _load_certificate_credential(certificate_path, certificate_password)
|
||||
self._auth_kind = "certificate"
|
||||
else:
|
||||
credential = client_secret # type: ignore[assignment]
|
||||
self._auth_kind = "client_secret"
|
||||
|
||||
self._store = store
|
||||
self._tenant_id = tenant_id
|
||||
self._client_id = client_id
|
||||
self._public_base_url = public_base_url.rstrip("/")
|
||||
self._scopes = [s for s in scope.split() if s and s.lower() not in {"openid", "profile", "offline_access"}]
|
||||
# MSAL ConfidentialClientApplication is sync; we wrap blocking calls
|
||||
# in ``asyncio.to_thread`` because token endpoint calls do real I/O.
|
||||
self._msal_app = msal.ConfidentialClientApplication(
|
||||
client_id=client_id,
|
||||
authority=self._AUTHORITY_TEMPLATE.format(tenant=tenant_id),
|
||||
client_credential=credential,
|
||||
)
|
||||
self._pending: dict[str, _PendingAuth] = {}
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
# ``link_token_secret`` is the HMAC key that gates ``/auth/start``.
|
||||
# Without it any open-internet caller can mint a binding for an
|
||||
# arbitrary ``(channel, channel_id)`` pair and IDOR the victim's
|
||||
# isolation key (see PR review on 0026 for the threat model).
|
||||
# Optional only so dev-mode samples without the integration in
|
||||
# place don't have to scramble for a secret; unsigned mode logs
|
||||
# a loud warning at startup and wire-time.
|
||||
self._link_token_secret = link_token_secret.encode("utf-8") if link_token_secret else None
|
||||
self._link_token_ttl = link_token_ttl_seconds
|
||||
# Allowed redirect-back hosts: relative paths and same-origin only.
|
||||
# ``return_to`` from the unauthenticated /start query string is
|
||||
# otherwise an open redirect (auth-host phishing vector).
|
||||
parsed = urlparse(self._public_base_url)
|
||||
self._allowed_return_host = parsed.netloc.lower() if parsed.netloc else None
|
||||
|
||||
@property
|
||||
def redirect_uri(self) -> str:
|
||||
"""The fully-qualified OAuth redirect URI registered with Entra ID.
|
||||
|
||||
Computed from ``public_base_url`` plus the channel's mount path so
|
||||
operators can copy it straight into the app registration's reply URLs.
|
||||
"""
|
||||
return f"{self._public_base_url}{self.path}/callback"
|
||||
|
||||
def contribute(self, context: "ChannelContext") -> "ChannelContribution":
|
||||
"""Mount the ``/start`` and ``/callback`` routes plus lifecycle hooks."""
|
||||
return ChannelContribution(
|
||||
routes=[
|
||||
Route("/start", self._handle_start, methods=["GET"]),
|
||||
Route("/callback", self._handle_callback, methods=["GET"]),
|
||||
],
|
||||
on_startup=[self._on_startup],
|
||||
on_shutdown=[self._on_shutdown],
|
||||
)
|
||||
|
||||
async def _on_startup(self) -> None:
|
||||
"""Open the shared HTTP client used for Microsoft Graph calls."""
|
||||
self._http = httpx.AsyncClient(timeout=15.0)
|
||||
if self._link_token_secret is None:
|
||||
logger.warning(
|
||||
"EntraIdentityLinkChannel running WITHOUT link_token_secret. "
|
||||
"GET /auth/start accepts unauthenticated (channel, id) pairs, "
|
||||
"which means any open-internet caller can bind their Entra "
|
||||
"account to a victim's per-channel id (IDOR on the identity "
|
||||
"store). Pass link_token_secret=<random>, mint URLs via "
|
||||
"mint_start_url(...), and gate /start in front of the "
|
||||
"channel that issues those URLs."
|
||||
)
|
||||
logger.info(
|
||||
"IdentityLinkChannel ready (auth=%s, signed_start=%s); redirect_uri=%s",
|
||||
self._auth_kind,
|
||||
self._link_token_secret is not None,
|
||||
self.redirect_uri,
|
||||
)
|
||||
|
||||
async def _on_shutdown(self) -> None:
|
||||
"""Close the Graph HTTP client; safe to call when never started."""
|
||||
if self._http is not None:
|
||||
await self._http.aclose()
|
||||
|
||||
# -- link-token helpers ----------------------------------------------- #
|
||||
|
||||
def _sign_link_token(self, channel: str, channel_id: str, expires_at: int) -> str:
|
||||
"""Sign ``(channel, channel_id, expires_at)`` with HMAC-SHA256."""
|
||||
if self._link_token_secret is None: # pragma: no cover - guarded by callers
|
||||
raise RuntimeError("link_token_secret is required to mint link tokens")
|
||||
msg = f"{channel}|{channel_id}|{expires_at}".encode()
|
||||
return hmac.new(self._link_token_secret, msg, hashlib.sha256).hexdigest()
|
||||
|
||||
def _verify_link_token(self, channel: str, channel_id: str, expires_at: int, signature: str) -> bool:
|
||||
"""Constant-time verify the link-token signature and TTL."""
|
||||
if self._link_token_secret is None: # pragma: no cover - guarded by callers
|
||||
return False
|
||||
if expires_at < int(time.time()):
|
||||
return False
|
||||
expected = self._sign_link_token(channel, channel_id, expires_at)
|
||||
return hmac.compare_digest(expected, signature)
|
||||
|
||||
def mint_start_url(self, channel: str, channel_id: str, return_to: str | None = None) -> str:
|
||||
"""Return a one-shot signed URL for ``GET /auth/start``.
|
||||
|
||||
Required when ``link_token_secret`` is set. Channels that issue
|
||||
these URLs (e.g. a Telegram ``/link`` command after verifying the
|
||||
inbound webhook signature) call this helper so the resulting URL
|
||||
proves the caller authorised the ``(channel, channel_id)`` binding.
|
||||
|
||||
Without this layer ``GET /auth/start`` is an IDOR vector: any
|
||||
anonymous caller can bind a victim's per-channel id to their own
|
||||
Entra ``oid``.
|
||||
"""
|
||||
if self._link_token_secret is None:
|
||||
raise RuntimeError("mint_start_url requires link_token_secret in the constructor")
|
||||
if return_to is not None:
|
||||
self._validate_return_to(return_to) # fail fast at mint time
|
||||
expires_at = int(time.time()) + self._link_token_ttl
|
||||
sig = self._sign_link_token(channel, str(channel_id), expires_at)
|
||||
params = {
|
||||
"channel": channel,
|
||||
"id": str(channel_id),
|
||||
"exp": str(expires_at),
|
||||
"sig": sig,
|
||||
}
|
||||
if return_to:
|
||||
params["return_to"] = return_to
|
||||
return f"{self._public_base_url}{self.path}/start?{urlencode(params)}"
|
||||
|
||||
def _validate_return_to(self, return_to: str) -> None:
|
||||
"""Reject open-redirect targets.
|
||||
|
||||
Allows: relative paths starting with ``/``, or absolute URLs whose
|
||||
host equals the configured ``public_base_url``'s host. Rejects
|
||||
everything else with ``ValueError``.
|
||||
"""
|
||||
if return_to.startswith("/") and not return_to.startswith("//"):
|
||||
return # relative path, safe.
|
||||
parsed = urlparse(return_to)
|
||||
if not parsed.netloc:
|
||||
return
|
||||
if self._allowed_return_host and parsed.netloc.lower() == self._allowed_return_host:
|
||||
return
|
||||
raise ValueError(
|
||||
f"return_to must be a relative path or same-origin URL "
|
||||
f"(public_base_url host={self._allowed_return_host!r}); got {return_to!r}"
|
||||
)
|
||||
|
||||
def authorize_url_for(self, channel: str, channel_id: str, return_to: str | None = None) -> str:
|
||||
"""Mint a one-shot authorize URL the user can visit to bind their account."""
|
||||
state = secrets.token_urlsafe(24)
|
||||
self._gc_pending()
|
||||
self._pending[state] = _PendingAuth(
|
||||
channel=channel,
|
||||
channel_id=str(channel_id),
|
||||
expires_at=time.monotonic() + self._PENDING_TTL_SECONDS,
|
||||
return_to=return_to,
|
||||
)
|
||||
return str(
|
||||
self._msal_app.get_authorization_request_url(
|
||||
scopes=self._scopes,
|
||||
redirect_uri=self.redirect_uri,
|
||||
state=state,
|
||||
prompt="select_account",
|
||||
)
|
||||
)
|
||||
|
||||
def _gc_pending(self) -> None:
|
||||
"""Drop expired pending-auth entries so the in-memory map cannot grow unbounded."""
|
||||
now = time.monotonic()
|
||||
for key, entry in list(self._pending.items()):
|
||||
if entry.expires_at < now:
|
||||
self._pending.pop(key, None)
|
||||
|
||||
async def _handle_start(self, request: Request) -> Response:
|
||||
"""``GET /start?channel=&id=&return_to=&exp=&sig=`` — redirect to Entra to sign in.
|
||||
|
||||
**Security model.** When ``link_token_secret`` is set the
|
||||
request must include ``exp`` + ``sig`` — an HMAC over
|
||||
``(channel, channel_id, expires_at)`` minted by
|
||||
:meth:`mint_start_url`. Without that gate, any open-internet
|
||||
caller can bind a victim's per-channel id (e.g.
|
||||
``telegram:<victim_chat>``) to their own Entra ``oid``: the
|
||||
callback would persist
|
||||
``"telegram:<victim>" -> "entra:<attacker_oid>"`` and any
|
||||
future inbound message from the victim would resolve to the
|
||||
attacker's isolation key. We make the unsigned mode opt-in
|
||||
with a loud startup warning so the dev-mode default doesn't
|
||||
ship to production.
|
||||
|
||||
``return_to`` is validated against the configured
|
||||
``public_base_url`` host (or restricted to relative paths) to
|
||||
prevent open-redirect phishing on a successful sign-in.
|
||||
"""
|
||||
channel = request.query_params.get("channel")
|
||||
channel_id = request.query_params.get("id")
|
||||
return_to = request.query_params.get("return_to")
|
||||
if not channel or not channel_id:
|
||||
return _link_html("Missing 'channel' or 'id' query parameter.", status=400)
|
||||
|
||||
if self._link_token_secret is not None:
|
||||
sig = request.query_params.get("sig")
|
||||
exp_raw = request.query_params.get("exp")
|
||||
try:
|
||||
exp = int(exp_raw) if exp_raw else 0
|
||||
except ValueError:
|
||||
exp = 0
|
||||
if not sig or not exp or not self._verify_link_token(channel, channel_id, exp, sig):
|
||||
logger.warning(
|
||||
"EntraIdentityLinkChannel /start rejected: missing/invalid signed link-token (channel=%s, id=%s)",
|
||||
channel,
|
||||
channel_id,
|
||||
)
|
||||
return _link_html("Invalid or expired sign-in link.", status=403)
|
||||
else:
|
||||
# See _on_startup warning. Logged on every wire access so
|
||||
# operators can't miss the IDOR exposure in their access logs.
|
||||
logger.warning(
|
||||
"EntraIdentityLinkChannel /start accepted UNSIGNED request "
|
||||
"for (channel=%s, id=%s) — set link_token_secret to require "
|
||||
"HMAC-signed link tokens minted via mint_start_url().",
|
||||
channel,
|
||||
channel_id,
|
||||
)
|
||||
if return_to is not None:
|
||||
try:
|
||||
self._validate_return_to(return_to)
|
||||
except ValueError as exc:
|
||||
logger.warning("EntraIdentityLinkChannel /start invalid return_to: %s", exc)
|
||||
return _link_html("Invalid return_to URL.", status=400)
|
||||
url = self.authorize_url_for(channel, channel_id, return_to=return_to)
|
||||
return RedirectResponse(url, status_code=302)
|
||||
|
||||
async def _handle_callback(self, request: Request) -> Response:
|
||||
"""``GET /callback`` — finish the OAuth flow and persist the link.
|
||||
|
||||
Exchanges the authorization code for a token, reads the user's
|
||||
``id``/``userPrincipalName`` from Microsoft Graph, then stores the
|
||||
``channel:channel_id -> entra:<oid>`` mapping in the identity store.
|
||||
Renders a small HTML page so a browser-based flow has something to
|
||||
show; if ``return_to`` was supplied (and validated at /start time
|
||||
against the same-origin allowlist) it appears as a deep link.
|
||||
|
||||
All values that flow into HTML output (``error``, ``error_description``,
|
||||
``channel_key``, ``upn``) are passed through :func:`html.escape` to
|
||||
avoid reflected XSS — both the OAuth-error path and the
|
||||
sign-in-success body would otherwise execute attacker-controlled
|
||||
markup on the auth host's origin.
|
||||
"""
|
||||
if self._http is None: # pragma: no cover - guarded by lifecycle
|
||||
raise RuntimeError("entra identity channel not started")
|
||||
if error := request.query_params.get("error"):
|
||||
description = request.query_params.get("error_description", "")
|
||||
return _link_html(
|
||||
f"Sign-in failed: {html.escape(error)}<br>{html.escape(description)}",
|
||||
status=400,
|
||||
)
|
||||
|
||||
code = request.query_params.get("code")
|
||||
state = request.query_params.get("state")
|
||||
pending = self._pending.pop(state or "", None)
|
||||
if not code or pending is None or pending.expires_at < time.monotonic():
|
||||
return _link_html("Invalid or expired sign-in state. Please retry.", status=400)
|
||||
|
||||
# MSAL handles client_secret vs client_assertion (cert) under the hood.
|
||||
result: dict[str, Any] = await asyncio.to_thread(
|
||||
self._msal_app.acquire_token_by_authorization_code,
|
||||
code,
|
||||
scopes=self._scopes,
|
||||
redirect_uri=self.redirect_uri,
|
||||
)
|
||||
if "access_token" not in result:
|
||||
logger.warning("Entra token exchange failed: %s", result)
|
||||
err_text = result.get("error_description") or result.get("error") or "unknown error"
|
||||
return _link_html(
|
||||
f"Token exchange failed: {html.escape(str(err_text))}",
|
||||
status=502,
|
||||
)
|
||||
access_token = result["access_token"]
|
||||
|
||||
me = await self._http.get(self._GRAPH_ME, headers={"Authorization": f"Bearer {access_token}"})
|
||||
if me.status_code != 200:
|
||||
return _link_html("Could not read user profile from Microsoft Graph.", status=502)
|
||||
profile = me.json()
|
||||
oid = profile.get("id")
|
||||
upn = profile.get("userPrincipalName") or profile.get("displayName") or oid
|
||||
if not oid:
|
||||
return _link_html("Profile response missing 'id'.", status=502)
|
||||
|
||||
channel_key = f"{pending.channel}:{pending.channel_id}"
|
||||
await self._store.link(channel_key, oid)
|
||||
logger.info("Linked %s → entra:%s (%s)", channel_key, oid, upn)
|
||||
|
||||
if pending.return_to:
|
||||
# ``return_to`` was already validated at /start time against
|
||||
# the allowlist (relative path or same-origin only). Re-check
|
||||
# defensively to harden against any future code path that
|
||||
# bypasses the /start gate.
|
||||
try:
|
||||
self._validate_return_to(pending.return_to)
|
||||
return RedirectResponse(pending.return_to, status_code=302)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"EntraIdentityLinkChannel /callback dropping invalid return_to: %s",
|
||||
pending.return_to,
|
||||
)
|
||||
return _link_html(
|
||||
f"<h2>Linked</h2><p>{html.escape(channel_key)} is now bound to "
|
||||
f"<b>{html.escape(str(upn))}</b>.</p>"
|
||||
"<p>You can close this window and return to your chat.</p>"
|
||||
)
|
||||
@@ -0,0 +1,108 @@
|
||||
[project]
|
||||
name = "agent-framework-hosting-entra"
|
||||
description = "Microsoft Entra (Azure AD) OAuth-based identity-linking channel for agent-framework-hosting."
|
||||
authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
version = "1.0.0a260424"
|
||||
license-files = ["LICENSE"]
|
||||
urls.homepage = "https://aka.ms/agent-framework"
|
||||
urls.source = "https://github.com/microsoft/agent-framework/tree/main/python"
|
||||
urls.release_notes = "https://github.com/microsoft/agent-framework/releases?q=tag%3Apython-1&expanded=true"
|
||||
urls.issues = "https://github.com/microsoft/agent-framework/issues"
|
||||
classifiers = [
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Programming Language :: Python :: 3.14",
|
||||
"Typing :: Typed",
|
||||
]
|
||||
dependencies = [
|
||||
"agent-framework-core>=1.2.0,<2",
|
||||
"agent-framework-hosting==1.0.0a260424",
|
||||
"httpx>=0.27,<1",
|
||||
"msal>=1.28,<2",
|
||||
"cryptography>=42",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
prerelease = "if-necessary-or-explicit"
|
||||
environments = [
|
||||
"sys_platform == 'darwin'",
|
||||
"sys_platform == 'linux'",
|
||||
"sys_platform == 'win32'"
|
||||
]
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
filterwarnings = []
|
||||
timeout = 120
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
"**/__init__.py"
|
||||
]
|
||||
|
||||
[tool.pyright]
|
||||
extends = "../../pyproject.toml"
|
||||
include = ["agent_framework_hosting_entra"]
|
||||
exclude = ['tests']
|
||||
# Bot Framework activities arrive as loosely-typed JSON-ish maps. Strict
|
||||
# ``Unknown`` reporting on every ``.get(...)`` adds noise without catching
|
||||
# real bugs — narrowing happens via runtime isinstance checks instead.
|
||||
reportUnknownArgumentType = "none"
|
||||
reportUnknownMemberType = "none"
|
||||
reportUnknownVariableType = "none"
|
||||
reportUnknownLambdaType = "none"
|
||||
reportOptionalMemberAccess = "none"
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ['pydantic.mypy']
|
||||
strict = true
|
||||
python_version = "3.10"
|
||||
ignore_missing_imports = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
check_untyped_defs = true
|
||||
warn_return_any = true
|
||||
show_error_codes = true
|
||||
warn_unused_ignores = false
|
||||
disallow_incomplete_defs = true
|
||||
disallow_untyped_decorators = true
|
||||
|
||||
[tool.bandit]
|
||||
targets = ["agent_framework_hosting_entra"]
|
||||
exclude_dirs = ["tests"]
|
||||
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks.mypy]
|
||||
help = "Run MyPy for this package."
|
||||
cmd = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_hosting_entra"
|
||||
|
||||
[tool.poe.tasks.test]
|
||||
help = "Run the default unit test suite for this package."
|
||||
cmd = 'pytest -m "not integration" --cov=agent_framework_hosting_entra --cov-report=term-missing:skip-covered tests'
|
||||
|
||||
[build-system]
|
||||
requires = ["flit-core >= 3.11,<4.0"]
|
||||
build-backend = "flit_core.buildapi"
|
||||
@@ -0,0 +1,464 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Unit tests for :mod:`agent_framework_hosting_entra`.
|
||||
|
||||
The MSAL ``ConfidentialClientApplication`` and Microsoft Graph calls are
|
||||
mocked out so no network access is required. Live OAuth, certificate auth,
|
||||
and full webhook flow are out of scope here.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from starlette.applications import Starlette
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from agent_framework_hosting_entra import (
|
||||
EntraIdentityLinkChannel,
|
||||
EntraIdentityStore,
|
||||
entra_isolation_key,
|
||||
)
|
||||
|
||||
|
||||
def test_entra_isolation_key_format() -> None:
|
||||
assert entra_isolation_key("abc123") == "entra:abc123"
|
||||
|
||||
|
||||
class TestEntraIdentityStore:
|
||||
async def test_link_writes_entra_namespaced_value(self, tmp_path: Path) -> None:
|
||||
store = EntraIdentityStore(tmp_path / "links.json")
|
||||
await store.link("telegram:42", "oid-xyz")
|
||||
assert store.lookup("telegram:42") == "entra:oid-xyz"
|
||||
# Persisted to disk.
|
||||
saved = json.loads((tmp_path / "links.json").read_text())
|
||||
assert saved == {"telegram:42": "entra:oid-xyz"}
|
||||
|
||||
async def test_unlink_removes_entry(self, tmp_path: Path) -> None:
|
||||
store = EntraIdentityStore(tmp_path / "links.json")
|
||||
await store.link("telegram:42", "oid")
|
||||
await store.unlink("telegram:42")
|
||||
assert store.lookup("telegram:42") is None
|
||||
assert json.loads((tmp_path / "links.json").read_text()) == {}
|
||||
|
||||
async def test_unlink_unknown_is_noop(self, tmp_path: Path) -> None:
|
||||
store = EntraIdentityStore(tmp_path / "links.json")
|
||||
await store.unlink("telegram:never") # must not raise
|
||||
assert not (tmp_path / "links.json").exists()
|
||||
|
||||
def test_loads_existing_file(self, tmp_path: Path) -> None:
|
||||
path = tmp_path / "links.json"
|
||||
path.write_text(json.dumps({"telegram:1": "entra:abc"}))
|
||||
store = EntraIdentityStore(path)
|
||||
assert store.lookup("telegram:1") == "entra:abc"
|
||||
|
||||
def test_corrupt_file_starts_empty(self, tmp_path: Path) -> None:
|
||||
path = tmp_path / "links.json"
|
||||
path.write_text("not-json")
|
||||
store = EntraIdentityStore(path)
|
||||
assert store.lookup("anything") is None
|
||||
|
||||
|
||||
class TestEntraIdentityLinkChannelConfig:
|
||||
def test_rejects_neither_credential(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(ValueError, match="exactly one"):
|
||||
EntraIdentityLinkChannel(
|
||||
store=EntraIdentityStore(tmp_path / "x.json"),
|
||||
tenant_id="t",
|
||||
client_id="c",
|
||||
public_base_url="https://example.com",
|
||||
)
|
||||
|
||||
def test_rejects_both_credentials(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(ValueError, match="exactly one"):
|
||||
EntraIdentityLinkChannel(
|
||||
store=EntraIdentityStore(tmp_path / "x.json"),
|
||||
tenant_id="t",
|
||||
client_id="c",
|
||||
public_base_url="https://example.com",
|
||||
client_secret="s",
|
||||
certificate_path="/tmp/does-not-exist.pem",
|
||||
)
|
||||
|
||||
def test_redirect_uri_strips_trailing_slash(self, tmp_path: Path) -> None:
|
||||
with patch(
|
||||
"agent_framework_hosting_entra._channel.msal.ConfidentialClientApplication",
|
||||
MagicMock(),
|
||||
):
|
||||
ch = EntraIdentityLinkChannel(
|
||||
store=EntraIdentityStore(tmp_path / "x.json"),
|
||||
tenant_id="t",
|
||||
client_id="c",
|
||||
public_base_url="https://example.com/",
|
||||
client_secret="s",
|
||||
)
|
||||
assert ch.redirect_uri == "https://example.com/auth/callback"
|
||||
|
||||
|
||||
class TestEntraIdentityLinkChannelRoutes:
|
||||
def _make_channel(self, tmp_path: Path, msal_app: MagicMock) -> tuple[EntraIdentityLinkChannel, EntraIdentityStore]:
|
||||
store = EntraIdentityStore(tmp_path / "links.json")
|
||||
with patch(
|
||||
"agent_framework_hosting_entra._channel.msal.ConfidentialClientApplication",
|
||||
return_value=msal_app,
|
||||
):
|
||||
ch = EntraIdentityLinkChannel(
|
||||
store=store,
|
||||
tenant_id="tenant-1",
|
||||
client_id="client-1",
|
||||
public_base_url="https://example.com",
|
||||
client_secret="s",
|
||||
)
|
||||
return ch, store
|
||||
|
||||
def _mount_app(self, ch: EntraIdentityLinkChannel) -> Starlette:
|
||||
# We don't depend on AgentFrameworkHost here — wire the routes
|
||||
# directly so we can exercise the channel in isolation.
|
||||
from starlette.routing import Mount
|
||||
|
||||
contribution = ch.contribute(MagicMock())
|
||||
return Starlette(routes=[Mount(ch.path, routes=contribution.routes)])
|
||||
|
||||
def test_start_missing_params_returns_400(self, tmp_path: Path) -> None:
|
||||
msal_app = MagicMock()
|
||||
ch, _ = self._make_channel(tmp_path, msal_app)
|
||||
with TestClient(self._mount_app(ch)) as client:
|
||||
r = client.get("/auth/start", follow_redirects=False)
|
||||
assert r.status_code == 400
|
||||
|
||||
def test_start_redirects_to_authorize_url(self, tmp_path: Path) -> None:
|
||||
msal_app = MagicMock()
|
||||
msal_app.get_authorization_request_url.return_value = (
|
||||
"https://login.microsoftonline.com/tenant-1/oauth2/v2.0/authorize?state=X"
|
||||
)
|
||||
ch, _ = self._make_channel(tmp_path, msal_app)
|
||||
with TestClient(self._mount_app(ch)) as client:
|
||||
r = client.get(
|
||||
"/auth/start",
|
||||
params={"channel": "telegram", "id": "42"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert r.status_code == 302
|
||||
assert "login.microsoftonline.com" in r.headers["location"]
|
||||
|
||||
def test_callback_invalid_state_returns_400(self, tmp_path: Path) -> None:
|
||||
msal_app = MagicMock()
|
||||
ch, _ = self._make_channel(tmp_path, msal_app)
|
||||
ch._http = MagicMock(aclose=AsyncMock())
|
||||
with TestClient(self._mount_app(ch)) as client:
|
||||
r = client.get("/auth/callback", params={"code": "c", "state": "unknown"})
|
||||
assert r.status_code == 400
|
||||
|
||||
def test_callback_links_oid_on_success(self, tmp_path: Path) -> None:
|
||||
msal_app = MagicMock()
|
||||
msal_app.get_authorization_request_url.return_value = (
|
||||
"https://login.microsoftonline.com/tenant-1/authorize?state=X"
|
||||
)
|
||||
msal_app.acquire_token_by_authorization_code.return_value = {"access_token": "t"}
|
||||
ch, store = self._make_channel(tmp_path, msal_app)
|
||||
|
||||
# Fake the Graph /me call.
|
||||
graph_response = MagicMock()
|
||||
graph_response.status_code = 200
|
||||
graph_response.json = MagicMock(return_value={"id": "oid-xyz", "userPrincipalName": "user@x"})
|
||||
ch._http = MagicMock()
|
||||
ch._http.get = AsyncMock(return_value=graph_response)
|
||||
ch._http.aclose = AsyncMock()
|
||||
|
||||
# Mint a real state via the public API so the pending dict is populated.
|
||||
ch.authorize_url_for("telegram", "42")
|
||||
state = next(iter(ch._pending.keys()))
|
||||
|
||||
with TestClient(self._mount_app(ch)) as client:
|
||||
r = client.get("/auth/callback", params={"code": "abc", "state": state})
|
||||
assert r.status_code == 200
|
||||
assert store.lookup("telegram:42") == "entra:oid-xyz"
|
||||
|
||||
def test_callback_token_failure_returns_502(self, tmp_path: Path) -> None:
|
||||
msal_app = MagicMock()
|
||||
msal_app.get_authorization_request_url.return_value = "https://x"
|
||||
msal_app.acquire_token_by_authorization_code.return_value = {
|
||||
"error": "invalid_grant",
|
||||
"error_description": "expired",
|
||||
}
|
||||
ch, store = self._make_channel(tmp_path, msal_app)
|
||||
ch._http = MagicMock(aclose=AsyncMock())
|
||||
ch.authorize_url_for("telegram", "42")
|
||||
state = next(iter(ch._pending.keys()))
|
||||
with TestClient(self._mount_app(ch)) as client:
|
||||
r = client.get("/auth/callback", params={"code": "c", "state": state})
|
||||
assert r.status_code == 502
|
||||
assert store.lookup("telegram:42") is None
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Round-2 security hardening #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
class TestSignedLinkToken:
|
||||
"""`/auth/start` must reject unsigned/forged requests when secret is set."""
|
||||
|
||||
def _make_signed_channel(
|
||||
self, tmp_path: Path, msal_app: MagicMock, *, secret: str = "test-secret"
|
||||
) -> EntraIdentityLinkChannel:
|
||||
store = EntraIdentityStore(tmp_path / "links.json")
|
||||
with patch(
|
||||
"agent_framework_hosting_entra._channel.msal.ConfidentialClientApplication",
|
||||
return_value=msal_app,
|
||||
):
|
||||
return EntraIdentityLinkChannel(
|
||||
store=store,
|
||||
tenant_id="tenant-1",
|
||||
client_id="client-1",
|
||||
public_base_url="https://example.com",
|
||||
client_secret="s",
|
||||
link_token_secret=secret,
|
||||
)
|
||||
|
||||
def _mount(self, ch: EntraIdentityLinkChannel) -> Starlette:
|
||||
from starlette.routing import Mount
|
||||
|
||||
contribution = ch.contribute(MagicMock())
|
||||
return Starlette(routes=[Mount(ch.path, routes=contribution.routes)])
|
||||
|
||||
def test_start_rejects_unsigned_request_when_secret_set(self, tmp_path: Path) -> None:
|
||||
msal_app = MagicMock()
|
||||
ch = self._make_signed_channel(tmp_path, msal_app)
|
||||
with TestClient(self._mount(ch)) as client:
|
||||
r = client.get(
|
||||
"/auth/start",
|
||||
params={"channel": "telegram", "id": "42"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert r.status_code == 403
|
||||
|
||||
def test_start_rejects_forged_signature(self, tmp_path: Path) -> None:
|
||||
msal_app = MagicMock()
|
||||
ch = self._make_signed_channel(tmp_path, msal_app)
|
||||
with TestClient(self._mount(ch)) as client:
|
||||
r = client.get(
|
||||
"/auth/start",
|
||||
params={
|
||||
"channel": "telegram",
|
||||
"id": "42",
|
||||
"exp": "9999999999",
|
||||
"sig": "deadbeef",
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert r.status_code == 403
|
||||
|
||||
def test_start_accepts_valid_signed_url(self, tmp_path: Path) -> None:
|
||||
msal_app = MagicMock()
|
||||
msal_app.get_authorization_request_url.return_value = (
|
||||
"https://login.microsoftonline.com/tenant-1/authorize?state=X"
|
||||
)
|
||||
ch = self._make_signed_channel(tmp_path, msal_app)
|
||||
url = ch.mint_start_url("telegram", "42")
|
||||
# Strip the host prefix to call via the in-process client.
|
||||
path_and_query = url.split("https://example.com", 1)[1]
|
||||
with TestClient(self._mount(ch)) as client:
|
||||
r = client.get(path_and_query, follow_redirects=False)
|
||||
assert r.status_code == 302
|
||||
|
||||
def test_start_rejects_expired_signed_url(self, tmp_path: Path) -> None:
|
||||
import time as time_module
|
||||
from urllib.parse import urlencode
|
||||
|
||||
msal_app = MagicMock()
|
||||
ch = self._make_signed_channel(tmp_path, msal_app)
|
||||
# Hand-craft an expired-but-otherwise-valid token.
|
||||
expired = int(time_module.time()) - 60
|
||||
sig = ch._sign_link_token("telegram", "42", expired) # type: ignore[attr-defined] # pyright: ignore[reportPrivateUsage]
|
||||
params = {"channel": "telegram", "id": "42", "exp": str(expired), "sig": sig}
|
||||
with TestClient(self._mount(ch)) as client:
|
||||
r = client.get(f"/auth/start?{urlencode(params)}", follow_redirects=False)
|
||||
assert r.status_code == 403
|
||||
|
||||
def test_mint_start_url_requires_secret(self, tmp_path: Path) -> None:
|
||||
import pytest
|
||||
|
||||
msal_app = MagicMock()
|
||||
store = EntraIdentityStore(tmp_path / "links.json")
|
||||
with patch(
|
||||
"agent_framework_hosting_entra._channel.msal.ConfidentialClientApplication",
|
||||
return_value=msal_app,
|
||||
):
|
||||
ch = EntraIdentityLinkChannel(
|
||||
store=store,
|
||||
tenant_id="tenant-1",
|
||||
client_id="client-1",
|
||||
public_base_url="https://example.com",
|
||||
client_secret="s",
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="link_token_secret"):
|
||||
ch.mint_start_url("telegram", "42")
|
||||
|
||||
def test_unsigned_mode_logs_warning_at_startup(self, tmp_path: Path, caplog: Any) -> None:
|
||||
import asyncio as asyncio_mod
|
||||
import logging
|
||||
|
||||
msal_app = MagicMock()
|
||||
store = EntraIdentityStore(tmp_path / "links.json")
|
||||
with patch(
|
||||
"agent_framework_hosting_entra._channel.msal.ConfidentialClientApplication",
|
||||
return_value=msal_app,
|
||||
):
|
||||
ch = EntraIdentityLinkChannel(
|
||||
store=store,
|
||||
tenant_id="tenant-1",
|
||||
client_id="client-1",
|
||||
public_base_url="https://example.com",
|
||||
client_secret="s",
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="agent_framework.hosting"):
|
||||
asyncio_mod.run(ch._on_startup()) # pyright: ignore[reportPrivateUsage]
|
||||
asyncio_mod.run(ch._on_shutdown()) # pyright: ignore[reportPrivateUsage]
|
||||
assert any("WITHOUT link_token_secret" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
class TestXssEscaping:
|
||||
"""All inbound query/profile values must be HTML-escaped before output."""
|
||||
|
||||
def _setup(self, tmp_path: Path) -> tuple[EntraIdentityLinkChannel, EntraIdentityStore, MagicMock]:
|
||||
store = EntraIdentityStore(tmp_path / "links.json")
|
||||
msal_app = MagicMock()
|
||||
msal_app.get_authorization_request_url.return_value = "https://x"
|
||||
with patch(
|
||||
"agent_framework_hosting_entra._channel.msal.ConfidentialClientApplication",
|
||||
return_value=msal_app,
|
||||
):
|
||||
ch = EntraIdentityLinkChannel(
|
||||
store=store,
|
||||
tenant_id="tenant-1",
|
||||
client_id="client-1",
|
||||
public_base_url="https://example.com",
|
||||
client_secret="s",
|
||||
)
|
||||
return ch, store, msal_app
|
||||
|
||||
def _mount(self, ch: EntraIdentityLinkChannel) -> Starlette:
|
||||
from starlette.routing import Mount
|
||||
|
||||
contribution = ch.contribute(MagicMock())
|
||||
return Starlette(routes=[Mount(ch.path, routes=contribution.routes)])
|
||||
|
||||
def test_callback_error_param_is_escaped(self, tmp_path: Path) -> None:
|
||||
ch, _, _ = self._setup(tmp_path)
|
||||
ch._http = MagicMock(aclose=AsyncMock())
|
||||
with TestClient(self._mount(ch)) as client:
|
||||
r = client.get(
|
||||
"/auth/callback",
|
||||
params={
|
||||
"error": "<script>alert(1)</script>",
|
||||
"error_description": "<img onerror=x>",
|
||||
},
|
||||
)
|
||||
assert r.status_code == 400
|
||||
assert "<script>" not in r.text
|
||||
assert "<script>" in r.text
|
||||
assert "<img" in r.text
|
||||
|
||||
def test_callback_success_escapes_channel_key_and_upn(self, tmp_path: Path) -> None:
|
||||
ch, store, msal_app = self._setup(tmp_path)
|
||||
msal_app.acquire_token_by_authorization_code.return_value = {"access_token": "t"}
|
||||
graph_response = MagicMock()
|
||||
graph_response.status_code = 200
|
||||
graph_response.json = MagicMock(
|
||||
return_value={"id": "oid-1", "userPrincipalName": "<script>alert(1)</script>@x"}
|
||||
)
|
||||
ch._http = MagicMock(aclose=AsyncMock())
|
||||
ch._http.get = AsyncMock(return_value=graph_response)
|
||||
# Mint a binding via authorize_url_for (channel-side trusted call).
|
||||
ch.authorize_url_for("<svg/onload=alert(1)>", "42")
|
||||
state = next(iter(ch._pending.keys()))
|
||||
with TestClient(self._mount(ch)) as client:
|
||||
r = client.get("/auth/callback", params={"code": "abc", "state": state})
|
||||
assert r.status_code == 200
|
||||
assert "<script>" not in r.text
|
||||
assert "<svg/" not in r.text
|
||||
assert "<svg/onload=alert(1)>" in r.text
|
||||
assert "<script>" in r.text
|
||||
|
||||
|
||||
class TestReturnToOpenRedirect:
|
||||
"""`return_to` must be relative or same-origin only."""
|
||||
|
||||
def _make(self, tmp_path: Path) -> EntraIdentityLinkChannel:
|
||||
store = EntraIdentityStore(tmp_path / "links.json")
|
||||
msal_app = MagicMock()
|
||||
msal_app.get_authorization_request_url.return_value = (
|
||||
"https://login.microsoftonline.com/tenant-1/authorize?state=X"
|
||||
)
|
||||
with patch(
|
||||
"agent_framework_hosting_entra._channel.msal.ConfidentialClientApplication",
|
||||
return_value=msal_app,
|
||||
):
|
||||
return EntraIdentityLinkChannel(
|
||||
store=store,
|
||||
tenant_id="tenant-1",
|
||||
client_id="client-1",
|
||||
public_base_url="https://example.com",
|
||||
client_secret="s",
|
||||
)
|
||||
|
||||
def _mount(self, ch: EntraIdentityLinkChannel) -> Starlette:
|
||||
from starlette.routing import Mount
|
||||
|
||||
contribution = ch.contribute(MagicMock())
|
||||
return Starlette(routes=[Mount(ch.path, routes=contribution.routes)])
|
||||
|
||||
def test_start_rejects_external_return_to(self, tmp_path: Path) -> None:
|
||||
ch = self._make(tmp_path)
|
||||
with TestClient(self._mount(ch)) as client:
|
||||
r = client.get(
|
||||
"/auth/start",
|
||||
params={"channel": "telegram", "id": "42", "return_to": "https://evil.example/"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert r.status_code == 400
|
||||
|
||||
def test_start_accepts_relative_return_to(self, tmp_path: Path) -> None:
|
||||
ch = self._make(tmp_path)
|
||||
with TestClient(self._mount(ch)) as client:
|
||||
r = client.get(
|
||||
"/auth/start",
|
||||
params={"channel": "telegram", "id": "42", "return_to": "/done"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert r.status_code == 302
|
||||
|
||||
def test_start_accepts_same_origin_return_to(self, tmp_path: Path) -> None:
|
||||
ch = self._make(tmp_path)
|
||||
with TestClient(self._mount(ch)) as client:
|
||||
r = client.get(
|
||||
"/auth/start",
|
||||
params={
|
||||
"channel": "telegram",
|
||||
"id": "42",
|
||||
"return_to": "https://example.com/done",
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert r.status_code == 302
|
||||
|
||||
def test_protocol_relative_return_to_rejected(self, tmp_path: Path) -> None:
|
||||
ch = self._make(tmp_path)
|
||||
with TestClient(self._mount(ch)) as client:
|
||||
r = client.get(
|
||||
"/auth/start",
|
||||
params={
|
||||
"channel": "telegram",
|
||||
"id": "42",
|
||||
"return_to": "//evil.example/",
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
# //evil.example/ — Python's urlparse treats this as netloc=evil.example,
|
||||
# which is NOT same-origin, so it must be rejected.
|
||||
assert r.status_code == 400
|
||||
Reference in New Issue
Block a user