mirror of
https://github.com/pchuan98/codex.git
synced 2026-07-01 00:31:56 +08:00
app-server: move transport into dedicated crate (#20545)
## Why `codex-app-server` currently owns both request-processing code and transport implementation details. Splitting the transport layer into its own crate makes that boundary explicit, reduces the amount of transport-specific dependency surface carried by `codex-app-server`, and gives future transport work a narrower place to evolve. ## What changed - Added `codex-app-server-transport` and moved the existing transport tree into it, including stdio, unix socket, websocket, remote-control transport, and websocket auth. - Moved shared transport-facing message types into the new crate so both the transport implementation and `codex-app-server` use the same definitions. - Kept processor-facing connection state and outbound routing in `codex-app-server`, with the routing tests moved next to that local wrapper. - Updated workspace metadata, Bazel crate metadata, and `codex-app-server` dependencies for the new crate boundary. ## Validation - `cargo metadata --locked --no-deps` - `git diff --check` - Attempted `cargo test -p codex-app-server-transport`, `cargo test -p codex-app-server`, `just fix -p codex-app-server-transport`, and `just fix -p codex-app-server`; all were blocked before compilation by the existing `packageproxy` resolution failure for locked `rustls-webpki = 0.103.13`. - Attempted Bazel build / lockfile validation; those were blocked by external fetch failures against BuildBuddy / GitHub while resolving `v8`.
This commit is contained in:
committed by
GitHub
Unverified
parent
5744b85b9a
commit
41e171fcf2
Generated
+40
-7
@@ -1857,8 +1857,8 @@ dependencies = [
|
||||
"chrono",
|
||||
"clap",
|
||||
"codex-analytics",
|
||||
"codex-api",
|
||||
"codex-app-server-protocol",
|
||||
"codex-app-server-transport",
|
||||
"codex-arg0",
|
||||
"codex-backend-client",
|
||||
"codex-chatgpt",
|
||||
@@ -1891,23 +1891,17 @@ dependencies = [
|
||||
"codex-state",
|
||||
"codex-thread-store",
|
||||
"codex-tools",
|
||||
"codex-uds",
|
||||
"codex-utils-absolute-path",
|
||||
"codex-utils-cargo-bin",
|
||||
"codex-utils-cli",
|
||||
"codex-utils-json-to-toml",
|
||||
"codex-utils-pty",
|
||||
"codex-utils-rustls-provider",
|
||||
"constant_time_eq 0.3.1",
|
||||
"core_test_support",
|
||||
"flate2",
|
||||
"futures",
|
||||
"gethostname",
|
||||
"hmac",
|
||||
"jsonwebtoken",
|
||||
"opentelemetry",
|
||||
"opentelemetry_sdk",
|
||||
"owo-colors",
|
||||
"pretty_assertions",
|
||||
"reqwest",
|
||||
"rmcp",
|
||||
@@ -2005,6 +1999,45 @@ dependencies = [
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-app-server-transport"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
"clap",
|
||||
"codex-api",
|
||||
"codex-app-server-protocol",
|
||||
"codex-config",
|
||||
"codex-core",
|
||||
"codex-login",
|
||||
"codex-model-provider",
|
||||
"codex-state",
|
||||
"codex-uds",
|
||||
"codex-utils-absolute-path",
|
||||
"codex-utils-rustls-provider",
|
||||
"constant_time_eq 0.3.1",
|
||||
"futures",
|
||||
"gethostname",
|
||||
"hmac",
|
||||
"jsonwebtoken",
|
||||
"owo-colors",
|
||||
"pretty_assertions",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"tempfile",
|
||||
"time",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"url",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-apply-patch"
|
||||
version = "0.0.0"
|
||||
|
||||
@@ -8,6 +8,7 @@ members = [
|
||||
"ansi-escape",
|
||||
"async-utils",
|
||||
"app-server",
|
||||
"app-server-transport",
|
||||
"app-server-client",
|
||||
"app-server-protocol",
|
||||
"app-server-test-client",
|
||||
@@ -127,6 +128,7 @@ codex-ansi-escape = { path = "ansi-escape" }
|
||||
codex-api = { path = "codex-api" }
|
||||
codex-aws-auth = { path = "aws-auth" }
|
||||
codex-app-server = { path = "app-server" }
|
||||
codex-app-server-transport = { path = "app-server-transport" }
|
||||
codex-app-server-client = { path = "app-server-client" }
|
||||
codex-app-server-protocol = { path = "app-server-protocol" }
|
||||
codex-app-server-test-client = { path = "app-server-test-client" }
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
load("//:defs.bzl", "codex_rust_crate")
|
||||
|
||||
codex_rust_crate(
|
||||
name = "app-server-transport",
|
||||
crate_name = "codex_app_server_transport",
|
||||
)
|
||||
@@ -0,0 +1,58 @@
|
||||
[package]
|
||||
name = "codex-app-server-transport"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[lib]
|
||||
name = "codex_app_server_transport"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
axum = { workspace = true, default-features = false, features = [
|
||||
"http1",
|
||||
"json",
|
||||
"tokio",
|
||||
"ws",
|
||||
] }
|
||||
base64 = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
codex-api = { workspace = true }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
codex-core = { workspace = true }
|
||||
codex-login = { workspace = true }
|
||||
codex-model-provider = { workspace = true }
|
||||
codex-state = { workspace = true }
|
||||
codex-uds = { workspace = true }
|
||||
codex-utils-absolute-path = { workspace = true }
|
||||
codex-utils-rustls-provider = { workspace = true }
|
||||
constant_time_eq = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
gethostname = { workspace = true }
|
||||
hmac = { workspace = true }
|
||||
jsonwebtoken = { workspace = true }
|
||||
owo-colors = { workspace = true, features = ["supports-colors"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
time = { workspace = true }
|
||||
tokio = { workspace = true, features = [
|
||||
"io-std",
|
||||
"macros",
|
||||
"rt-multi-thread",
|
||||
] }
|
||||
tokio-tungstenite = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
url = { workspace = true }
|
||||
uuid = { workspace = true, features = ["serde", "v7"] }
|
||||
|
||||
[dev-dependencies]
|
||||
chrono = { workspace = true }
|
||||
codex-config = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
@@ -0,0 +1,20 @@
|
||||
mod outgoing_message;
|
||||
mod transport;
|
||||
|
||||
pub use outgoing_message::ConnectionId;
|
||||
pub use outgoing_message::OutgoingError;
|
||||
pub use outgoing_message::OutgoingMessage;
|
||||
pub use outgoing_message::OutgoingResponse;
|
||||
pub use outgoing_message::QueuedOutgoingMessage;
|
||||
pub use transport::AppServerTransport;
|
||||
pub use transport::AppServerTransportParseError;
|
||||
pub use transport::CHANNEL_CAPACITY;
|
||||
pub use transport::ConnectionOrigin;
|
||||
pub use transport::RemoteControlHandle;
|
||||
pub use transport::TransportEvent;
|
||||
pub use transport::app_server_control_socket_path;
|
||||
pub use transport::auth;
|
||||
pub use transport::start_control_socket_acceptor;
|
||||
pub use transport::start_remote_control;
|
||||
pub use transport::start_stdio_connection;
|
||||
pub use transport::start_websocket_acceptor;
|
||||
@@ -0,0 +1,58 @@
|
||||
use std::fmt;
|
||||
|
||||
use codex_app_server_protocol::JSONRPCErrorError;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::Result;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use codex_app_server_protocol::ServerRequest;
|
||||
use serde::Serialize;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
/// Stable identifier for a transport connection.
|
||||
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
|
||||
pub struct ConnectionId(pub u64);
|
||||
|
||||
impl fmt::Display for ConnectionId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Outgoing message from the server to the client.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum OutgoingMessage {
|
||||
Request(ServerRequest),
|
||||
/// AppServerNotification is specific to the case where this is run as an
|
||||
/// "app server" as opposed to an MCP server.
|
||||
AppServerNotification(ServerNotification),
|
||||
Response(OutgoingResponse),
|
||||
Error(OutgoingError),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub struct OutgoingResponse {
|
||||
pub id: RequestId,
|
||||
pub result: Result,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub struct OutgoingError {
|
||||
pub error: JSONRPCErrorError,
|
||||
pub id: RequestId,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct QueuedOutgoingMessage {
|
||||
pub message: OutgoingMessage,
|
||||
pub write_complete_tx: Option<oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
impl QueuedOutgoingMessage {
|
||||
pub fn new(message: OutgoingMessage) -> Self {
|
||||
Self {
|
||||
message,
|
||||
write_complete_tx: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
+2
-2
@@ -86,7 +86,7 @@ pub enum AppServerWebsocketCapabilityTokenSource {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub(crate) struct WebsocketAuthPolicy {
|
||||
pub struct WebsocketAuthPolicy {
|
||||
pub(crate) mode: Option<WebsocketAuthMode>,
|
||||
}
|
||||
|
||||
@@ -219,7 +219,7 @@ impl AppServerWebsocketAuthArgs {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn policy_from_settings(
|
||||
pub fn policy_from_settings(
|
||||
settings: &AppServerWebsocketAuthSettings,
|
||||
) -> io::Result<WebsocketAuthPolicy> {
|
||||
let mode = match settings.config.as_ref() {
|
||||
@@ -0,0 +1,478 @@
|
||||
pub mod auth;
|
||||
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::OutgoingError;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||
use codex_app_server_protocol::JSONRPCErrorError;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_core::config::find_codex_home;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::error;
|
||||
use tracing::warn;
|
||||
|
||||
/// Size of the bounded channels used to communicate between tasks. The value
|
||||
/// is a balance between throughput and memory usage - 128 messages should be
|
||||
/// plenty for an interactive CLI.
|
||||
pub const CHANNEL_CAPACITY: usize = 128;
|
||||
|
||||
mod remote_control;
|
||||
mod stdio;
|
||||
mod unix_socket;
|
||||
#[cfg(test)]
|
||||
mod unix_socket_tests;
|
||||
mod websocket;
|
||||
|
||||
pub use remote_control::RemoteControlHandle;
|
||||
pub use remote_control::start_remote_control;
|
||||
pub use stdio::start_stdio_connection;
|
||||
pub use unix_socket::start_control_socket_acceptor;
|
||||
pub use websocket::start_websocket_acceptor;
|
||||
|
||||
const OVERLOADED_ERROR_CODE: i64 = -32001;
|
||||
|
||||
const APP_SERVER_CONTROL_SOCKET_DIR_NAME: &str = "app-server-control";
|
||||
const APP_SERVER_CONTROL_SOCKET_FILE_NAME: &str = "app-server-control.sock";
|
||||
|
||||
pub fn app_server_control_socket_path(codex_home: &Path) -> std::io::Result<AbsolutePathBuf> {
|
||||
AbsolutePathBuf::from_absolute_path(
|
||||
codex_home
|
||||
.join(APP_SERVER_CONTROL_SOCKET_DIR_NAME)
|
||||
.join(APP_SERVER_CONTROL_SOCKET_FILE_NAME),
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub enum AppServerTransport {
|
||||
Stdio,
|
||||
UnixSocket { socket_path: AbsolutePathBuf },
|
||||
WebSocket { bind_address: SocketAddr },
|
||||
Off,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub enum AppServerTransportParseError {
|
||||
UnsupportedListenUrl(String),
|
||||
InvalidUnixSocketPath { listen_url: String, message: String },
|
||||
InvalidWebSocketListenUrl(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for AppServerTransportParseError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
AppServerTransportParseError::UnsupportedListenUrl(listen_url) => write!(
|
||||
f,
|
||||
"unsupported --listen URL `{listen_url}`; expected `stdio://`, `unix://`, `unix://PATH`, `ws://IP:PORT`, or `off`"
|
||||
),
|
||||
AppServerTransportParseError::InvalidUnixSocketPath {
|
||||
listen_url,
|
||||
message,
|
||||
} => write!(
|
||||
f,
|
||||
"invalid unix socket --listen URL `{listen_url}`; failed to resolve socket path: {message}"
|
||||
),
|
||||
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url) => write!(
|
||||
f,
|
||||
"invalid websocket --listen URL `{listen_url}`; expected `ws://IP:PORT`"
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for AppServerTransportParseError {}
|
||||
|
||||
impl AppServerTransport {
|
||||
pub const DEFAULT_LISTEN_URL: &'static str = "stdio://";
|
||||
|
||||
pub fn from_listen_url(listen_url: &str) -> Result<Self, AppServerTransportParseError> {
|
||||
if listen_url == Self::DEFAULT_LISTEN_URL {
|
||||
return Ok(Self::Stdio);
|
||||
}
|
||||
|
||||
if let Some(raw_socket_path) = listen_url.strip_prefix("unix://") {
|
||||
let socket_path = if raw_socket_path.is_empty() {
|
||||
let codex_home = find_codex_home().map_err(|err| {
|
||||
AppServerTransportParseError::InvalidUnixSocketPath {
|
||||
listen_url: listen_url.to_string(),
|
||||
message: format!("failed to resolve CODEX_HOME: {err}"),
|
||||
}
|
||||
})?;
|
||||
app_server_control_socket_path(&codex_home).map_err(|err| {
|
||||
AppServerTransportParseError::InvalidUnixSocketPath {
|
||||
listen_url: listen_url.to_string(),
|
||||
message: err.to_string(),
|
||||
}
|
||||
})?
|
||||
} else {
|
||||
AbsolutePathBuf::relative_to_current_dir(raw_socket_path).map_err(|err| {
|
||||
AppServerTransportParseError::InvalidUnixSocketPath {
|
||||
listen_url: listen_url.to_string(),
|
||||
message: err.to_string(),
|
||||
}
|
||||
})?
|
||||
};
|
||||
return Ok(Self::UnixSocket { socket_path });
|
||||
}
|
||||
|
||||
if listen_url == "off" {
|
||||
return Ok(Self::Off);
|
||||
}
|
||||
|
||||
if let Some(socket_addr) = listen_url.strip_prefix("ws://") {
|
||||
let bind_address = socket_addr.parse::<SocketAddr>().map_err(|_| {
|
||||
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string())
|
||||
})?;
|
||||
return Ok(Self::WebSocket { bind_address });
|
||||
}
|
||||
|
||||
Err(AppServerTransportParseError::UnsupportedListenUrl(
|
||||
listen_url.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for AppServerTransport {
|
||||
type Err = AppServerTransportParseError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
Self::from_listen_url(s)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TransportEvent {
|
||||
ConnectionOpened {
|
||||
connection_id: ConnectionId,
|
||||
origin: ConnectionOrigin,
|
||||
writer: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
},
|
||||
ConnectionClosed {
|
||||
connection_id: ConnectionId,
|
||||
},
|
||||
IncomingMessage {
|
||||
connection_id: ConnectionId,
|
||||
message: JSONRPCMessage,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ConnectionOrigin {
|
||||
Stdio,
|
||||
InProcess,
|
||||
WebSocket,
|
||||
RemoteControl,
|
||||
}
|
||||
|
||||
impl ConnectionOrigin {
|
||||
pub fn allows_device_key_requests(self) -> bool {
|
||||
// Device-key endpoints are only for local connections that own the app-server instance.
|
||||
// Do not include remote transports such as SSH or remote-control websocket connections.
|
||||
matches!(self, Self::Stdio | Self::InProcess)
|
||||
}
|
||||
}
|
||||
|
||||
static CONNECTION_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
fn next_connection_id() -> ConnectionId {
|
||||
ConnectionId(CONNECTION_ID_COUNTER.fetch_add(1, Ordering::Relaxed))
|
||||
}
|
||||
|
||||
async fn forward_incoming_message(
|
||||
transport_event_tx: &mpsc::Sender<TransportEvent>,
|
||||
writer: &mpsc::Sender<QueuedOutgoingMessage>,
|
||||
connection_id: ConnectionId,
|
||||
payload: &str,
|
||||
) -> bool {
|
||||
match serde_json::from_str::<JSONRPCMessage>(payload) {
|
||||
Ok(message) => {
|
||||
enqueue_incoming_message(transport_event_tx, writer, connection_id, message).await
|
||||
}
|
||||
Err(err) => {
|
||||
error!("Failed to deserialize JSONRPCMessage: {err}");
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn enqueue_incoming_message(
|
||||
transport_event_tx: &mpsc::Sender<TransportEvent>,
|
||||
writer: &mpsc::Sender<QueuedOutgoingMessage>,
|
||||
connection_id: ConnectionId,
|
||||
message: JSONRPCMessage,
|
||||
) -> bool {
|
||||
let event = TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message,
|
||||
};
|
||||
match transport_event_tx.try_send(event) {
|
||||
Ok(()) => true,
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => false,
|
||||
Err(mpsc::error::TrySendError::Full(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message: JSONRPCMessage::Request(request),
|
||||
})) => {
|
||||
let overload_error = OutgoingMessage::Error(OutgoingError {
|
||||
id: request.id,
|
||||
error: JSONRPCErrorError {
|
||||
code: OVERLOADED_ERROR_CODE,
|
||||
message: "Server overloaded; retry later.".to_string(),
|
||||
data: None,
|
||||
},
|
||||
});
|
||||
match writer.try_send(QueuedOutgoingMessage::new(overload_error)) {
|
||||
Ok(()) => true,
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => false,
|
||||
Err(mpsc::error::TrySendError::Full(_overload_error)) => {
|
||||
warn!(
|
||||
"dropping overload response for connection {:?}: outbound queue is full",
|
||||
connection_id
|
||||
);
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Full(event)) => transport_event_tx.send(event).await.is_ok(),
|
||||
}
|
||||
}
|
||||
|
||||
fn serialize_outgoing_message(outgoing_message: OutgoingMessage) -> Option<String> {
|
||||
let value = match serde_json::to_value(outgoing_message) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
error!("Failed to convert OutgoingMessage to JSON value: {err}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
match serde_json::to_string(&value) {
|
||||
Ok(json) => Some(json),
|
||||
Err(err) => {
|
||||
error!("Failed to serialize JSONRPCMessage: {err}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use codex_app_server_protocol::ConfigWarningNotification;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
|
||||
#[test]
|
||||
fn listen_off_parses_as_off_transport() {
|
||||
assert_eq!(
|
||||
AppServerTransport::from_listen_url("off"),
|
||||
Ok(AppServerTransport::Off)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enqueue_incoming_request_returns_overload_error_when_queue_is_full() {
|
||||
let connection_id = ConnectionId(42);
|
||||
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
|
||||
let first_message = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
transport_event_tx
|
||||
.send(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message: first_message.clone(),
|
||||
})
|
||||
.await
|
||||
.expect("queue should accept first message");
|
||||
|
||||
let request = JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(7),
|
||||
method: "config/read".to_string(),
|
||||
params: Some(json!({ "includeLayers": false })),
|
||||
trace: None,
|
||||
});
|
||||
assert!(
|
||||
enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request).await
|
||||
);
|
||||
|
||||
let queued_event = transport_event_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("first event should stay queued");
|
||||
match queued_event {
|
||||
TransportEvent::IncomingMessage {
|
||||
connection_id: queued_connection_id,
|
||||
message,
|
||||
} => {
|
||||
assert_eq!(queued_connection_id, connection_id);
|
||||
assert_eq!(message, first_message);
|
||||
}
|
||||
_ => panic!("expected queued incoming message"),
|
||||
}
|
||||
|
||||
let overload = writer_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("request should receive overload error");
|
||||
let overload_json =
|
||||
serde_json::to_value(overload.message).expect("serialize overload error");
|
||||
assert_eq!(
|
||||
overload_json,
|
||||
json!({
|
||||
"id": 7,
|
||||
"error": {
|
||||
"code": OVERLOADED_ERROR_CODE,
|
||||
"message": "Server overloaded; retry later."
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enqueue_incoming_response_waits_instead_of_dropping_when_queue_is_full() {
|
||||
let connection_id = ConnectionId(42);
|
||||
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1);
|
||||
let (writer_tx, _writer_rx) = mpsc::channel(1);
|
||||
|
||||
let first_message = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
transport_event_tx
|
||||
.send(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message: first_message.clone(),
|
||||
})
|
||||
.await
|
||||
.expect("queue should accept first message");
|
||||
|
||||
let response = JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id: RequestId::Integer(7),
|
||||
result: json!({"ok": true}),
|
||||
});
|
||||
let transport_event_tx_for_enqueue = transport_event_tx.clone();
|
||||
let writer_tx_for_enqueue = writer_tx.clone();
|
||||
let enqueue_handle = tokio::spawn(async move {
|
||||
enqueue_incoming_message(
|
||||
&transport_event_tx_for_enqueue,
|
||||
&writer_tx_for_enqueue,
|
||||
connection_id,
|
||||
response,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
let queued_event = transport_event_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("first event should be dequeued");
|
||||
match queued_event {
|
||||
TransportEvent::IncomingMessage {
|
||||
connection_id: queued_connection_id,
|
||||
message,
|
||||
} => {
|
||||
assert_eq!(queued_connection_id, connection_id);
|
||||
assert_eq!(message, first_message);
|
||||
}
|
||||
_ => panic!("expected queued incoming message"),
|
||||
}
|
||||
|
||||
let enqueue_result = enqueue_handle.await.expect("enqueue task should not panic");
|
||||
assert!(enqueue_result);
|
||||
|
||||
let forwarded_event = transport_event_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("response should be forwarded instead of dropped");
|
||||
match forwarded_event {
|
||||
TransportEvent::IncomingMessage {
|
||||
connection_id: queued_connection_id,
|
||||
message: JSONRPCMessage::Response(JSONRPCResponse { id, result }),
|
||||
} => {
|
||||
assert_eq!(queued_connection_id, connection_id);
|
||||
assert_eq!(id, RequestId::Integer(7));
|
||||
assert_eq!(result, json!({"ok": true}));
|
||||
}
|
||||
_ => panic!("expected forwarded response message"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enqueue_incoming_request_does_not_block_when_writer_queue_is_full() {
|
||||
let connection_id = ConnectionId(42);
|
||||
let (transport_event_tx, _transport_event_rx) = mpsc::channel(1);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
|
||||
transport_event_tx
|
||||
.send(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message: JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
}),
|
||||
})
|
||||
.await
|
||||
.expect("transport queue should accept first message");
|
||||
|
||||
writer_tx
|
||||
.send(QueuedOutgoingMessage::new(
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "queued".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
))
|
||||
.await
|
||||
.expect("writer queue should accept first message");
|
||||
|
||||
let request = JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(7),
|
||||
method: "config/read".to_string(),
|
||||
params: Some(json!({ "includeLayers": false })),
|
||||
trace: None,
|
||||
});
|
||||
|
||||
let enqueue_result = timeout(
|
||||
Duration::from_millis(100),
|
||||
enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request),
|
||||
)
|
||||
.await
|
||||
.expect("enqueue should not block while writer queue is full");
|
||||
assert!(enqueue_result);
|
||||
|
||||
let queued_outgoing = writer_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("writer queue should still contain original message");
|
||||
let queued_json =
|
||||
serde_json::to_value(queued_outgoing.message).expect("serialize queued message");
|
||||
assert_eq!(
|
||||
queued_json,
|
||||
json!({
|
||||
"method": "configWarning",
|
||||
"params": {
|
||||
"summary": "queued",
|
||||
"details": null,
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
+4
-6
@@ -36,14 +36,14 @@ pub(super) struct QueuedServerEnvelope {
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RemoteControlHandle {
|
||||
pub struct RemoteControlHandle {
|
||||
enabled_tx: Arc<watch::Sender<bool>>,
|
||||
status_tx: Arc<watch::Sender<RemoteControlStatusChangedNotification>>,
|
||||
state_db_available: bool,
|
||||
}
|
||||
|
||||
impl RemoteControlHandle {
|
||||
pub(crate) fn set_enabled(&self, enabled: bool) {
|
||||
pub fn set_enabled(&self, enabled: bool) {
|
||||
let requested_enabled = enabled;
|
||||
let enabled = enabled && self.state_db_available;
|
||||
if requested_enabled && !self.state_db_available {
|
||||
@@ -56,14 +56,12 @@ impl RemoteControlHandle {
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn status_receiver(
|
||||
&self,
|
||||
) -> watch::Receiver<RemoteControlStatusChangedNotification> {
|
||||
pub fn status_receiver(&self) -> watch::Receiver<RemoteControlStatusChangedNotification> {
|
||||
self.status_tx.subscribe()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn start_remote_control(
|
||||
pub async fn start_remote_control(
|
||||
remote_control_url: String,
|
||||
state_db: Option<Arc<StateRuntime>>,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
+1
-1
@@ -21,7 +21,7 @@ use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
|
||||
pub(crate) async fn start_stdio_connection(
|
||||
pub async fn start_stdio_connection(
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
stdio_handles: &mut Vec<JoinHandle<()>>,
|
||||
initialize_client_name_tx: oneshot::Sender<String>,
|
||||
+1
-1
@@ -20,7 +20,7 @@ use tracing::warn;
|
||||
#[cfg(unix)]
|
||||
const CONTROL_SOCKET_MODE: u32 = 0o600;
|
||||
|
||||
pub(crate) async fn start_control_socket_acceptor(
|
||||
pub async fn start_control_socket_acceptor(
|
||||
socket_path: AbsolutePathBuf,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
shutdown_token: CancellationToken,
|
||||
+1
-1
@@ -128,7 +128,7 @@ async fn websocket_upgrade_handler(
|
||||
.into_response()
|
||||
}
|
||||
|
||||
pub(crate) async fn start_websocket_acceptor(
|
||||
pub async fn start_websocket_acceptor(
|
||||
bind_address: SocketAddr,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
shutdown_token: CancellationToken,
|
||||
@@ -30,7 +30,6 @@ axum = { workspace = true, default-features = false, features = [
|
||||
"ws",
|
||||
] }
|
||||
codex-analytics = { workspace = true }
|
||||
codex-api = { workspace = true }
|
||||
codex-arg0 = { workspace = true }
|
||||
codex-cloud-requirements = { workspace = true }
|
||||
codex-config = { workspace = true }
|
||||
@@ -58,6 +57,7 @@ codex-model-provider = { workspace = true }
|
||||
codex-models-manager = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
codex-app-server-transport = { workspace = true }
|
||||
codex-feedback = { workspace = true }
|
||||
codex-rmcp-client = { workspace = true }
|
||||
codex-rollout = { workspace = true }
|
||||
@@ -65,18 +65,11 @@ codex-sandboxing = { workspace = true }
|
||||
codex-state = { workspace = true }
|
||||
codex-thread-store = { workspace = true }
|
||||
codex-tools = { workspace = true }
|
||||
codex-uds = { workspace = true }
|
||||
codex-utils-absolute-path = { workspace = true }
|
||||
codex-utils-json-to-toml = { workspace = true }
|
||||
codex-utils-rustls-provider = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
constant_time_eq = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
gethostname = { workspace = true }
|
||||
hmac = { workspace = true }
|
||||
jsonwebtoken = { workspace = true }
|
||||
owo-colors = { workspace = true, features = ["supports-colors"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
@@ -93,7 +86,6 @@ tokio = { workspace = true, features = [
|
||||
"signal",
|
||||
] }
|
||||
tokio-util = { workspace = true }
|
||||
tokio-tungstenite = { workspace = true }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter", "fmt", "json"] }
|
||||
url = { workspace = true }
|
||||
@@ -111,6 +103,7 @@ core_test_support = { workspace = true }
|
||||
codex-model-provider-info = { workspace = true }
|
||||
codex-utils-cargo-bin = { workspace = true }
|
||||
flate2 = { workspace = true }
|
||||
hmac = { workspace = true }
|
||||
opentelemetry = { workspace = true }
|
||||
opentelemetry_sdk = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
@@ -15,7 +14,6 @@ use codex_app_server_protocol::ServerRequestPayload;
|
||||
use codex_otel::span_w3c_trace_context;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::protocol::W3cTraceContext;
|
||||
use serde::Serialize;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
@@ -26,22 +24,17 @@ use tracing::warn;
|
||||
use crate::error_code::INTERNAL_ERROR_CODE;
|
||||
use crate::error_code::internal_error;
|
||||
use crate::server_request_error::TURN_TRANSITION_PENDING_REQUEST_ERROR_REASON;
|
||||
pub(crate) use codex_app_server_transport::ConnectionId;
|
||||
pub(crate) use codex_app_server_transport::OutgoingError;
|
||||
pub(crate) use codex_app_server_transport::OutgoingMessage;
|
||||
pub(crate) use codex_app_server_transport::OutgoingResponse;
|
||||
pub(crate) use codex_app_server_transport::QueuedOutgoingMessage;
|
||||
|
||||
#[cfg(test)]
|
||||
use codex_protocol::account::PlanType;
|
||||
|
||||
pub(crate) type ClientRequestResult = std::result::Result<Result, JSONRPCErrorError>;
|
||||
|
||||
/// Stable identifier for a transport connection.
|
||||
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
|
||||
pub(crate) struct ConnectionId(pub(crate) u64);
|
||||
|
||||
impl fmt::Display for ConnectionId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Stable identifier for a client request scoped to a transport connection.
|
||||
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
|
||||
pub(crate) struct ConnectionRequestId {
|
||||
@@ -96,21 +89,6 @@ pub(crate) enum OutgoingEnvelope {
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct QueuedOutgoingMessage {
|
||||
pub(crate) message: OutgoingMessage,
|
||||
pub(crate) write_complete_tx: Option<oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
impl QueuedOutgoingMessage {
|
||||
pub(crate) fn new(message: OutgoingMessage) -> Self {
|
||||
Self {
|
||||
message,
|
||||
write_complete_tx: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends messages to the client and manages request callbacks.
|
||||
pub(crate) struct OutgoingMessageSender {
|
||||
next_server_request_id: AtomicI64,
|
||||
@@ -665,30 +643,6 @@ impl OutgoingMessageSender {
|
||||
}
|
||||
}
|
||||
|
||||
/// Outgoing message from the server to the client.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub(crate) enum OutgoingMessage {
|
||||
Request(ServerRequest),
|
||||
/// AppServerNotification is specific to the case where this is run as an
|
||||
/// "app server" as opposed to an MCP server.
|
||||
AppServerNotification(ServerNotification),
|
||||
Response(OutgoingResponse),
|
||||
Error(OutgoingError),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub(crate) struct OutgoingResponse {
|
||||
pub id: RequestId,
|
||||
pub result: Result,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub(crate) struct OutgoingError {
|
||||
pub error: JSONRPCErrorError,
|
||||
pub id: RequestId,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -0,0 +1,232 @@
|
||||
use crate::message_processor::ConnectionSessionState;
|
||||
use crate::outgoing_message::OutgoingEnvelope;
|
||||
use codex_app_server_protocol::ExperimentalApi;
|
||||
use codex_app_server_protocol::ServerRequest;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::warn;
|
||||
|
||||
pub use codex_app_server_transport::AppServerTransport;
|
||||
pub(crate) use codex_app_server_transport::CHANNEL_CAPACITY;
|
||||
pub(crate) use codex_app_server_transport::ConnectionId;
|
||||
pub(crate) use codex_app_server_transport::ConnectionOrigin;
|
||||
pub(crate) use codex_app_server_transport::OutgoingMessage;
|
||||
pub(crate) use codex_app_server_transport::QueuedOutgoingMessage;
|
||||
pub(crate) use codex_app_server_transport::RemoteControlHandle;
|
||||
pub(crate) use codex_app_server_transport::TransportEvent;
|
||||
pub use codex_app_server_transport::app_server_control_socket_path;
|
||||
pub use codex_app_server_transport::auth;
|
||||
pub(crate) use codex_app_server_transport::start_control_socket_acceptor;
|
||||
pub(crate) use codex_app_server_transport::start_remote_control;
|
||||
pub(crate) use codex_app_server_transport::start_stdio_connection;
|
||||
pub(crate) use codex_app_server_transport::start_websocket_acceptor;
|
||||
|
||||
pub(crate) struct ConnectionState {
|
||||
pub(crate) outbound_initialized: Arc<AtomicBool>,
|
||||
pub(crate) outbound_experimental_api_enabled: Arc<AtomicBool>,
|
||||
pub(crate) outbound_opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
pub(crate) session: Arc<ConnectionSessionState>,
|
||||
}
|
||||
|
||||
impl ConnectionState {
|
||||
pub(crate) fn new(
|
||||
origin: ConnectionOrigin,
|
||||
outbound_initialized: Arc<AtomicBool>,
|
||||
outbound_experimental_api_enabled: Arc<AtomicBool>,
|
||||
outbound_opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
outbound_initialized,
|
||||
outbound_experimental_api_enabled,
|
||||
outbound_opted_out_notification_methods,
|
||||
session: Arc::new(ConnectionSessionState::new(origin)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct OutboundConnectionState {
|
||||
pub(crate) initialized: Arc<AtomicBool>,
|
||||
pub(crate) experimental_api_enabled: Arc<AtomicBool>,
|
||||
pub(crate) opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
pub(crate) writer: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
}
|
||||
|
||||
impl OutboundConnectionState {
|
||||
pub(crate) fn new(
|
||||
writer: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
initialized: Arc<AtomicBool>,
|
||||
experimental_api_enabled: Arc<AtomicBool>,
|
||||
opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
) -> Self {
|
||||
Self {
|
||||
initialized,
|
||||
experimental_api_enabled,
|
||||
opted_out_notification_methods,
|
||||
writer,
|
||||
disconnect_sender,
|
||||
}
|
||||
}
|
||||
|
||||
fn can_disconnect(&self) -> bool {
|
||||
self.disconnect_sender.is_some()
|
||||
}
|
||||
|
||||
pub(crate) fn request_disconnect(&self) {
|
||||
if let Some(disconnect_sender) = &self.disconnect_sender {
|
||||
disconnect_sender.cancel();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn should_skip_notification_for_connection(
|
||||
connection_state: &OutboundConnectionState,
|
||||
message: &OutgoingMessage,
|
||||
) -> bool {
|
||||
let Ok(opted_out_notification_methods) = connection_state.opted_out_notification_methods.read()
|
||||
else {
|
||||
warn!("failed to read outbound opted-out notifications");
|
||||
return false;
|
||||
};
|
||||
match message {
|
||||
OutgoingMessage::AppServerNotification(notification) => {
|
||||
if notification.experimental_reason().is_some()
|
||||
&& !connection_state
|
||||
.experimental_api_enabled
|
||||
.load(Ordering::Acquire)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
let method = notification.to_string();
|
||||
opted_out_notification_methods.contains(method.as_str())
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn disconnect_connection(
|
||||
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
|
||||
connection_id: ConnectionId,
|
||||
) -> bool {
|
||||
if let Some(connection_state) = connections.remove(&connection_id) {
|
||||
connection_state.request_disconnect();
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
async fn send_message_to_connection(
|
||||
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
|
||||
connection_id: ConnectionId,
|
||||
message: OutgoingMessage,
|
||||
write_complete_tx: Option<tokio::sync::oneshot::Sender<()>>,
|
||||
) -> bool {
|
||||
let Some(connection_state) = connections.get(&connection_id) else {
|
||||
warn!("dropping message for disconnected connection: {connection_id:?}");
|
||||
return false;
|
||||
};
|
||||
let message = filter_outgoing_message_for_connection(connection_state, message);
|
||||
if should_skip_notification_for_connection(connection_state, &message) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let writer = connection_state.writer.clone();
|
||||
let queued_message = QueuedOutgoingMessage {
|
||||
message,
|
||||
write_complete_tx,
|
||||
};
|
||||
if connection_state.can_disconnect() {
|
||||
match writer.try_send(queued_message) {
|
||||
Ok(()) => false,
|
||||
Err(mpsc::error::TrySendError::Full(_)) => {
|
||||
warn!(
|
||||
"disconnecting slow connection after outbound queue filled: {connection_id:?}"
|
||||
);
|
||||
disconnect_connection(connections, connection_id)
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => {
|
||||
disconnect_connection(connections, connection_id)
|
||||
}
|
||||
}
|
||||
} else if writer.send(queued_message).await.is_err() {
|
||||
disconnect_connection(connections, connection_id)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn filter_outgoing_message_for_connection(
|
||||
connection_state: &OutboundConnectionState,
|
||||
message: OutgoingMessage,
|
||||
) -> OutgoingMessage {
|
||||
let experimental_api_enabled = connection_state
|
||||
.experimental_api_enabled
|
||||
.load(Ordering::Acquire);
|
||||
match message {
|
||||
OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
|
||||
request_id,
|
||||
mut params,
|
||||
}) => {
|
||||
if !experimental_api_enabled {
|
||||
params.strip_experimental_fields();
|
||||
}
|
||||
OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
|
||||
request_id,
|
||||
params,
|
||||
})
|
||||
}
|
||||
_ => message,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn route_outgoing_envelope(
|
||||
connections: &mut HashMap<ConnectionId, OutboundConnectionState>,
|
||||
envelope: OutgoingEnvelope,
|
||||
) {
|
||||
match envelope {
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message,
|
||||
write_complete_tx,
|
||||
} => {
|
||||
let _ =
|
||||
send_message_to_connection(connections, connection_id, message, write_complete_tx)
|
||||
.await;
|
||||
}
|
||||
OutgoingEnvelope::Broadcast { message } => {
|
||||
let target_connections: Vec<ConnectionId> = connections
|
||||
.iter()
|
||||
.filter_map(|(connection_id, connection_state)| {
|
||||
if connection_state.initialized.load(Ordering::Acquire)
|
||||
&& !should_skip_notification_for_connection(connection_state, &message)
|
||||
{
|
||||
Some(*connection_id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
for connection_id in target_connections {
|
||||
let _ = send_message_to_connection(
|
||||
connections,
|
||||
connection_id,
|
||||
message.clone(),
|
||||
/*write_complete_tx*/ None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "transport_tests.rs"]
|
||||
mod tests;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,532 @@
|
||||
use super::*;
|
||||
use codex_app_server_protocol::ConfigWarningNotification;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use codex_app_server_protocol::ThreadGoal;
|
||||
use codex_app_server_protocol::ThreadGoalStatus;
|
||||
use codex_app_server_protocol::ThreadGoalUpdatedNotification;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
|
||||
fn absolute_path(path: &str) -> AbsolutePathBuf {
|
||||
AbsolutePathBuf::from_absolute_path(path).expect("absolute path")
|
||||
}
|
||||
|
||||
fn thread_goal_updated_notification() -> ServerNotification {
|
||||
ServerNotification::ThreadGoalUpdated(ThreadGoalUpdatedNotification {
|
||||
thread_id: "thread-1".to_string(),
|
||||
turn_id: None,
|
||||
goal: ThreadGoal {
|
||||
thread_id: "thread-1".to_string(),
|
||||
objective: "ship goal mode".to_string(),
|
||||
status: ThreadGoalStatus::Active,
|
||||
token_budget: None,
|
||||
tokens_used: 0,
|
||||
time_used_seconds: 0,
|
||||
created_at: 1,
|
||||
updated_at: 1,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn to_connection_notification_respects_opt_out_filters() {
|
||||
let connection_id = ConnectionId(7);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
let initialized = Arc::new(AtomicBool::new(true));
|
||||
let opted_out_notification_methods =
|
||||
Arc::new(RwLock::new(HashSet::from(["configWarning".to_string()])));
|
||||
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
writer_tx,
|
||||
initialized,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
opted_out_notification_methods,
|
||||
/*disconnect_sender*/ None,
|
||||
),
|
||||
);
|
||||
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "task_started".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
writer_rx.try_recv().is_err(),
|
||||
"opted-out notification should be dropped"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn to_connection_notifications_are_dropped_for_opted_out_clients() {
|
||||
let connection_id = ConnectionId(10);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::from(["configWarning".to_string()]))),
|
||||
/*disconnect_sender*/ None,
|
||||
),
|
||||
);
|
||||
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "task_started".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
writer_rx.try_recv().is_err(),
|
||||
"opted-out notifications should not reach clients"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn to_connection_notifications_are_preserved_for_non_opted_out_clients() {
|
||||
let connection_id = ConnectionId(11);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
/*disconnect_sender*/ None,
|
||||
),
|
||||
);
|
||||
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "task_started".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
let message = writer_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("notification should reach non-opted-out clients");
|
||||
assert!(matches!(
|
||||
message.message,
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification { summary, .. }
|
||||
)) if summary == "task_started"
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn experimental_notifications_are_dropped_without_capability() {
|
||||
let connection_id = ConnectionId(12);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
/*disconnect_sender*/ None,
|
||||
),
|
||||
);
|
||||
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::AppServerNotification(thread_goal_updated_notification()),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
writer_rx.try_recv().is_err(),
|
||||
"experimental notifications should not reach clients without capability"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn experimental_notifications_are_preserved_with_capability() {
|
||||
let connection_id = ConnectionId(13);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
/*disconnect_sender*/ None,
|
||||
),
|
||||
);
|
||||
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::AppServerNotification(thread_goal_updated_notification()),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
let message = writer_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("experimental notification should reach opted-in client");
|
||||
assert!(matches!(
|
||||
message.message,
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ThreadGoalUpdated(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn command_execution_request_approval_strips_additional_permissions_without_capability() {
|
||||
let connection_id = ConnectionId(8);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
/*disconnect_sender*/ None,
|
||||
),
|
||||
);
|
||||
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
|
||||
request_id: RequestId::Integer(1),
|
||||
params: codex_app_server_protocol::CommandExecutionRequestApprovalParams {
|
||||
thread_id: "thr_123".to_string(),
|
||||
turn_id: "turn_123".to_string(),
|
||||
item_id: "call_123".to_string(),
|
||||
approval_id: None,
|
||||
reason: Some("Need extra read access".to_string()),
|
||||
network_approval_context: None,
|
||||
command: Some("cat file".to_string()),
|
||||
cwd: Some(absolute_path("/tmp")),
|
||||
command_actions: None,
|
||||
additional_permissions: Some(
|
||||
codex_app_server_protocol::AdditionalPermissionProfile {
|
||||
network: None,
|
||||
file_system: Some(
|
||||
codex_app_server_protocol::AdditionalFileSystemPermissions {
|
||||
read: Some(vec![absolute_path("/tmp/allowed")]),
|
||||
write: None,
|
||||
glob_scan_max_depth: None,
|
||||
entries: None,
|
||||
},
|
||||
),
|
||||
},
|
||||
),
|
||||
proposed_execpolicy_amendment: None,
|
||||
proposed_network_policy_amendments: None,
|
||||
available_decisions: None,
|
||||
},
|
||||
}),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
let message = writer_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("request should be delivered to the connection");
|
||||
let json = serde_json::to_value(message.message).expect("request should serialize");
|
||||
assert_eq!(json["params"].get("additionalPermissions"), None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn command_execution_request_approval_keeps_additional_permissions_with_capability() {
|
||||
let connection_id = ConnectionId(9);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
/*disconnect_sender*/ None,
|
||||
),
|
||||
);
|
||||
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
|
||||
request_id: RequestId::Integer(1),
|
||||
params: codex_app_server_protocol::CommandExecutionRequestApprovalParams {
|
||||
thread_id: "thr_123".to_string(),
|
||||
turn_id: "turn_123".to_string(),
|
||||
item_id: "call_123".to_string(),
|
||||
approval_id: None,
|
||||
reason: Some("Need extra read access".to_string()),
|
||||
network_approval_context: None,
|
||||
command: Some("cat file".to_string()),
|
||||
cwd: Some(absolute_path("/tmp")),
|
||||
command_actions: None,
|
||||
additional_permissions: Some(
|
||||
codex_app_server_protocol::AdditionalPermissionProfile {
|
||||
network: None,
|
||||
file_system: Some(
|
||||
codex_app_server_protocol::AdditionalFileSystemPermissions {
|
||||
read: Some(vec![absolute_path("/tmp/allowed")]),
|
||||
write: None,
|
||||
glob_scan_max_depth: None,
|
||||
entries: None,
|
||||
},
|
||||
),
|
||||
},
|
||||
),
|
||||
proposed_execpolicy_amendment: None,
|
||||
proposed_network_policy_amendments: None,
|
||||
available_decisions: None,
|
||||
},
|
||||
}),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
let message = writer_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("request should be delivered to the connection");
|
||||
let json = serde_json::to_value(message.message).expect("request should serialize");
|
||||
let allowed_path = absolute_path("/tmp/allowed").to_string_lossy().into_owned();
|
||||
assert_eq!(
|
||||
json["params"]["additionalPermissions"],
|
||||
json!({
|
||||
"network": null,
|
||||
"fileSystem": {
|
||||
"read": [allowed_path],
|
||||
"write": null,
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn broadcast_does_not_block_on_slow_connection() {
|
||||
let fast_connection_id = ConnectionId(1);
|
||||
let slow_connection_id = ConnectionId(2);
|
||||
|
||||
let (fast_writer_tx, mut fast_writer_rx) = mpsc::channel(1);
|
||||
let (slow_writer_tx, mut slow_writer_rx) = mpsc::channel(1);
|
||||
let fast_disconnect_token = CancellationToken::new();
|
||||
let slow_disconnect_token = CancellationToken::new();
|
||||
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
fast_connection_id,
|
||||
OutboundConnectionState::new(
|
||||
fast_writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
Some(fast_disconnect_token.clone()),
|
||||
),
|
||||
);
|
||||
connections.insert(
|
||||
slow_connection_id,
|
||||
OutboundConnectionState::new(
|
||||
slow_writer_tx.clone(),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
Some(slow_disconnect_token.clone()),
|
||||
),
|
||||
);
|
||||
|
||||
let queued_message = OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "already-buffered".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
));
|
||||
slow_writer_tx
|
||||
.try_send(QueuedOutgoingMessage::new(queued_message))
|
||||
.expect("channel should have room");
|
||||
|
||||
let broadcast_message = OutgoingMessage::AppServerNotification(
|
||||
ServerNotification::ConfigWarning(ConfigWarningNotification {
|
||||
summary: "test".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
}),
|
||||
);
|
||||
timeout(
|
||||
Duration::from_millis(100),
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::Broadcast {
|
||||
message: broadcast_message,
|
||||
},
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("broadcast should return even when one connection is slow");
|
||||
assert!(!connections.contains_key(&slow_connection_id));
|
||||
assert!(slow_disconnect_token.is_cancelled());
|
||||
assert!(!fast_disconnect_token.is_cancelled());
|
||||
let fast_message = fast_writer_rx
|
||||
.try_recv()
|
||||
.expect("fast connection should receive the broadcast notification");
|
||||
assert!(matches!(
|
||||
fast_message.message,
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification { summary, .. }
|
||||
)) if summary == "test"
|
||||
));
|
||||
|
||||
let slow_message = slow_writer_rx
|
||||
.try_recv()
|
||||
.expect("slow connection should retain its original buffered message");
|
||||
assert!(matches!(
|
||||
slow_message.message,
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification { summary, .. }
|
||||
)) if summary == "already-buffered"
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn to_connection_stdio_waits_instead_of_disconnecting_when_writer_queue_is_full() {
|
||||
let connection_id = ConnectionId(3);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
writer_tx
|
||||
.send(QueuedOutgoingMessage::new(
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "queued".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
))
|
||||
.await
|
||||
.expect("channel should accept the first queued message");
|
||||
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
/*disconnect_sender*/ None,
|
||||
),
|
||||
);
|
||||
|
||||
let route_task = tokio::spawn(async move {
|
||||
route_outgoing_envelope(
|
||||
&mut connections,
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "second".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
write_complete_tx: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
let first = timeout(Duration::from_millis(100), writer_rx.recv())
|
||||
.await
|
||||
.expect("first queued message should be readable")
|
||||
.expect("first queued message should exist");
|
||||
timeout(Duration::from_millis(100), route_task)
|
||||
.await
|
||||
.expect("routing should finish after the first queued message is drained")
|
||||
.expect("routing task should succeed");
|
||||
|
||||
assert!(matches!(
|
||||
first.message,
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification { summary, .. }
|
||||
)) if summary == "queued"
|
||||
));
|
||||
let second = writer_rx
|
||||
.try_recv()
|
||||
.expect("second notification should be delivered once the queue has room");
|
||||
assert!(matches!(
|
||||
second.message,
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification { summary, .. }
|
||||
)) if summary == "second"
|
||||
));
|
||||
}
|
||||
Reference in New Issue
Block a user