mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: add agent-framework-hosting-discord channel (#6081)
* Add Discord hosting channel Add an alpha agent-framework-hosting-discord package backed by Discord HTTP Interactions. The channel verifies signed slash-command requests, registers commands, runs hosted agents and ChannelCommand handlers, supports originating response hooks, streams by editing the original interaction response, and can push through Discord channel ids. Factor standard channel response-hook context application into hosting core so both host fan-out and originating channel replies use one helper. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address Discord review chunking feedback Ensure Discord command replies are chunked and streaming preview edits stay under Discord's content limit while final streamed replies continue through the chunked reply path. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * small fix in init * updated lock --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
6b822853eb
commit
e8c22caaeb
@@ -34,6 +34,7 @@ Status is grouped into these buckets:
|
||||
| `agent-framework-foundry-local` | `python/packages/foundry_local` | `beta` |
|
||||
| `agent-framework-gemini` | `python/packages/gemini` | `alpha` |
|
||||
| `agent-framework-github-copilot` | `python/packages/github_copilot` | `beta` |
|
||||
| `agent-framework-hosting-discord` | `python/packages/hosting-discord` | `alpha` |
|
||||
| `agent-framework-hyperlight` | `python/packages/hyperlight` | `beta` |
|
||||
| `agent-framework-lab` | `python/packages/lab` | `beta` |
|
||||
| `agent-framework-mem0` | `python/packages/mem0` | `beta` |
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
# agent-framework-hosting-discord
|
||||
|
||||
Discord HTTP Interactions channel for [agent-framework-hosting](../hosting).
|
||||
The channel exposes a signed Starlette route for Discord slash commands, maps a
|
||||
configurable slash command to the hosted agent, maps `ChannelCommand` instances
|
||||
to native Discord commands, and supports push to Discord channel ids.
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from agent_framework_hosting import AgentFrameworkHost
|
||||
from agent_framework_hosting_discord import DiscordChannel
|
||||
|
||||
host = AgentFrameworkHost(
|
||||
target=my_agent,
|
||||
channels=[
|
||||
DiscordChannel(
|
||||
application_id="<discord application id>",
|
||||
public_key="<discord public key>",
|
||||
bot_token="<discord bot token>",
|
||||
guild_id="<guild id for fast dev command registration>",
|
||||
)
|
||||
],
|
||||
)
|
||||
host.serve()
|
||||
```
|
||||
|
||||
Configure the Discord Developer Portal interaction endpoint as:
|
||||
|
||||
```text
|
||||
https://<your-host>/discord/interactions
|
||||
```
|
||||
|
||||
The channel verifies Discord's `X-Signature-Ed25519` header against the raw
|
||||
request body before parsing JSON. `skip_signature_verification=True` exists only
|
||||
for local tests and should not be used on a public endpoint.
|
||||
|
||||
## Slash commands
|
||||
|
||||
By default, `/ask prompt:<text>` invokes the hosted agent. Additional
|
||||
`ChannelCommand` instances are registered as Discord slash commands with an
|
||||
optional `input` string option:
|
||||
|
||||
```python
|
||||
from agent_framework_hosting import ChannelCommand
|
||||
|
||||
async def reset(ctx):
|
||||
await ctx.reply("Reset acknowledged")
|
||||
|
||||
DiscordChannel(
|
||||
application_id="...",
|
||||
public_key="...",
|
||||
bot_token="...",
|
||||
commands=[ChannelCommand("reset", "Reset the conversation", reset)],
|
||||
)
|
||||
```
|
||||
|
||||
When `guild_id` is set, commands are registered only for that guild and usually
|
||||
appear quickly. Global command registration can take much longer to propagate.
|
||||
If `register_commands=True` but `bot_token` is omitted, the channel logs a
|
||||
warning and assumes commands were registered outside the host.
|
||||
|
||||
## Identity, sessions, and push
|
||||
|
||||
The default isolation key is `discord:<guild-or-dm>:<channel_id>:<user_id>`,
|
||||
which keeps each user private inside a Discord channel or thread. Pass
|
||||
`isolation_key_factory=` to use a different scope.
|
||||
|
||||
`ChannelIdentity.native_id` is the Discord user id. Push requires
|
||||
`identity.attributes["channel_id"]`; the first slice intentionally does not
|
||||
create DM channels as a fallback.
|
||||
|
||||
## Streaming
|
||||
|
||||
Set `streaming=True` to consume the host stream and edit the original Discord
|
||||
interaction response as text accumulates. Edits are debounced with
|
||||
`edit_interval` to avoid excessive Discord REST calls.
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Discord channel for ``agent-framework-hosting``."""
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
from ._channel import DiscordChannel, DiscordIsolationKeyFactory, discord_isolation_key
|
||||
|
||||
try:
|
||||
__version__ = importlib.metadata.version(__name__)
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
__version__ = "0.0.0"
|
||||
|
||||
__all__ = [
|
||||
"DiscordChannel",
|
||||
"DiscordIsolationKeyFactory",
|
||||
"__version__",
|
||||
"discord_isolation_key",
|
||||
]
|
||||
@@ -0,0 +1,657 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Discord HTTP Interactions channel."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable, Coroutine, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
from agent_framework import AgentResponse, AgentResponseUpdate, Content, Message, ResponseStream
|
||||
from agent_framework_hosting import (
|
||||
ChannelCommand,
|
||||
ChannelCommandContext,
|
||||
ChannelContext,
|
||||
ChannelContribution,
|
||||
ChannelIdentity,
|
||||
ChannelRequest,
|
||||
ChannelResponseHook,
|
||||
ChannelRunHook,
|
||||
ChannelSession,
|
||||
ChannelStreamTransformHook,
|
||||
HostedRunResult,
|
||||
apply_channel_response_hook,
|
||||
apply_run_hook,
|
||||
)
|
||||
from nacl.exceptions import BadSignatureError
|
||||
from nacl.signing import VerifyKey
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.routing import Route
|
||||
|
||||
logger = logging.getLogger("agent_framework.hosting.discord")
|
||||
|
||||
DiscordInteraction = Mapping[str, Any]
|
||||
DiscordIsolationKeyFactory = Callable[[DiscordInteraction], str]
|
||||
|
||||
_DISCORD_API_BASE = "https://discord.com/api/v10"
|
||||
_DISCORD_MAX_BODY_BYTES = 1024 * 1024
|
||||
_DISCORD_MAX_CONTENT_LEN = 2000
|
||||
_INTERACTION_PING = 1
|
||||
_INTERACTION_APPLICATION_COMMAND = 2
|
||||
_RESPONSE_PONG = 1
|
||||
_RESPONSE_DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE = 5
|
||||
_OPTION_STRING = 3
|
||||
_APPLICATION_COMMAND_CHAT_INPUT = 1
|
||||
_COMMAND_NAME_RE = re.compile(r"^[a-z0-9_-]{1,32}$")
|
||||
|
||||
|
||||
def discord_isolation_key(guild_id: str | None, channel_id: str, user_id: str) -> str:
|
||||
"""Build the default Discord isolation key.
|
||||
|
||||
Args:
|
||||
guild_id: Discord guild id, or ``None`` for a DM interaction.
|
||||
channel_id: Discord channel or thread id.
|
||||
user_id: Discord user id.
|
||||
|
||||
Returns:
|
||||
A stable host isolation key scoped to guild/channel/user.
|
||||
"""
|
||||
scope = guild_id or "dm"
|
||||
return f"discord:{scope}:{channel_id}:{user_id}"
|
||||
|
||||
|
||||
def _default_isolation_key(interaction: DiscordInteraction) -> str:
|
||||
user = _user_from_interaction(interaction)
|
||||
user_id = _require_string(user.get("id"), "interaction user id")
|
||||
channel_id = _require_string(interaction.get("channel_id"), "interaction channel_id")
|
||||
guild_id = _string_or_none(interaction.get("guild_id"))
|
||||
return discord_isolation_key(guild_id, channel_id, user_id)
|
||||
|
||||
|
||||
def _text_result(text: str) -> HostedRunResult[AgentResponse]:
|
||||
"""Build a host delivery payload from text accumulated by this channel."""
|
||||
return HostedRunResult(AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text(text=text)])]))
|
||||
|
||||
|
||||
class DiscordChannel:
|
||||
"""Discord channel backed by signed HTTP Interactions."""
|
||||
|
||||
name = "discord"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_id: str,
|
||||
public_key: str,
|
||||
bot_token: str | None = None,
|
||||
guild_id: str | None = None,
|
||||
path: str = "/discord",
|
||||
agent_command: str = "ask",
|
||||
agent_command_description: str = "Ask the agent",
|
||||
agent_command_option: str = "prompt",
|
||||
register_commands: bool = True,
|
||||
commands: Sequence[ChannelCommand] | None = None,
|
||||
run_hook: ChannelRunHook | None = None,
|
||||
response_hook: ChannelResponseHook | None = None,
|
||||
stream_transform_hook: ChannelStreamTransformHook | None = None,
|
||||
streaming: bool = False,
|
||||
isolation_key_factory: DiscordIsolationKeyFactory | None = None,
|
||||
skip_signature_verification: bool = False,
|
||||
edit_interval: float = 1.0,
|
||||
max_body_bytes: int = _DISCORD_MAX_BODY_BYTES,
|
||||
api_base_url: str = _DISCORD_API_BASE,
|
||||
) -> None:
|
||||
"""Configure the Discord channel.
|
||||
|
||||
Keyword Args:
|
||||
application_id: Discord application id.
|
||||
public_key: Discord application public key as lowercase or
|
||||
uppercase hex. Used to verify interaction signatures.
|
||||
bot_token: Bot token used to register slash commands and push
|
||||
messages to Discord channel ids. Interaction webhook replies
|
||||
do not require this token.
|
||||
guild_id: Optional guild id for guild-scoped slash command
|
||||
registration. Recommended for development because global
|
||||
command registration can take a long time to propagate.
|
||||
path: Host mount path. The interaction route is contributed as
|
||||
``/interactions`` below this path.
|
||||
agent_command: Slash command name that invokes the hosted agent.
|
||||
agent_command_description: Description for the agent slash command.
|
||||
agent_command_option: String option name that carries the prompt.
|
||||
register_commands: Whether startup should register slash commands
|
||||
through Discord REST when ``bot_token`` is configured.
|
||||
commands: Additional host ``ChannelCommand`` instances to expose
|
||||
as Discord slash commands.
|
||||
run_hook: Optional hook that can rewrite the channel request before
|
||||
it reaches the host.
|
||||
response_hook: Optional hook that can rewrite the hosted result
|
||||
before the originating Discord response is serialized.
|
||||
stream_transform_hook: Optional per-update transform hook applied
|
||||
while streaming.
|
||||
streaming: Whether the agent command should call ``run_stream``
|
||||
and edit the original interaction response as deltas arrive.
|
||||
isolation_key_factory: Optional callable that receives the raw
|
||||
Discord interaction and returns a host isolation key.
|
||||
skip_signature_verification: Disable Ed25519 verification. Use
|
||||
only for local tests; never expose publicly with this enabled.
|
||||
edit_interval: Minimum seconds between streaming edits to the
|
||||
original Discord interaction response.
|
||||
max_body_bytes: Maximum raw interaction request body size.
|
||||
api_base_url: Discord API base URL. Primarily useful for tests.
|
||||
|
||||
Raises:
|
||||
ValueError: If public key hex or command names are invalid, or if
|
||||
command names collide.
|
||||
"""
|
||||
self.application_id = application_id
|
||||
self.public_key = public_key
|
||||
self.bot_token = bot_token
|
||||
self.guild_id = guild_id
|
||||
self.path = path
|
||||
self.agent_command = agent_command
|
||||
self.agent_command_description = agent_command_description
|
||||
self.agent_command_option = agent_command_option
|
||||
self.register_commands = register_commands
|
||||
self._commands: set[ChannelCommand] = set(commands) or {} # type: ignore
|
||||
self._command_by_name = {command.name: command for command in self._commands}
|
||||
self._run_hook = run_hook
|
||||
self.response_hook = response_hook
|
||||
self._stream_transform_hook = stream_transform_hook
|
||||
self._streaming = streaming
|
||||
self._isolation_key_factory = isolation_key_factory or _default_isolation_key
|
||||
self._skip_signature_verification = skip_signature_verification
|
||||
self._edit_interval = edit_interval
|
||||
self._max_body_bytes = max_body_bytes
|
||||
self._api_base_url = api_base_url.rstrip("/")
|
||||
self._ctx: ChannelContext | None = None
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
self._tasks: set[asyncio.Task[None]] = set()
|
||||
|
||||
self._validate_configuration()
|
||||
try:
|
||||
self._verify_key = VerifyKey(bytes.fromhex(public_key))
|
||||
except ValueError as exc:
|
||||
raise ValueError("DiscordChannel public_key must be a valid Ed25519 public key hex string") from exc
|
||||
|
||||
def contribute(self, context: ChannelContext) -> ChannelContribution:
|
||||
"""Register the Discord interaction route and lifecycle hooks."""
|
||||
self._ctx = context
|
||||
return ChannelContribution(
|
||||
routes=[Route("/interactions", self._handle, methods=["POST"])],
|
||||
commands=self._commands,
|
||||
on_startup=[self._on_startup],
|
||||
on_shutdown=[self._on_shutdown],
|
||||
)
|
||||
|
||||
async def push(self, identity: ChannelIdentity, payload: HostedRunResult[Any]) -> None:
|
||||
"""Push a hosted result to a Discord channel.
|
||||
|
||||
Args:
|
||||
identity: Destination identity. ``identity.attributes`` must carry
|
||||
``channel_id``.
|
||||
payload: Hosted run result to render as Discord message text.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the channel has no bot token for Discord REST.
|
||||
ValueError: If ``channel_id`` is missing from the identity.
|
||||
"""
|
||||
channel_id = _string_or_none(identity.attributes.get("channel_id"))
|
||||
if channel_id is None:
|
||||
raise ValueError("Discord push requires identity.attributes['channel_id']")
|
||||
if self.bot_token is None:
|
||||
raise RuntimeError("DiscordChannel.push requires bot_token to send channel messages")
|
||||
await self._send_channel_messages(channel_id, _payload_text(payload))
|
||||
|
||||
async def _on_startup(self) -> None:
|
||||
"""Open the Discord REST client and optionally register slash commands."""
|
||||
self._ensure_http()
|
||||
if self._skip_signature_verification:
|
||||
logger.warning(
|
||||
"DiscordChannel running with skip_signature_verification=True. "
|
||||
"Use only for local tests; public Discord endpoints must verify signatures."
|
||||
)
|
||||
if not self.register_commands:
|
||||
return
|
||||
if self.bot_token is None:
|
||||
logger.warning(
|
||||
"DiscordChannel register_commands=True but bot_token is not configured; "
|
||||
"slash commands must be registered outside the host."
|
||||
)
|
||||
return
|
||||
if self.guild_id is None:
|
||||
logger.warning(
|
||||
"DiscordChannel registering global slash commands; Discord can take a long time "
|
||||
"to propagate global command changes. Set guild_id for faster development updates."
|
||||
)
|
||||
try:
|
||||
await self._register_commands()
|
||||
except (RuntimeError, httpx.HTTPError):
|
||||
logger.exception("DiscordChannel slash command registration failed; continuing startup")
|
||||
|
||||
async def _on_shutdown(self) -> None:
|
||||
"""Drain in-flight interaction tasks and close the Discord REST client."""
|
||||
if self._tasks:
|
||||
await asyncio.gather(*self._tasks, return_exceptions=True)
|
||||
if self._http is not None:
|
||||
await self._http.aclose()
|
||||
self._http = None
|
||||
|
||||
async def _handle(self, request: Request) -> Response:
|
||||
"""Handle one Discord interaction webhook request."""
|
||||
raw_body = await request.body()
|
||||
if len(raw_body) > self._max_body_bytes:
|
||||
return JSONResponse({"error": "request body too large"}, status_code=413)
|
||||
if not self._skip_signature_verification and not self._verify_signature(request, raw_body):
|
||||
return JSONResponse({"error": "invalid signature"}, status_code=401)
|
||||
try:
|
||||
body = json.loads(raw_body.decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
return JSONResponse({"error": "invalid JSON"}, status_code=400)
|
||||
if not isinstance(body, Mapping):
|
||||
return JSONResponse({"error": "interaction body must be a JSON object"}, status_code=400)
|
||||
interaction = cast("DiscordInteraction", body)
|
||||
|
||||
interaction_type = interaction.get("type")
|
||||
if interaction_type == _INTERACTION_PING:
|
||||
return JSONResponse({"type": _RESPONSE_PONG})
|
||||
if interaction_type != _INTERACTION_APPLICATION_COMMAND:
|
||||
return JSONResponse({"error": f"unsupported interaction type: {interaction_type!r}"}, status_code=400)
|
||||
|
||||
self._schedule(self._dispatch_application_command(interaction))
|
||||
return JSONResponse({"type": _RESPONSE_DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE})
|
||||
|
||||
async def _dispatch_application_command(self, interaction: DiscordInteraction) -> None:
|
||||
token = _require_string(interaction.get("token"), "interaction token")
|
||||
try:
|
||||
name = _application_command_name(interaction)
|
||||
if name == self.agent_command:
|
||||
await self._run_agent_command(interaction, token)
|
||||
return
|
||||
command = self._command_by_name.get(name)
|
||||
if command is None:
|
||||
await self._edit_original(token, f"Unknown Discord command: {name}")
|
||||
return
|
||||
await self._run_channel_command(command, interaction, token)
|
||||
except Exception:
|
||||
logger.exception("DiscordChannel interaction handling failed")
|
||||
await self._try_edit_original(token, "Sorry, something went wrong while handling that Discord command.")
|
||||
raise
|
||||
|
||||
async def _run_agent_command(self, interaction: DiscordInteraction, token: str) -> None:
|
||||
if self._ctx is None:
|
||||
raise RuntimeError("DiscordChannel was not contributed to a host.")
|
||||
prompt = _string_option(interaction, self.agent_command_option)
|
||||
if prompt is None:
|
||||
await self._edit_original(token, f"Missing required `{self.agent_command_option}` option.")
|
||||
return
|
||||
request = self._build_request(
|
||||
interaction,
|
||||
operation="message.create",
|
||||
input_value=prompt,
|
||||
stream=self._streaming,
|
||||
)
|
||||
if self._run_hook is not None:
|
||||
request = await apply_run_hook(
|
||||
self._run_hook,
|
||||
request,
|
||||
target=self._ctx.target,
|
||||
protocol_request=interaction,
|
||||
)
|
||||
if request.stream:
|
||||
await self._run_streaming(request, token)
|
||||
return
|
||||
result = await self._ctx.run(request)
|
||||
include_originating = await self._ctx.deliver_response(request, result)
|
||||
if include_originating:
|
||||
result = await apply_channel_response_hook(self, result, request=request, originating=True)
|
||||
await self._edit_original_with_result(token, result)
|
||||
else:
|
||||
await self._edit_original(token, "Sent.")
|
||||
|
||||
async def _run_channel_command(
|
||||
self,
|
||||
command: ChannelCommand,
|
||||
interaction: DiscordInteraction,
|
||||
token: str,
|
||||
) -> None:
|
||||
command_input = _string_option(interaction, "input")
|
||||
request = self._build_request(
|
||||
interaction,
|
||||
operation="command.invoke",
|
||||
input_value=f"/{command.name}" if command_input is None else f"/{command.name} {command_input}",
|
||||
stream=False,
|
||||
)
|
||||
reply = _DiscordInteractionReply(self, token)
|
||||
await command.handle(ChannelCommandContext(request=request, reply=reply))
|
||||
if not reply.sent:
|
||||
await self._edit_original(token, "Done.")
|
||||
|
||||
async def _run_streaming(self, request: ChannelRequest, token: str) -> None:
|
||||
if self._ctx is None:
|
||||
raise RuntimeError("DiscordChannel was not contributed to a host.")
|
||||
stream: ResponseStream[AgentResponseUpdate, AgentResponse] = self._ctx.run_stream(request)
|
||||
accumulated: list[str] = []
|
||||
last_edit = 0.0
|
||||
async for update in stream:
|
||||
transformed: AgentResponseUpdate | None = update
|
||||
if self._stream_transform_hook is not None:
|
||||
maybe = self._stream_transform_hook(update)
|
||||
if isinstance(maybe, Awaitable):
|
||||
transformed = await cast("Awaitable[AgentResponseUpdate | None]", maybe)
|
||||
else:
|
||||
transformed = maybe
|
||||
if transformed is None:
|
||||
continue
|
||||
chunk = _update_text(transformed)
|
||||
if not chunk:
|
||||
continue
|
||||
accumulated.append(chunk)
|
||||
now = time.monotonic()
|
||||
if self._edit_interval <= 0 or now - last_edit >= self._edit_interval:
|
||||
await self._edit_original(token, _stream_preview_content("".join(accumulated)))
|
||||
last_edit = now
|
||||
|
||||
final = _text_result("".join(accumulated))
|
||||
include_originating = await self._ctx.deliver_response(request, final)
|
||||
if include_originating:
|
||||
final = await apply_channel_response_hook(self, final, request=request, originating=True)
|
||||
await self._edit_original_with_result(token, final)
|
||||
else:
|
||||
await self._edit_original(token, "Sent.")
|
||||
|
||||
def _build_request(
|
||||
self,
|
||||
interaction: DiscordInteraction,
|
||||
*,
|
||||
operation: str,
|
||||
input_value: Any,
|
||||
stream: bool,
|
||||
) -> ChannelRequest:
|
||||
identity = self._identity_from_interaction(interaction)
|
||||
command_name = _application_command_name(interaction)
|
||||
metadata = {
|
||||
"interaction_id": _string_or_none(interaction.get("id")),
|
||||
"application_id": self.application_id,
|
||||
"guild_id": _string_or_none(interaction.get("guild_id")),
|
||||
"channel_id": _string_or_none(interaction.get("channel_id")),
|
||||
"user_id": identity.native_id,
|
||||
"command": command_name,
|
||||
}
|
||||
clean_metadata = {key: value for key, value in metadata.items() if value is not None}
|
||||
return ChannelRequest(
|
||||
channel=self.name,
|
||||
operation=operation,
|
||||
input=input_value,
|
||||
session=ChannelSession(isolation_key=self._isolation_key_factory(interaction)),
|
||||
metadata=clean_metadata,
|
||||
attributes=clean_metadata,
|
||||
stream=stream,
|
||||
identity=identity,
|
||||
)
|
||||
|
||||
def _identity_from_interaction(self, interaction: DiscordInteraction) -> ChannelIdentity:
|
||||
user = _user_from_interaction(interaction)
|
||||
user_id = _require_string(user.get("id"), "interaction user id")
|
||||
attributes = {
|
||||
"username": _string_or_none(user.get("username")),
|
||||
"global_name": _string_or_none(user.get("global_name")),
|
||||
"guild_id": _string_or_none(interaction.get("guild_id")),
|
||||
"channel_id": _string_or_none(interaction.get("channel_id")),
|
||||
"application_id": self.application_id,
|
||||
}
|
||||
return ChannelIdentity(
|
||||
channel=self.name,
|
||||
native_id=user_id,
|
||||
attributes={key: value for key, value in attributes.items() if value is not None},
|
||||
)
|
||||
|
||||
def _verify_signature(self, request: Request, raw_body: bytes) -> bool:
|
||||
signature = request.headers.get("x-signature-ed25519")
|
||||
timestamp = request.headers.get("x-signature-timestamp")
|
||||
if not signature or not timestamp:
|
||||
return False
|
||||
try:
|
||||
self._verify_key.verify(timestamp.encode("utf-8") + raw_body, bytes.fromhex(signature))
|
||||
except (BadSignatureError, ValueError):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _schedule(self, coro: Coroutine[Any, Any, None]) -> None:
|
||||
task = asyncio.create_task(coro)
|
||||
self._tasks.add(task)
|
||||
task.add_done_callback(self._on_task_done)
|
||||
|
||||
def _on_task_done(self, task: asyncio.Task[None]) -> None:
|
||||
self._tasks.discard(task)
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception:
|
||||
logger.exception("DiscordChannel background task failed")
|
||||
|
||||
def _ensure_http(self) -> httpx.AsyncClient:
|
||||
if self._http is None:
|
||||
self._http = httpx.AsyncClient(base_url=self._api_base_url, timeout=30.0)
|
||||
return self._http
|
||||
|
||||
async def _register_commands(self) -> None:
|
||||
http = self._ensure_http()
|
||||
path = f"/applications/{self.application_id}/commands"
|
||||
if self.guild_id is not None:
|
||||
path = f"/applications/{self.application_id}/guilds/{self.guild_id}/commands"
|
||||
response = await http.put(path, headers=self._bot_headers(), json=self._command_payloads())
|
||||
_raise_for_discord_error(response, "register slash commands")
|
||||
|
||||
async def _edit_original_with_result(self, token: str, payload: HostedRunResult[Any]) -> None:
|
||||
chunks = _split_content(_payload_text(payload))
|
||||
await self._edit_original(token, chunks[0])
|
||||
for chunk in chunks[1:]:
|
||||
await self._send_followup(token, chunk)
|
||||
|
||||
async def _edit_original(self, token: str, content: str) -> None:
|
||||
http = self._ensure_http()
|
||||
response = await http.patch(
|
||||
f"/webhooks/{self.application_id}/{token}/messages/@original",
|
||||
json={"content": _normalize_content(content)},
|
||||
)
|
||||
_raise_for_discord_error(response, "edit interaction response")
|
||||
|
||||
async def _try_edit_original(self, token: str, content: str) -> None:
|
||||
try:
|
||||
await self._edit_original(token, content)
|
||||
except (RuntimeError, httpx.HTTPError):
|
||||
logger.exception("DiscordChannel failed to edit interaction error response")
|
||||
|
||||
async def _send_followup(self, token: str, content: str) -> None:
|
||||
http = self._ensure_http()
|
||||
response = await http.post(
|
||||
f"/webhooks/{self.application_id}/{token}",
|
||||
json={"content": _normalize_content(content)},
|
||||
)
|
||||
_raise_for_discord_error(response, "send interaction follow-up")
|
||||
|
||||
async def _send_channel_messages(self, channel_id: str, content: str) -> None:
|
||||
http = self._ensure_http()
|
||||
for chunk in _split_content(content):
|
||||
response = await http.post(
|
||||
f"/channels/{channel_id}/messages",
|
||||
headers=self._bot_headers(),
|
||||
json={"content": chunk},
|
||||
)
|
||||
_raise_for_discord_error(response, "send channel message")
|
||||
|
||||
def _bot_headers(self) -> dict[str, str]:
|
||||
if self.bot_token is None:
|
||||
raise RuntimeError("Discord bot token is required for this operation")
|
||||
return {"Authorization": f"Bot {self.bot_token}"}
|
||||
|
||||
def _command_payloads(self) -> list[dict[str, Any]]:
|
||||
payloads = [
|
||||
{
|
||||
"type": _APPLICATION_COMMAND_CHAT_INPUT,
|
||||
"name": self.agent_command,
|
||||
"description": self.agent_command_description,
|
||||
"options": [
|
||||
{
|
||||
"type": _OPTION_STRING,
|
||||
"name": self.agent_command_option,
|
||||
"description": "Prompt for the agent.",
|
||||
"required": True,
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
for command in self._commands:
|
||||
payloads.append({
|
||||
"type": _APPLICATION_COMMAND_CHAT_INPUT,
|
||||
"name": command.name,
|
||||
"description": command.description,
|
||||
"options": [
|
||||
{
|
||||
"type": _OPTION_STRING,
|
||||
"name": "input",
|
||||
"description": "Optional command input.",
|
||||
"required": False,
|
||||
}
|
||||
],
|
||||
})
|
||||
return payloads
|
||||
|
||||
def _validate_configuration(self) -> None:
|
||||
names = [self.agent_command, *(command.name for command in self._commands)]
|
||||
for name in names:
|
||||
if not _COMMAND_NAME_RE.fullmatch(name):
|
||||
raise ValueError(
|
||||
"Discord command names must be lowercase ASCII letters, numbers, hyphen, "
|
||||
f"or underscore, and 1-32 characters long: {name!r}"
|
||||
)
|
||||
if not _COMMAND_NAME_RE.fullmatch(self.agent_command_option):
|
||||
raise ValueError(
|
||||
"Discord agent_command_option must be lowercase ASCII letters, numbers, hyphen, "
|
||||
f"or underscore, and 1-32 characters long: {self.agent_command_option!r}"
|
||||
)
|
||||
if len(set(names)) != len(names):
|
||||
raise ValueError("Discord command names must be unique; agent_command cannot collide with commands")
|
||||
if self._edit_interval < 0:
|
||||
raise ValueError("edit_interval must be >= 0")
|
||||
if self._max_body_bytes <= 0:
|
||||
raise ValueError("max_body_bytes must be > 0")
|
||||
|
||||
|
||||
class _DiscordInteractionReply:
|
||||
"""Reply helper that edits the deferred response first, then sends follow-ups."""
|
||||
|
||||
def __init__(self, channel: DiscordChannel, token: str) -> None:
|
||||
self._channel = channel
|
||||
self._token = token
|
||||
self.sent = False
|
||||
|
||||
async def __call__(self, body: str) -> None:
|
||||
chunks = _split_content(body)
|
||||
if not self.sent:
|
||||
await self._channel._edit_original(self._token, chunks[0]) # pyright: ignore[reportPrivateUsage]
|
||||
self.sent = True
|
||||
for chunk in chunks[1:]:
|
||||
await self._channel._send_followup(self._token, chunk) # pyright: ignore[reportPrivateUsage]
|
||||
return
|
||||
for chunk in chunks:
|
||||
await self._channel._send_followup(self._token, chunk) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
def _user_from_interaction(interaction: DiscordInteraction) -> Mapping[str, Any]:
|
||||
member = interaction.get("member")
|
||||
if isinstance(member, Mapping):
|
||||
member_user = member.get("user")
|
||||
if isinstance(member_user, Mapping):
|
||||
return member_user
|
||||
user = interaction.get("user")
|
||||
if isinstance(user, Mapping):
|
||||
return user
|
||||
raise ValueError("Discord interaction is missing user information")
|
||||
|
||||
|
||||
def _application_command_name(interaction: DiscordInteraction) -> str:
|
||||
data = interaction.get("data")
|
||||
if not isinstance(data, Mapping):
|
||||
raise ValueError("Discord application command interaction is missing data")
|
||||
return _require_string(data.get("name"), "application command name")
|
||||
|
||||
|
||||
def _string_option(interaction: DiscordInteraction, name: str) -> str | None:
|
||||
data = interaction.get("data")
|
||||
if not isinstance(data, Mapping):
|
||||
return None
|
||||
options = data.get("options")
|
||||
if not isinstance(options, Sequence) or isinstance(options, (str, bytes)):
|
||||
return None
|
||||
for option in options:
|
||||
if not isinstance(option, Mapping):
|
||||
continue
|
||||
if option.get("name") != name:
|
||||
continue
|
||||
value = option.get("value")
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
return None
|
||||
|
||||
|
||||
def _payload_text(payload: HostedRunResult[Any]) -> str:
|
||||
text = getattr(payload.result, "text", None)
|
||||
if isinstance(text, str) and text:
|
||||
return text
|
||||
messages = getattr(payload.result, "messages", None)
|
||||
if isinstance(messages, Sequence):
|
||||
for message in reversed(messages):
|
||||
message_text = getattr(message, "text", None)
|
||||
if isinstance(message_text, str) and message_text:
|
||||
return message_text
|
||||
return "(no response)"
|
||||
|
||||
|
||||
def _update_text(update: AgentResponseUpdate) -> str:
|
||||
parts: list[str] = []
|
||||
for content in update.contents:
|
||||
text = getattr(content, "text", None)
|
||||
if isinstance(text, str) and text:
|
||||
parts.append(text)
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def _split_content(content: str) -> list[str]:
|
||||
normalized = _normalize_content(content)
|
||||
return [normalized[i : i + _DISCORD_MAX_CONTENT_LEN] for i in range(0, len(normalized), _DISCORD_MAX_CONTENT_LEN)]
|
||||
|
||||
|
||||
def _stream_preview_content(content: str) -> str:
|
||||
return _split_content(content)[0]
|
||||
|
||||
|
||||
def _normalize_content(content: str) -> str:
|
||||
return content if content else "(no response)"
|
||||
|
||||
|
||||
def _string_or_none(value: Any) -> str | None:
|
||||
return value if isinstance(value, str) and value else None
|
||||
|
||||
|
||||
def _require_string(value: Any, field_name: str) -> str:
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
raise ValueError(f"Discord {field_name} must be a non-empty string")
|
||||
|
||||
|
||||
def _raise_for_discord_error(response: httpx.Response, action: str) -> None:
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
body = response.text[:500]
|
||||
raise RuntimeError(f"Discord {action} failed with HTTP {response.status_code}: {body}") from exc
|
||||
@@ -0,0 +1,107 @@
|
||||
[project]
|
||||
name = "agent-framework-hosting-discord"
|
||||
description = "Discord channel for agent-framework-hosting."
|
||||
authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
version = "1.0.0a260526"
|
||||
license-files = ["LICENSE"]
|
||||
urls.homepage = "https://aka.ms/agent-framework"
|
||||
urls.source = "https://github.com/microsoft/agent-framework/tree/main/python"
|
||||
urls.release_notes = "https://github.com/microsoft/agent-framework/releases?q=tag%3Apython-1&expanded=true"
|
||||
urls.issues = "https://github.com/microsoft/agent-framework/issues"
|
||||
classifiers = [
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Programming Language :: Python :: 3.14",
|
||||
"Typing :: Typed",
|
||||
]
|
||||
dependencies = [
|
||||
"agent-framework-core>=1.2.0,<2",
|
||||
"agent-framework-hosting>=1.0.0a260424,<2",
|
||||
"httpx>=0.27,<1",
|
||||
"PyNaCl>=1.2.0,<2",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
prerelease = "if-necessary-or-explicit"
|
||||
environments = [
|
||||
"sys_platform == 'darwin'",
|
||||
"sys_platform == 'linux'",
|
||||
"sys_platform == 'win32'"
|
||||
]
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
filterwarnings = []
|
||||
timeout = 120
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
"**/__init__.py"
|
||||
]
|
||||
|
||||
[tool.pyright]
|
||||
extends = "../../pyproject.toml"
|
||||
include = ["agent_framework_hosting_discord"]
|
||||
exclude = ['tests']
|
||||
# Discord interactions arrive as loosely-typed JSON maps. Runtime guards narrow
|
||||
# payloads where needed; strict Unknown reporting on every `.get()` is noisy.
|
||||
reportUnknownArgumentType = "none"
|
||||
reportUnknownMemberType = "none"
|
||||
reportUnknownVariableType = "none"
|
||||
reportUnknownLambdaType = "none"
|
||||
reportOptionalMemberAccess = "none"
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ['pydantic.mypy']
|
||||
strict = true
|
||||
python_version = "3.10"
|
||||
ignore_missing_imports = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
check_untyped_defs = true
|
||||
warn_return_any = true
|
||||
show_error_codes = true
|
||||
warn_unused_ignores = false
|
||||
disallow_incomplete_defs = true
|
||||
disallow_untyped_decorators = true
|
||||
|
||||
[tool.bandit]
|
||||
targets = ["agent_framework_hosting_discord"]
|
||||
exclude_dirs = ["tests"]
|
||||
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks.mypy]
|
||||
help = "Run MyPy for this package."
|
||||
cmd = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_hosting_discord"
|
||||
|
||||
[tool.poe.tasks.test]
|
||||
help = "Run the default unit test suite for this package."
|
||||
cmd = 'pytest -m "not integration" --cov=agent_framework_hosting_discord --cov-report=term-missing:skip-covered tests'
|
||||
|
||||
[build-system]
|
||||
requires = ["flit-core >= 3.11,<4.0"]
|
||||
build-backend = "flit_core.buildapi"
|
||||
|
||||
@@ -0,0 +1,680 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from agent_framework import AgentResponse, AgentResponseUpdate, Content, Message
|
||||
from agent_framework_hosting import (
|
||||
ChannelCommand,
|
||||
ChannelCommandContext,
|
||||
ChannelRequest,
|
||||
ChannelResponseContext,
|
||||
HostedRunResult,
|
||||
)
|
||||
from nacl.signing import SigningKey
|
||||
from starlette.applications import Starlette
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from agent_framework_hosting_discord import DiscordChannel, discord_isolation_key
|
||||
|
||||
|
||||
def _run_result(text: str) -> HostedRunResult[AgentResponse]:
|
||||
return HostedRunResult(AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text(text=text)])]))
|
||||
|
||||
|
||||
def _interaction(command: str = "ask", *, prompt: str = "hello", token: str = "token") -> dict[str, Any]:
|
||||
return {
|
||||
"id": "interaction-1",
|
||||
"type": 2,
|
||||
"application_id": "app-1",
|
||||
"token": token,
|
||||
"guild_id": "guild-1",
|
||||
"channel_id": "channel-1",
|
||||
"member": {
|
||||
"user": {
|
||||
"id": "user-1",
|
||||
"username": "ada",
|
||||
"global_name": "Ada",
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"name": command,
|
||||
"options": [{"name": "prompt", "type": 3, "value": prompt}],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _headers(signing_key: SigningKey, body: bytes) -> dict[str, str]:
|
||||
timestamp = "1234567890"
|
||||
signature = signing_key.sign(timestamp.encode("utf-8") + body).signature.hex()
|
||||
return {
|
||||
"x-signature-ed25519": signature,
|
||||
"x-signature-timestamp": timestamp,
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
class _FakeContext:
|
||||
def __init__(self, *, text: str = "agent reply", include_originating: bool = True) -> None:
|
||||
self.target = object()
|
||||
self.text = text
|
||||
self.include_originating = include_originating
|
||||
self.requests: list[ChannelRequest] = []
|
||||
self.delivered: list[tuple[ChannelRequest, HostedRunResult[Any]]] = []
|
||||
self.stream: _FakeStream | None = None
|
||||
|
||||
async def run(self, request: ChannelRequest) -> HostedRunResult[AgentResponse]:
|
||||
self.requests.append(request)
|
||||
return _run_result(self.text)
|
||||
|
||||
def run_stream(self, request: ChannelRequest) -> _FakeStream:
|
||||
self.requests.append(request)
|
||||
if self.stream is None:
|
||||
self.stream = _FakeStream(["a", "b"])
|
||||
return self.stream
|
||||
|
||||
async def deliver_response(self, request: ChannelRequest, payload: HostedRunResult[Any]) -> bool:
|
||||
self.delivered.append((request, payload))
|
||||
return self.include_originating
|
||||
|
||||
|
||||
class _FakeStream:
|
||||
def __init__(self, chunks: list[str]) -> None:
|
||||
self._chunks = chunks
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[AgentResponseUpdate]:
|
||||
return self._iter()
|
||||
|
||||
async def _iter(self) -> AsyncIterator[AgentResponseUpdate]:
|
||||
for chunk in self._chunks:
|
||||
yield AgentResponseUpdate(contents=[Content.from_text(text=chunk)], role="assistant")
|
||||
|
||||
|
||||
class _DiscordRecorder:
|
||||
def __init__(self) -> None:
|
||||
self.requests: list[httpx.Request] = []
|
||||
self.json_payloads: list[Any] = []
|
||||
|
||||
def transport(self) -> httpx.MockTransport:
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
self.requests.append(request)
|
||||
if request.content:
|
||||
self.json_payloads.append(json.loads(request.content.decode("utf-8")))
|
||||
return httpx.Response(200, json={"ok": True})
|
||||
|
||||
return httpx.MockTransport(handler)
|
||||
|
||||
|
||||
def test_discord_isolation_key_scopes_to_guild_channel_user() -> None:
|
||||
assert discord_isolation_key("guild", "channel", "user") == "discord:guild:channel:user"
|
||||
assert discord_isolation_key(None, "dm-channel", "user") == "discord:dm:dm-channel:user"
|
||||
|
||||
|
||||
def test_ping_requires_valid_signature_and_returns_pong() -> None:
|
||||
signing_key = SigningKey.generate()
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=signing_key.verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
)
|
||||
app = Starlette(routes=list(channel.contribute(_FakeContext()).routes)) # type: ignore[arg-type]
|
||||
body = json.dumps({"type": 1}).encode("utf-8")
|
||||
|
||||
with TestClient(app) as client:
|
||||
ok = client.post("/interactions", content=body, headers=_headers(signing_key, body))
|
||||
bad = client.post(
|
||||
"/interactions",
|
||||
content=body,
|
||||
headers={
|
||||
**_headers(signing_key, body),
|
||||
"x-signature-ed25519": "00" * 64,
|
||||
},
|
||||
)
|
||||
|
||||
assert ok.status_code == 200
|
||||
assert ok.json() == {"type": 1}
|
||||
assert bad.status_code == 401
|
||||
|
||||
|
||||
def test_request_validation_errors() -> None:
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
skip_signature_verification=True,
|
||||
max_body_bytes=2,
|
||||
)
|
||||
app = Starlette(routes=list(channel.contribute(_FakeContext()).routes)) # type: ignore[arg-type]
|
||||
unsupported_channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
skip_signature_verification=True,
|
||||
)
|
||||
unsupported_app = Starlette(routes=list(unsupported_channel.contribute(_FakeContext()).routes)) # type: ignore[arg-type]
|
||||
|
||||
with TestClient(app) as client:
|
||||
too_large = client.post("/interactions", content=b"{}x")
|
||||
invalid_json = client.post("/interactions", content=b"{")
|
||||
with TestClient(unsupported_app) as client:
|
||||
non_object = client.post("/interactions", json=[])
|
||||
unsupported = client.post("/interactions", json={"type": 99})
|
||||
|
||||
assert too_large.status_code == 413
|
||||
assert invalid_json.status_code == 400
|
||||
assert non_object.status_code == 400
|
||||
assert unsupported.status_code == 400
|
||||
|
||||
|
||||
def test_constructor_validates_discord_configuration() -> None:
|
||||
public_key = SigningKey.generate().verify_key.encode().hex()
|
||||
|
||||
with pytest.raises(ValueError, match="public_key"):
|
||||
DiscordChannel(application_id="app-1", public_key="not-hex")
|
||||
with pytest.raises(ValueError, match="command names"):
|
||||
DiscordChannel(application_id="app-1", public_key=public_key, agent_command="Ask")
|
||||
with pytest.raises(ValueError, match="unique"):
|
||||
DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=public_key,
|
||||
commands=[ChannelCommand(name="ask", description="Ask again", handle=lambda _ctx: _noop())],
|
||||
)
|
||||
with pytest.raises(ValueError, match="edit_interval"):
|
||||
DiscordChannel(application_id="app-1", public_key=public_key, edit_interval=-1)
|
||||
with pytest.raises(ValueError, match="max_body_bytes"):
|
||||
DiscordChannel(application_id="app-1", public_key=public_key, max_body_bytes=0)
|
||||
|
||||
|
||||
async def test_agent_command_runs_host_and_edits_original_response() -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
context = _FakeContext(text="agent says hi")
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
skip_signature_verification=True,
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel.contribute(context) # type: ignore[arg-type]
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
|
||||
await channel._run_agent_command(_interaction(prompt="what now?"), "token")
|
||||
|
||||
assert context.requests[0].operation == "message.create"
|
||||
assert context.requests[0].input == "what now?"
|
||||
assert context.requests[0].session is not None
|
||||
assert context.requests[0].session.isolation_key == "discord:guild-1:channel-1:user-1"
|
||||
assert context.requests[0].identity is not None
|
||||
assert context.requests[0].identity.native_id == "user-1"
|
||||
assert context.requests[0].identity.attributes["channel_id"] == "channel-1"
|
||||
assert len(context.delivered) == 1
|
||||
assert recorder.requests[0].method == "PATCH"
|
||||
assert recorder.requests[0].url.path == "/webhooks/app-1/token/messages/@original"
|
||||
assert recorder.json_payloads[0] == {"content": "agent says hi"}
|
||||
|
||||
|
||||
async def test_run_hook_can_rewrite_agent_request() -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
context = _FakeContext(text="agent says hi")
|
||||
|
||||
async def hook(request: ChannelRequest, **_: Any) -> ChannelRequest:
|
||||
return ChannelRequest(
|
||||
channel=request.channel,
|
||||
operation=request.operation,
|
||||
input="rewritten",
|
||||
session=request.session,
|
||||
metadata=request.metadata,
|
||||
attributes=request.attributes,
|
||||
stream=request.stream,
|
||||
identity=request.identity,
|
||||
response_target=request.response_target,
|
||||
)
|
||||
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
run_hook=hook,
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel.contribute(context) # type: ignore[arg-type]
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
|
||||
await channel._run_agent_command(_interaction(prompt="original"), "token")
|
||||
|
||||
assert context.requests[0].input == "rewritten"
|
||||
|
||||
|
||||
async def test_response_hook_rewrites_originating_reply() -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
context = _FakeContext(text="original")
|
||||
|
||||
async def hook(result: HostedRunResult[Any], *, context: ChannelResponseContext) -> HostedRunResult[Any]:
|
||||
assert context.originating is True
|
||||
assert result.result.text == "original"
|
||||
return _run_result("rewritten")
|
||||
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
response_hook=hook,
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel.contribute(context) # type: ignore[arg-type]
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
|
||||
await channel._run_agent_command(_interaction(), "token")
|
||||
|
||||
assert recorder.json_payloads[-1] == {"content": "rewritten"}
|
||||
|
||||
|
||||
async def test_deliver_response_false_acknowledges_without_originating_payload() -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
context = _FakeContext(text="fanout only", include_originating=False)
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel.contribute(context) # type: ignore[arg-type]
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
|
||||
await channel._run_agent_command(_interaction(), "token")
|
||||
|
||||
assert recorder.json_payloads[-1] == {"content": "Sent."}
|
||||
|
||||
|
||||
async def test_missing_prompt_edits_original_without_calling_host() -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
context = _FakeContext(text="should not run")
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel.contribute(context) # type: ignore[arg-type]
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
interaction = _interaction()
|
||||
interaction["data"]["options"] = []
|
||||
|
||||
await channel._run_agent_command(interaction, "token")
|
||||
|
||||
assert context.requests == []
|
||||
assert recorder.json_payloads[-1] == {"content": "Missing required `prompt` option."}
|
||||
|
||||
|
||||
async def test_dispatch_application_command_routes_agent_command() -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
context = _FakeContext(text="dispatched")
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel.contribute(context) # type: ignore[arg-type]
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
|
||||
await channel._dispatch_application_command(_interaction(command="ask"))
|
||||
|
||||
assert context.requests[0].operation == "message.create"
|
||||
assert recorder.json_payloads[-1] == {"content": "dispatched"}
|
||||
|
||||
|
||||
async def test_channel_command_handler_receives_context_and_replies() -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
captured: list[ChannelCommandContext] = []
|
||||
|
||||
async def handler(ctx: ChannelCommandContext) -> None:
|
||||
captured.append(ctx)
|
||||
await ctx.reply("reset done")
|
||||
|
||||
command = ChannelCommand(name="reset", description="Reset", handle=handler)
|
||||
context = _FakeContext()
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
commands=[command],
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel.contribute(context) # type: ignore[arg-type]
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
interaction = _interaction(command="reset")
|
||||
interaction["data"]["options"] = [{"name": "input", "type": 3, "value": "please"}]
|
||||
|
||||
await channel._run_channel_command(command, interaction, "token")
|
||||
|
||||
assert captured
|
||||
assert captured[0].request.operation == "command.invoke"
|
||||
assert captured[0].request.input == "/reset please"
|
||||
assert recorder.json_payloads == [{"content": "reset done"}]
|
||||
|
||||
|
||||
async def test_channel_command_reply_sends_followups_after_first_edit() -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
|
||||
async def handler(ctx: ChannelCommandContext) -> None:
|
||||
await ctx.reply("first")
|
||||
await ctx.reply("second")
|
||||
|
||||
command = ChannelCommand(name="reset", description="Reset", handle=handler)
|
||||
context = _FakeContext()
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
commands=[command],
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel.contribute(context) # type: ignore[arg-type]
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
|
||||
await channel._run_channel_command(command, _interaction(command="reset"), "token")
|
||||
|
||||
assert [request.method for request in recorder.requests] == ["PATCH", "POST"]
|
||||
assert recorder.json_payloads == [{"content": "first"}, {"content": "second"}]
|
||||
|
||||
|
||||
async def test_channel_command_reply_chunks_long_content() -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
|
||||
async def handler(ctx: ChannelCommandContext) -> None:
|
||||
await ctx.reply("a" * 2001)
|
||||
|
||||
command = ChannelCommand(name="reset", description="Reset", handle=handler)
|
||||
context = _FakeContext()
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
commands=[command],
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel.contribute(context) # type: ignore[arg-type]
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
|
||||
await channel._run_channel_command(command, _interaction(command="reset"), "token")
|
||||
|
||||
assert [request.method for request in recorder.requests] == ["PATCH", "POST"]
|
||||
assert [len(payload["content"]) for payload in recorder.json_payloads] == [2000, 1]
|
||||
|
||||
|
||||
async def test_channel_command_edits_done_when_handler_does_not_reply() -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
|
||||
async def handler(_ctx: ChannelCommandContext) -> None:
|
||||
return None
|
||||
|
||||
command = ChannelCommand(name="reset", description="Reset", handle=handler)
|
||||
context = _FakeContext()
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
commands=[command],
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel.contribute(context) # type: ignore[arg-type]
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
|
||||
await channel._run_channel_command(command, _interaction(command="reset"), "token")
|
||||
|
||||
assert recorder.json_payloads == [{"content": "Done."}]
|
||||
|
||||
|
||||
async def test_unknown_command_edits_error_response() -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
|
||||
await channel._dispatch_application_command(_interaction(command="missing"))
|
||||
|
||||
assert recorder.json_payloads == [{"content": "Unknown Discord command: missing"}]
|
||||
|
||||
|
||||
async def test_startup_bulk_registers_guild_commands() -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
command = ChannelCommand(name="reset", description="Reset", handle=lambda _ctx: _noop())
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
bot_token="bot-token",
|
||||
guild_id="guild-1",
|
||||
commands=[command],
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
|
||||
await channel._on_startup()
|
||||
|
||||
assert recorder.requests[0].method == "PUT"
|
||||
assert recorder.requests[0].url.path == "/applications/app-1/guilds/guild-1/commands"
|
||||
assert recorder.requests[0].headers["authorization"] == "Bot bot-token"
|
||||
assert [payload["name"] for payload in recorder.json_payloads[0]] == ["ask", "reset"]
|
||||
|
||||
|
||||
async def test_global_startup_registration_warns_about_propagation(caplog: pytest.LogCaptureFixture) -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
bot_token="bot-token",
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
|
||||
await channel._on_startup()
|
||||
|
||||
assert recorder.requests[0].url.path == "/applications/app-1/commands"
|
||||
assert "global slash commands" in caplog.text
|
||||
|
||||
|
||||
async def test_startup_warns_when_registration_has_no_bot_token(caplog: pytest.LogCaptureFixture) -> None:
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
)
|
||||
|
||||
await channel._on_startup()
|
||||
await channel._on_shutdown()
|
||||
|
||||
assert "slash commands must be registered outside the host" in caplog.text
|
||||
|
||||
|
||||
async def test_originating_reply_sends_followup_chunks() -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
context = _FakeContext(text="a" * 2001)
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel.contribute(context) # type: ignore[arg-type]
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
|
||||
await channel._run_agent_command(_interaction(), "token")
|
||||
|
||||
assert [request.method for request in recorder.requests] == ["PATCH", "POST"]
|
||||
assert [len(payload["content"]) for payload in recorder.json_payloads] == [2000, 1]
|
||||
|
||||
|
||||
async def test_push_requires_channel_id_and_sends_chunked_messages() -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
bot_token="bot-token",
|
||||
register_commands=False,
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
|
||||
await channel.push(
|
||||
identity=channel._identity_from_interaction(_interaction()), # pyright: ignore[reportPrivateUsage]
|
||||
payload=_run_result("a" * 2001),
|
||||
)
|
||||
|
||||
assert [request.url.path for request in recorder.requests] == [
|
||||
"/channels/channel-1/messages",
|
||||
"/channels/channel-1/messages",
|
||||
]
|
||||
assert [len(payload["content"]) for payload in recorder.json_payloads] == [2000, 1]
|
||||
|
||||
|
||||
async def test_push_renders_no_response_for_unknown_payload_shape() -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
bot_token="bot-token",
|
||||
register_commands=False,
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
|
||||
await channel.push(
|
||||
identity=channel._identity_from_interaction(_interaction()), # pyright: ignore[reportPrivateUsage]
|
||||
payload=HostedRunResult(object()),
|
||||
)
|
||||
|
||||
assert recorder.json_payloads == [{"content": "(no response)"}]
|
||||
|
||||
|
||||
async def test_push_requires_bot_token_and_channel_id() -> None:
|
||||
identity = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
)._identity_from_interaction(_interaction()) # pyright: ignore[reportPrivateUsage]
|
||||
no_bot_token = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
)
|
||||
no_channel_id = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
bot_token="bot-token",
|
||||
register_commands=False,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="bot_token"):
|
||||
await no_bot_token.push(identity=identity, payload=_run_result("hello"))
|
||||
with pytest.raises(ValueError, match="channel_id"):
|
||||
await no_channel_id.push(
|
||||
identity=type(identity)(channel=identity.channel, native_id=identity.native_id, attributes={}),
|
||||
payload=_run_result("hello"),
|
||||
)
|
||||
|
||||
|
||||
async def test_streaming_edits_original_and_delivers_final_response() -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
context = _FakeContext()
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
streaming=True,
|
||||
edit_interval=0,
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel.contribute(context) # type: ignore[arg-type]
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
|
||||
await channel._run_agent_command(_interaction(), "token")
|
||||
|
||||
assert [payload["content"] for payload in recorder.json_payloads] == ["a", "ab", "ab"]
|
||||
assert len(context.delivered) == 1
|
||||
assert context.delivered[0][1].result.text == "ab"
|
||||
|
||||
|
||||
async def test_streaming_preview_is_limited_and_final_reply_is_chunked() -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
context = _FakeContext()
|
||||
context.stream = _FakeStream(["a" * 2001])
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
streaming=True,
|
||||
edit_interval=0,
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel.contribute(context) # type: ignore[arg-type]
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
|
||||
await channel._run_agent_command(_interaction(), "token")
|
||||
|
||||
assert [request.method for request in recorder.requests] == ["PATCH", "PATCH", "POST"]
|
||||
assert [len(payload["content"]) for payload in recorder.json_payloads] == [2000, 2000, 1]
|
||||
assert len(context.delivered[0][1].result.text) == 2001
|
||||
|
||||
|
||||
async def test_stream_transform_hook_can_drop_updates_and_disable_originating_reply() -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
context = _FakeContext(include_originating=False)
|
||||
|
||||
async def hook(update: AgentResponseUpdate) -> AgentResponseUpdate | None:
|
||||
if update.text == "a":
|
||||
return None
|
||||
return update
|
||||
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
streaming=True,
|
||||
stream_transform_hook=hook,
|
||||
edit_interval=0,
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel.contribute(context) # type: ignore[arg-type]
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
|
||||
await channel._run_agent_command(_interaction(), "token")
|
||||
|
||||
assert [payload["content"] for payload in recorder.json_payloads] == ["b", "Sent."]
|
||||
assert context.delivered[0][1].result.text == "b"
|
||||
|
||||
|
||||
async def test_stream_transform_hook_can_synchronously_rewrite_updates() -> None:
|
||||
recorder = _DiscordRecorder()
|
||||
context = _FakeContext()
|
||||
|
||||
def hook(_update: AgentResponseUpdate) -> AgentResponseUpdate:
|
||||
return AgentResponseUpdate(contents=[Content.from_text(text="x")], role="assistant")
|
||||
|
||||
channel = DiscordChannel(
|
||||
application_id="app-1",
|
||||
public_key=SigningKey.generate().verify_key.encode().hex(),
|
||||
register_commands=False,
|
||||
streaming=True,
|
||||
stream_transform_hook=hook,
|
||||
edit_interval=0,
|
||||
api_base_url="https://discord.test",
|
||||
)
|
||||
channel.contribute(context) # type: ignore[arg-type]
|
||||
channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport())
|
||||
|
||||
await channel._run_agent_command(_interaction(), "token")
|
||||
|
||||
assert [payload["content"] for payload in recorder.json_payloads] == ["x", "xx", "xx"]
|
||||
|
||||
|
||||
async def _noop() -> None:
|
||||
return None
|
||||
@@ -71,6 +71,7 @@ from ._types import (
|
||||
RetryPolicy,
|
||||
TaskHandle,
|
||||
TaskStatus,
|
||||
apply_channel_response_hook,
|
||||
apply_response_hook,
|
||||
apply_run_hook,
|
||||
)
|
||||
@@ -134,6 +135,7 @@ __all__ = [
|
||||
"TaskHandle",
|
||||
"TaskStatus",
|
||||
"__version__",
|
||||
"apply_channel_response_hook",
|
||||
"apply_response_hook",
|
||||
"apply_run_hook",
|
||||
"get_current_isolation_keys",
|
||||
|
||||
@@ -81,15 +81,13 @@ from ._types import (
|
||||
ChannelPush,
|
||||
ChannelPushCodec,
|
||||
ChannelRequest,
|
||||
ChannelResponseContext,
|
||||
ChannelResponseHook,
|
||||
DurableTaskPayloadMode,
|
||||
DurableTaskRunner,
|
||||
HostedRunResult,
|
||||
HostStatePaths,
|
||||
PushPayloadNotSerializable,
|
||||
ResponseTargetKind,
|
||||
apply_response_hook,
|
||||
apply_channel_response_hook,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -1896,17 +1894,15 @@ class AgentFrameworkHost:
|
||||
contract; richer surfaces stay attribute-level so adding hook
|
||||
support to a new channel does not require updating the Protocol.
|
||||
"""
|
||||
shaped: HostedRunResult[Any] = payload.replace()
|
||||
hook = cast(ChannelResponseHook | None, getattr(channel, "response_hook", None))
|
||||
if callable(hook):
|
||||
ctx = ChannelResponseContext(
|
||||
request=request,
|
||||
channel_name=channel.name,
|
||||
destination_identity=identity,
|
||||
originating=False,
|
||||
is_echo=is_echo,
|
||||
)
|
||||
shaped = await apply_response_hook(hook, shaped, context=ctx)
|
||||
shaped = await apply_channel_response_hook(
|
||||
channel,
|
||||
payload,
|
||||
request=request,
|
||||
destination_identity=identity,
|
||||
originating=False,
|
||||
is_echo=is_echo,
|
||||
clone=True,
|
||||
)
|
||||
await channel.push(identity, shaped)
|
||||
return shaped
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ import os
|
||||
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypedDict, TypeVar, runtime_checkable
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypedDict, TypeVar, cast, runtime_checkable
|
||||
|
||||
from agent_framework import (
|
||||
AgentResponse,
|
||||
@@ -760,6 +760,52 @@ class ChannelPush(Protocol):
|
||||
async def push(self, identity: ChannelIdentity, payload: HostedRunResult[Any]) -> None: ...
|
||||
|
||||
|
||||
async def apply_channel_response_hook(
|
||||
channel: Channel | ChannelPush,
|
||||
result: HostedRunResult[Any],
|
||||
*,
|
||||
request: ChannelRequest,
|
||||
originating: bool,
|
||||
destination_identity: ChannelIdentity | None = None,
|
||||
is_echo: bool = False,
|
||||
clone: bool = False,
|
||||
) -> HostedRunResult[Any]:
|
||||
"""Apply a channel's optional response hook with the standard context.
|
||||
|
||||
Channels and the host call this helper when they need to shape a
|
||||
:class:`HostedRunResult` for one destination. The helper centralizes the
|
||||
response-hook convention: hooks are discovered from a duck-typed
|
||||
``response_hook`` attribute, called through :func:`apply_response_hook`,
|
||||
and receive a :class:`ChannelResponseContext` that identifies the channel,
|
||||
destination identity, originating-vs-push phase, and echo phase.
|
||||
|
||||
Args:
|
||||
channel: Channel whose ``response_hook`` attribute may shape the payload.
|
||||
result: Hosted run result to pass to the hook.
|
||||
request: Originating channel request.
|
||||
originating: Whether this is the originating channel's synchronous reply.
|
||||
destination_identity: Destination identity for non-originating pushes, or
|
||||
``None`` for originating replies.
|
||||
is_echo: Whether the payload is an echo of the user input.
|
||||
clone: Whether to shallow-clone ``result`` before applying the hook.
|
||||
|
||||
Returns:
|
||||
The original, cloned, or hook-shaped hosted run result.
|
||||
"""
|
||||
shaped = result.replace() if clone else result
|
||||
hook = cast(ChannelResponseHook | None, getattr(channel, "response_hook", None))
|
||||
if not callable(hook):
|
||||
return shaped
|
||||
context = ChannelResponseContext(
|
||||
request=request,
|
||||
channel_name=channel.name,
|
||||
destination_identity=destination_identity,
|
||||
originating=originating,
|
||||
is_echo=is_echo,
|
||||
)
|
||||
return await apply_response_hook(hook, shaped, context=context)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Durable task runner — pluggable seam for non-originating push fan-out and
|
||||
# (in v1 fast-follow) background runs. See spec §"Durable task runner".
|
||||
@@ -910,6 +956,7 @@ __all__ = [
|
||||
"RetryPolicy",
|
||||
"TaskHandle",
|
||||
"TaskStatus",
|
||||
"apply_channel_response_hook",
|
||||
"apply_response_hook",
|
||||
"apply_run_hook",
|
||||
]
|
||||
|
||||
@@ -7,12 +7,16 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
from agent_framework_hosting import (
|
||||
ChannelContribution,
|
||||
ChannelIdentity,
|
||||
ChannelRequest,
|
||||
ChannelResponseContext,
|
||||
ChannelSession,
|
||||
DurableTaskPayloadMode,
|
||||
HostedRunResult,
|
||||
ResponseTarget,
|
||||
ResponseTargetKind,
|
||||
apply_channel_response_hook,
|
||||
apply_run_hook,
|
||||
)
|
||||
|
||||
@@ -117,6 +121,88 @@ class _DummyTarget:
|
||||
"""
|
||||
|
||||
|
||||
class _DummyChannel:
|
||||
name = "dummy"
|
||||
path = "/dummy"
|
||||
|
||||
def contribute(self, _context: Any) -> ChannelContribution:
|
||||
return ChannelContribution()
|
||||
|
||||
|
||||
class TestApplyChannelResponseHook:
|
||||
async def test_originating_hook_receives_standard_context(self) -> None:
|
||||
request = ChannelRequest(channel="discord", operation="message.create", input="hi")
|
||||
payload = HostedRunResult("original")
|
||||
captured: list[ChannelResponseContext] = []
|
||||
|
||||
async def hook(
|
||||
result: HostedRunResult[Any],
|
||||
*,
|
||||
context: ChannelResponseContext,
|
||||
) -> HostedRunResult[Any]:
|
||||
captured.append(context)
|
||||
return result.replace(result="hooked")
|
||||
|
||||
channel = _DummyChannel()
|
||||
channel.response_hook = hook # type: ignore[attr-defined]
|
||||
|
||||
shaped = await apply_channel_response_hook(channel, payload, request=request, originating=True)
|
||||
|
||||
assert shaped.result == "hooked"
|
||||
assert captured[0].request is request
|
||||
assert captured[0].channel_name == "dummy"
|
||||
assert captured[0].destination_identity is None
|
||||
assert captured[0].originating is True
|
||||
assert captured[0].is_echo is False
|
||||
|
||||
async def test_non_originating_hook_can_clone_before_shaping(self) -> None:
|
||||
request = ChannelRequest(channel="responses", operation="message.create", input="hi")
|
||||
identity = ChannelIdentity(channel="dummy", native_id="user-1")
|
||||
payload = HostedRunResult("original")
|
||||
seen_payloads: list[HostedRunResult[Any]] = []
|
||||
seen_contexts: list[ChannelResponseContext] = []
|
||||
|
||||
def hook(
|
||||
result: HostedRunResult[Any],
|
||||
*,
|
||||
context: ChannelResponseContext,
|
||||
) -> HostedRunResult[Any]:
|
||||
seen_payloads.append(result)
|
||||
seen_contexts.append(context)
|
||||
return result.replace(result="hooked")
|
||||
|
||||
channel = _DummyChannel()
|
||||
channel.response_hook = hook # type: ignore[attr-defined]
|
||||
|
||||
shaped = await apply_channel_response_hook(
|
||||
channel,
|
||||
payload,
|
||||
request=request,
|
||||
destination_identity=identity,
|
||||
originating=False,
|
||||
is_echo=True,
|
||||
clone=True,
|
||||
)
|
||||
|
||||
assert seen_payloads[0] is not payload
|
||||
assert shaped.result == "hooked"
|
||||
assert seen_contexts[0].destination_identity is identity
|
||||
assert seen_contexts[0].originating is False
|
||||
assert seen_contexts[0].is_echo is True
|
||||
|
||||
async def test_missing_hook_returns_payload_or_clone(self) -> None:
|
||||
request = ChannelRequest(channel="responses", operation="message.create", input="hi")
|
||||
payload = HostedRunResult("original")
|
||||
channel = _DummyChannel()
|
||||
|
||||
same = await apply_channel_response_hook(channel, payload, request=request, originating=True)
|
||||
cloned = await apply_channel_response_hook(channel, payload, request=request, originating=True, clone=True)
|
||||
|
||||
assert same is payload
|
||||
assert cloned is not payload
|
||||
assert cloned.result == payload.result
|
||||
|
||||
|
||||
class TestApplyRunHook:
|
||||
"""`apply_run_hook` is the channel-side helper that invokes a
|
||||
`ChannelRunHook` with the standard kwargs (`request` positional,
|
||||
|
||||
@@ -90,6 +90,7 @@ agent-framework-hosting-invocations = { workspace = true }
|
||||
agent-framework-hosting-telegram = { workspace = true }
|
||||
agent-framework-hosting-activity-protocol = { workspace = true }
|
||||
agent-framework-hosting-entra = { workspace = true }
|
||||
agent-framework-hosting-discord = { workspace = true }
|
||||
agent-framework-hyperlight = { workspace = true }
|
||||
agent-framework-lab = { workspace = true }
|
||||
agent-framework-mem0 = { workspace = true }
|
||||
|
||||
Generated
+76
-31
@@ -48,11 +48,12 @@ members = [
|
||||
"agent-framework-gemini",
|
||||
"agent-framework-github-copilot",
|
||||
"agent-framework-hosting",
|
||||
"agent-framework-hosting-responses",
|
||||
"agent-framework-hosting-invocations",
|
||||
"agent-framework-hosting-telegram",
|
||||
"agent-framework-hosting-activity-protocol",
|
||||
"agent-framework-hosting-discord",
|
||||
"agent-framework-hosting-entra",
|
||||
"agent-framework-hosting-invocations",
|
||||
"agent-framework-hosting-responses",
|
||||
"agent-framework-hosting-telegram",
|
||||
"agent-framework-hyperlight",
|
||||
"agent-framework-lab",
|
||||
"agent-framework-mem0",
|
||||
@@ -648,13 +649,41 @@ provides-extras = ["serve", "disk"]
|
||||
dev = [{ name = "httpx", specifier = ">=0.28.1" }]
|
||||
|
||||
[[package]]
|
||||
name = "agent-framework-hosting-telegram"
|
||||
name = "agent-framework-hosting-activity-protocol"
|
||||
version = "1.0.0a260424"
|
||||
source = { editable = "packages/hosting-telegram" }
|
||||
source = { editable = "packages/hosting-activity-protocol" }
|
||||
dependencies = [
|
||||
{ name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "agent-framework-hosting", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "azure-identity", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "agent-framework-core", editable = "packages/core" },
|
||||
{ name = "agent-framework-hosting", editable = "packages/hosting" },
|
||||
{ name = "azure-identity", specifier = ">=1.20,<2" },
|
||||
{ name = "httpx", specifier = ">=0.27,<1" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent-framework-hosting-discord"
|
||||
version = "1.0.0a260526"
|
||||
source = { editable = "packages/hosting-discord" }
|
||||
dependencies = [
|
||||
{ name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "agent-framework-hosting", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "pynacl", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "agent-framework-core", editable = "packages/core" },
|
||||
{ name = "agent-framework-hosting", editable = "packages/hosting" },
|
||||
{ name = "httpx", specifier = ">=0.27,<1" },
|
||||
{ name = "pynacl", specifier = ">=1.2.0,<2" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -669,32 +698,13 @@ dependencies = [
|
||||
{ name = "msal", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent-framework-hosting-responses"
|
||||
version = "1.0.0a260424"
|
||||
source = { editable = "packages/hosting-responses" }
|
||||
dependencies = [
|
||||
{ name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "agent-framework-hosting", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "agent-framework-core", editable = "packages/core" },
|
||||
{ name = "agent-framework-hosting", editable = "packages/hosting" },
|
||||
{ name = "openai", specifier = ">=1.99.0,<3" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent-framework-hosting-activity-protocol"
|
||||
version = "1.0.0a260424"
|
||||
source = { editable = "packages/hosting-activity-protocol" }
|
||||
dependencies = [
|
||||
{ name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "agent-framework-hosting", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "azure-identity", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "cryptography", specifier = ">=42" },
|
||||
{ name = "httpx", specifier = ">=0.27,<1" },
|
||||
{ name = "msal", specifier = ">=1.28,<2" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -704,11 +714,46 @@ source = { editable = "packages/hosting-invocations" }
|
||||
dependencies = [
|
||||
{ name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "agent-framework-hosting", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "azure-identity", specifier = ">=1.20,<2" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "agent-framework-core", editable = "packages/core" },
|
||||
{ name = "agent-framework-hosting", editable = "packages/hosting" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent-framework-hosting-responses"
|
||||
version = "1.0.0a260424"
|
||||
source = { editable = "packages/hosting-responses" }
|
||||
dependencies = [
|
||||
{ name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "agent-framework-hosting", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "agent-framework-core", editable = "packages/core" },
|
||||
{ name = "agent-framework-hosting", editable = "packages/hosting" },
|
||||
{ name = "openai", specifier = ">=1.99.0,<3" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent-framework-hosting-telegram"
|
||||
version = "1.0.0a260424"
|
||||
source = { editable = "packages/hosting-telegram" }
|
||||
dependencies = [
|
||||
{ name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "agent-framework-hosting", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "agent-framework-core", editable = "packages/core" },
|
||||
{ name = "agent-framework-hosting", editable = "packages/hosting" },
|
||||
{ name = "httpx", specifier = ">=0.27,<1" },
|
||||
{ name = "cryptography", specifier = ">=42" },
|
||||
{ name = "httpx", specifier = ">=0.27,<1" },
|
||||
{ name = "msal", specifier = ">=1.28,<2" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user