From 41e171fcf2c4664136ee3bee9e6c3a2aeca4e140 Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Fri, 1 May 2026 09:23:47 -0700 Subject: [PATCH] app-server: move transport into dedicated crate (#20545) ## Why `codex-app-server` currently owns both request-processing code and transport implementation details. Splitting the transport layer into its own crate makes that boundary explicit, reduces the amount of transport-specific dependency surface carried by `codex-app-server`, and gives future transport work a narrower place to evolve. ## What changed - Added `codex-app-server-transport` and moved the existing transport tree into it, including stdio, unix socket, websocket, remote-control transport, and websocket auth. - Moved shared transport-facing message types into the new crate so both the transport implementation and `codex-app-server` use the same definitions. - Kept processor-facing connection state and outbound routing in `codex-app-server`, with the routing tests moved next to that local wrapper. - Updated workspace metadata, Bazel crate metadata, and `codex-app-server` dependencies for the new crate boundary. ## Validation - `cargo metadata --locked --no-deps` - `git diff --check` - Attempted `cargo test -p codex-app-server-transport`, `cargo test -p codex-app-server`, `just fix -p codex-app-server-transport`, and `just fix -p codex-app-server`; all were blocked before compilation by the existing `packageproxy` resolution failure for locked `rustls-webpki = 0.103.13`. - Attempted Bazel build / lockfile validation; those were blocked by external fetch failures against BuildBuddy / GitHub while resolving `v8`. --- codex-rs/Cargo.lock | 47 +- codex-rs/Cargo.toml | 2 + codex-rs/app-server-transport/BUILD.bazel | 6 + codex-rs/app-server-transport/Cargo.toml | 58 + codex-rs/app-server-transport/src/lib.rs | 20 + .../src/outgoing_message.rs | 58 + .../src/transport/auth.rs | 4 +- .../app-server-transport/src/transport/mod.rs | 478 +++++++ .../remote_control/client_tracker.rs | 0 .../src/transport/remote_control/enroll.rs | 0 .../src/transport/remote_control/mod.rs | 10 +- .../src/transport/remote_control/protocol.rs | 0 .../src/transport/remote_control/segment.rs | 0 .../transport/remote_control/segment_tests.rs | 0 .../src/transport/remote_control/tests.rs | 0 .../src/transport/remote_control/websocket.rs | 0 .../src/transport/stdio.rs | 2 +- .../src/transport/unix_socket.rs | 2 +- .../src/transport/unix_socket_tests.rs | 0 .../src/transport/websocket.rs | 2 +- codex-rs/app-server/Cargo.toml | 11 +- codex-rs/app-server/src/outgoing_message.rs | 56 +- codex-rs/app-server/src/transport.rs | 232 ++++ codex-rs/app-server/src/transport/mod.rs | 1210 ----------------- codex-rs/app-server/src/transport_tests.rs | 532 ++++++++ 25 files changed, 1442 insertions(+), 1288 deletions(-) create mode 100644 codex-rs/app-server-transport/BUILD.bazel create mode 100644 codex-rs/app-server-transport/Cargo.toml create mode 100644 codex-rs/app-server-transport/src/lib.rs create mode 100644 codex-rs/app-server-transport/src/outgoing_message.rs rename codex-rs/{app-server => app-server-transport}/src/transport/auth.rs (99%) create mode 100644 codex-rs/app-server-transport/src/transport/mod.rs rename codex-rs/{app-server => app-server-transport}/src/transport/remote_control/client_tracker.rs (100%) rename codex-rs/{app-server => app-server-transport}/src/transport/remote_control/enroll.rs (100%) rename codex-rs/{app-server => app-server-transport}/src/transport/remote_control/mod.rs (93%) rename codex-rs/{app-server => app-server-transport}/src/transport/remote_control/protocol.rs (100%) rename codex-rs/{app-server => app-server-transport}/src/transport/remote_control/segment.rs (100%) rename codex-rs/{app-server => app-server-transport}/src/transport/remote_control/segment_tests.rs (100%) rename codex-rs/{app-server => app-server-transport}/src/transport/remote_control/tests.rs (100%) rename codex-rs/{app-server => app-server-transport}/src/transport/remote_control/websocket.rs (100%) rename codex-rs/{app-server => app-server-transport}/src/transport/stdio.rs (98%) rename codex-rs/{app-server => app-server-transport}/src/transport/unix_socket.rs (99%) rename codex-rs/{app-server => app-server-transport}/src/transport/unix_socket_tests.rs (100%) rename codex-rs/{app-server => app-server-transport}/src/transport/websocket.rs (99%) create mode 100644 codex-rs/app-server/src/transport.rs delete mode 100644 codex-rs/app-server/src/transport/mod.rs create mode 100644 codex-rs/app-server/src/transport_tests.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 056bae406..2c18d3e57 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1857,8 +1857,8 @@ dependencies = [ "chrono", "clap", "codex-analytics", - "codex-api", "codex-app-server-protocol", + "codex-app-server-transport", "codex-arg0", "codex-backend-client", "codex-chatgpt", @@ -1891,23 +1891,17 @@ dependencies = [ "codex-state", "codex-thread-store", "codex-tools", - "codex-uds", "codex-utils-absolute-path", "codex-utils-cargo-bin", "codex-utils-cli", "codex-utils-json-to-toml", "codex-utils-pty", - "codex-utils-rustls-provider", - "constant_time_eq 0.3.1", "core_test_support", "flate2", "futures", - "gethostname", "hmac", - "jsonwebtoken", "opentelemetry", "opentelemetry_sdk", - "owo-colors", "pretty_assertions", "reqwest", "rmcp", @@ -2005,6 +1999,45 @@ dependencies = [ "uuid", ] +[[package]] +name = "codex-app-server-transport" +version = "0.0.0" +dependencies = [ + "anyhow", + "axum", + "base64 0.22.1", + "chrono", + "clap", + "codex-api", + "codex-app-server-protocol", + "codex-config", + "codex-core", + "codex-login", + "codex-model-provider", + "codex-state", + "codex-uds", + "codex-utils-absolute-path", + "codex-utils-rustls-provider", + "constant_time_eq 0.3.1", + "futures", + "gethostname", + "hmac", + "jsonwebtoken", + "owo-colors", + "pretty_assertions", + "serde", + "serde_json", + "sha2", + "tempfile", + "time", + "tokio", + "tokio-tungstenite", + "tokio-util", + "tracing", + "url", + "uuid", +] + [[package]] name = "codex-apply-patch" version = "0.0.0" diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index 79d932c8b..2efba8b63 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -8,6 +8,7 @@ members = [ "ansi-escape", "async-utils", "app-server", + "app-server-transport", "app-server-client", "app-server-protocol", "app-server-test-client", @@ -127,6 +128,7 @@ codex-ansi-escape = { path = "ansi-escape" } codex-api = { path = "codex-api" } codex-aws-auth = { path = "aws-auth" } codex-app-server = { path = "app-server" } +codex-app-server-transport = { path = "app-server-transport" } codex-app-server-client = { path = "app-server-client" } codex-app-server-protocol = { path = "app-server-protocol" } codex-app-server-test-client = { path = "app-server-test-client" } diff --git a/codex-rs/app-server-transport/BUILD.bazel b/codex-rs/app-server-transport/BUILD.bazel new file mode 100644 index 000000000..f6ecba680 --- /dev/null +++ b/codex-rs/app-server-transport/BUILD.bazel @@ -0,0 +1,6 @@ +load("//:defs.bzl", "codex_rust_crate") + +codex_rust_crate( + name = "app-server-transport", + crate_name = "codex_app_server_transport", +) diff --git a/codex-rs/app-server-transport/Cargo.toml b/codex-rs/app-server-transport/Cargo.toml new file mode 100644 index 000000000..d1f89c5b5 --- /dev/null +++ b/codex-rs/app-server-transport/Cargo.toml @@ -0,0 +1,58 @@ +[package] +name = "codex-app-server-transport" +version.workspace = true +edition.workspace = true +license.workspace = true + +[lib] +name = "codex_app_server_transport" +path = "src/lib.rs" + +[lints] +workspace = true + +[dependencies] +anyhow = { workspace = true } +axum = { workspace = true, default-features = false, features = [ + "http1", + "json", + "tokio", + "ws", +] } +base64 = { workspace = true } +clap = { workspace = true, features = ["derive"] } +codex-api = { workspace = true } +codex-app-server-protocol = { workspace = true } +codex-core = { workspace = true } +codex-login = { workspace = true } +codex-model-provider = { workspace = true } +codex-state = { workspace = true } +codex-uds = { workspace = true } +codex-utils-absolute-path = { workspace = true } +codex-utils-rustls-provider = { workspace = true } +constant_time_eq = { workspace = true } +futures = { workspace = true } +gethostname = { workspace = true } +hmac = { workspace = true } +jsonwebtoken = { workspace = true } +owo-colors = { workspace = true, features = ["supports-colors"] } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +sha2 = { workspace = true } +time = { workspace = true } +tokio = { workspace = true, features = [ + "io-std", + "macros", + "rt-multi-thread", +] } +tokio-tungstenite = { workspace = true } +tokio-util = { workspace = true } +tracing = { workspace = true, features = ["log"] } +url = { workspace = true } +uuid = { workspace = true, features = ["serde", "v7"] } + +[dev-dependencies] +chrono = { workspace = true } +codex-config = { workspace = true } +pretty_assertions = { workspace = true } +tempfile = { workspace = true } diff --git a/codex-rs/app-server-transport/src/lib.rs b/codex-rs/app-server-transport/src/lib.rs new file mode 100644 index 000000000..0a5c080ac --- /dev/null +++ b/codex-rs/app-server-transport/src/lib.rs @@ -0,0 +1,20 @@ +mod outgoing_message; +mod transport; + +pub use outgoing_message::ConnectionId; +pub use outgoing_message::OutgoingError; +pub use outgoing_message::OutgoingMessage; +pub use outgoing_message::OutgoingResponse; +pub use outgoing_message::QueuedOutgoingMessage; +pub use transport::AppServerTransport; +pub use transport::AppServerTransportParseError; +pub use transport::CHANNEL_CAPACITY; +pub use transport::ConnectionOrigin; +pub use transport::RemoteControlHandle; +pub use transport::TransportEvent; +pub use transport::app_server_control_socket_path; +pub use transport::auth; +pub use transport::start_control_socket_acceptor; +pub use transport::start_remote_control; +pub use transport::start_stdio_connection; +pub use transport::start_websocket_acceptor; diff --git a/codex-rs/app-server-transport/src/outgoing_message.rs b/codex-rs/app-server-transport/src/outgoing_message.rs new file mode 100644 index 000000000..ff56b9fef --- /dev/null +++ b/codex-rs/app-server-transport/src/outgoing_message.rs @@ -0,0 +1,58 @@ +use std::fmt; + +use codex_app_server_protocol::JSONRPCErrorError; +use codex_app_server_protocol::RequestId; +use codex_app_server_protocol::Result; +use codex_app_server_protocol::ServerNotification; +use codex_app_server_protocol::ServerRequest; +use serde::Serialize; +use tokio::sync::oneshot; + +/// Stable identifier for a transport connection. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub struct ConnectionId(pub u64); + +impl fmt::Display for ConnectionId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Outgoing message from the server to the client. +#[derive(Debug, Clone, Serialize)] +#[serde(untagged)] +pub enum OutgoingMessage { + Request(ServerRequest), + /// AppServerNotification is specific to the case where this is run as an + /// "app server" as opposed to an MCP server. + AppServerNotification(ServerNotification), + Response(OutgoingResponse), + Error(OutgoingError), +} + +#[derive(Debug, Clone, PartialEq, Serialize)] +pub struct OutgoingResponse { + pub id: RequestId, + pub result: Result, +} + +#[derive(Debug, Clone, PartialEq, Serialize)] +pub struct OutgoingError { + pub error: JSONRPCErrorError, + pub id: RequestId, +} + +#[derive(Debug)] +pub struct QueuedOutgoingMessage { + pub message: OutgoingMessage, + pub write_complete_tx: Option>, +} + +impl QueuedOutgoingMessage { + pub fn new(message: OutgoingMessage) -> Self { + Self { + message, + write_complete_tx: None, + } + } +} diff --git a/codex-rs/app-server/src/transport/auth.rs b/codex-rs/app-server-transport/src/transport/auth.rs similarity index 99% rename from codex-rs/app-server/src/transport/auth.rs rename to codex-rs/app-server-transport/src/transport/auth.rs index 45f44a36c..9ec025f66 100644 --- a/codex-rs/app-server/src/transport/auth.rs +++ b/codex-rs/app-server-transport/src/transport/auth.rs @@ -86,7 +86,7 @@ pub enum AppServerWebsocketCapabilityTokenSource { } #[derive(Clone, Debug, Default)] -pub(crate) struct WebsocketAuthPolicy { +pub struct WebsocketAuthPolicy { pub(crate) mode: Option, } @@ -219,7 +219,7 @@ impl AppServerWebsocketAuthArgs { } } -pub(crate) fn policy_from_settings( +pub fn policy_from_settings( settings: &AppServerWebsocketAuthSettings, ) -> io::Result { let mode = match settings.config.as_ref() { diff --git a/codex-rs/app-server-transport/src/transport/mod.rs b/codex-rs/app-server-transport/src/transport/mod.rs new file mode 100644 index 000000000..e1590ab43 --- /dev/null +++ b/codex-rs/app-server-transport/src/transport/mod.rs @@ -0,0 +1,478 @@ +pub mod auth; + +use crate::outgoing_message::ConnectionId; +use crate::outgoing_message::OutgoingError; +use crate::outgoing_message::OutgoingMessage; +use crate::outgoing_message::QueuedOutgoingMessage; +use codex_app_server_protocol::JSONRPCErrorError; +use codex_app_server_protocol::JSONRPCMessage; +use codex_core::config::find_codex_home; +use codex_utils_absolute_path::AbsolutePathBuf; +use std::net::SocketAddr; +use std::path::Path; +use std::str::FromStr; +use std::sync::atomic::AtomicU64; +use std::sync::atomic::Ordering; +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; +use tracing::error; +use tracing::warn; + +/// Size of the bounded channels used to communicate between tasks. The value +/// is a balance between throughput and memory usage - 128 messages should be +/// plenty for an interactive CLI. +pub const CHANNEL_CAPACITY: usize = 128; + +mod remote_control; +mod stdio; +mod unix_socket; +#[cfg(test)] +mod unix_socket_tests; +mod websocket; + +pub use remote_control::RemoteControlHandle; +pub use remote_control::start_remote_control; +pub use stdio::start_stdio_connection; +pub use unix_socket::start_control_socket_acceptor; +pub use websocket::start_websocket_acceptor; + +const OVERLOADED_ERROR_CODE: i64 = -32001; + +const APP_SERVER_CONTROL_SOCKET_DIR_NAME: &str = "app-server-control"; +const APP_SERVER_CONTROL_SOCKET_FILE_NAME: &str = "app-server-control.sock"; + +pub fn app_server_control_socket_path(codex_home: &Path) -> std::io::Result { + AbsolutePathBuf::from_absolute_path( + codex_home + .join(APP_SERVER_CONTROL_SOCKET_DIR_NAME) + .join(APP_SERVER_CONTROL_SOCKET_FILE_NAME), + ) +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum AppServerTransport { + Stdio, + UnixSocket { socket_path: AbsolutePathBuf }, + WebSocket { bind_address: SocketAddr }, + Off, +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum AppServerTransportParseError { + UnsupportedListenUrl(String), + InvalidUnixSocketPath { listen_url: String, message: String }, + InvalidWebSocketListenUrl(String), +} + +impl std::fmt::Display for AppServerTransportParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AppServerTransportParseError::UnsupportedListenUrl(listen_url) => write!( + f, + "unsupported --listen URL `{listen_url}`; expected `stdio://`, `unix://`, `unix://PATH`, `ws://IP:PORT`, or `off`" + ), + AppServerTransportParseError::InvalidUnixSocketPath { + listen_url, + message, + } => write!( + f, + "invalid unix socket --listen URL `{listen_url}`; failed to resolve socket path: {message}" + ), + AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url) => write!( + f, + "invalid websocket --listen URL `{listen_url}`; expected `ws://IP:PORT`" + ), + } + } +} + +impl std::error::Error for AppServerTransportParseError {} + +impl AppServerTransport { + pub const DEFAULT_LISTEN_URL: &'static str = "stdio://"; + + pub fn from_listen_url(listen_url: &str) -> Result { + if listen_url == Self::DEFAULT_LISTEN_URL { + return Ok(Self::Stdio); + } + + if let Some(raw_socket_path) = listen_url.strip_prefix("unix://") { + let socket_path = if raw_socket_path.is_empty() { + let codex_home = find_codex_home().map_err(|err| { + AppServerTransportParseError::InvalidUnixSocketPath { + listen_url: listen_url.to_string(), + message: format!("failed to resolve CODEX_HOME: {err}"), + } + })?; + app_server_control_socket_path(&codex_home).map_err(|err| { + AppServerTransportParseError::InvalidUnixSocketPath { + listen_url: listen_url.to_string(), + message: err.to_string(), + } + })? + } else { + AbsolutePathBuf::relative_to_current_dir(raw_socket_path).map_err(|err| { + AppServerTransportParseError::InvalidUnixSocketPath { + listen_url: listen_url.to_string(), + message: err.to_string(), + } + })? + }; + return Ok(Self::UnixSocket { socket_path }); + } + + if listen_url == "off" { + return Ok(Self::Off); + } + + if let Some(socket_addr) = listen_url.strip_prefix("ws://") { + let bind_address = socket_addr.parse::().map_err(|_| { + AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string()) + })?; + return Ok(Self::WebSocket { bind_address }); + } + + Err(AppServerTransportParseError::UnsupportedListenUrl( + listen_url.to_string(), + )) + } +} + +impl FromStr for AppServerTransport { + type Err = AppServerTransportParseError; + + fn from_str(s: &str) -> Result { + Self::from_listen_url(s) + } +} + +#[derive(Debug)] +pub enum TransportEvent { + ConnectionOpened { + connection_id: ConnectionId, + origin: ConnectionOrigin, + writer: mpsc::Sender, + disconnect_sender: Option, + }, + ConnectionClosed { + connection_id: ConnectionId, + }, + IncomingMessage { + connection_id: ConnectionId, + message: JSONRPCMessage, + }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConnectionOrigin { + Stdio, + InProcess, + WebSocket, + RemoteControl, +} + +impl ConnectionOrigin { + pub fn allows_device_key_requests(self) -> bool { + // Device-key endpoints are only for local connections that own the app-server instance. + // Do not include remote transports such as SSH or remote-control websocket connections. + matches!(self, Self::Stdio | Self::InProcess) + } +} + +static CONNECTION_ID_COUNTER: AtomicU64 = AtomicU64::new(0); + +fn next_connection_id() -> ConnectionId { + ConnectionId(CONNECTION_ID_COUNTER.fetch_add(1, Ordering::Relaxed)) +} + +async fn forward_incoming_message( + transport_event_tx: &mpsc::Sender, + writer: &mpsc::Sender, + connection_id: ConnectionId, + payload: &str, +) -> bool { + match serde_json::from_str::(payload) { + Ok(message) => { + enqueue_incoming_message(transport_event_tx, writer, connection_id, message).await + } + Err(err) => { + error!("Failed to deserialize JSONRPCMessage: {err}"); + true + } + } +} + +async fn enqueue_incoming_message( + transport_event_tx: &mpsc::Sender, + writer: &mpsc::Sender, + connection_id: ConnectionId, + message: JSONRPCMessage, +) -> bool { + let event = TransportEvent::IncomingMessage { + connection_id, + message, + }; + match transport_event_tx.try_send(event) { + Ok(()) => true, + Err(mpsc::error::TrySendError::Closed(_)) => false, + Err(mpsc::error::TrySendError::Full(TransportEvent::IncomingMessage { + connection_id, + message: JSONRPCMessage::Request(request), + })) => { + let overload_error = OutgoingMessage::Error(OutgoingError { + id: request.id, + error: JSONRPCErrorError { + code: OVERLOADED_ERROR_CODE, + message: "Server overloaded; retry later.".to_string(), + data: None, + }, + }); + match writer.try_send(QueuedOutgoingMessage::new(overload_error)) { + Ok(()) => true, + Err(mpsc::error::TrySendError::Closed(_)) => false, + Err(mpsc::error::TrySendError::Full(_overload_error)) => { + warn!( + "dropping overload response for connection {:?}: outbound queue is full", + connection_id + ); + true + } + } + } + Err(mpsc::error::TrySendError::Full(event)) => transport_event_tx.send(event).await.is_ok(), + } +} + +fn serialize_outgoing_message(outgoing_message: OutgoingMessage) -> Option { + let value = match serde_json::to_value(outgoing_message) { + Ok(value) => value, + Err(err) => { + error!("Failed to convert OutgoingMessage to JSON value: {err}"); + return None; + } + }; + match serde_json::to_string(&value) { + Ok(json) => Some(json), + Err(err) => { + error!("Failed to serialize JSONRPCMessage: {err}"); + None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use codex_app_server_protocol::ConfigWarningNotification; + use codex_app_server_protocol::JSONRPCNotification; + use codex_app_server_protocol::JSONRPCRequest; + use codex_app_server_protocol::JSONRPCResponse; + use codex_app_server_protocol::RequestId; + use codex_app_server_protocol::ServerNotification; + use pretty_assertions::assert_eq; + use serde_json::json; + use tokio::time::Duration; + use tokio::time::timeout; + + #[test] + fn listen_off_parses_as_off_transport() { + assert_eq!( + AppServerTransport::from_listen_url("off"), + Ok(AppServerTransport::Off) + ); + } + + #[tokio::test] + async fn enqueue_incoming_request_returns_overload_error_when_queue_is_full() { + let connection_id = ConnectionId(42); + let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1); + let (writer_tx, mut writer_rx) = mpsc::channel(1); + + let first_message = JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + transport_event_tx + .send(TransportEvent::IncomingMessage { + connection_id, + message: first_message.clone(), + }) + .await + .expect("queue should accept first message"); + + let request = JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(7), + method: "config/read".to_string(), + params: Some(json!({ "includeLayers": false })), + trace: None, + }); + assert!( + enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request).await + ); + + let queued_event = transport_event_rx + .recv() + .await + .expect("first event should stay queued"); + match queued_event { + TransportEvent::IncomingMessage { + connection_id: queued_connection_id, + message, + } => { + assert_eq!(queued_connection_id, connection_id); + assert_eq!(message, first_message); + } + _ => panic!("expected queued incoming message"), + } + + let overload = writer_rx + .recv() + .await + .expect("request should receive overload error"); + let overload_json = + serde_json::to_value(overload.message).expect("serialize overload error"); + assert_eq!( + overload_json, + json!({ + "id": 7, + "error": { + "code": OVERLOADED_ERROR_CODE, + "message": "Server overloaded; retry later." + } + }) + ); + } + + #[tokio::test] + async fn enqueue_incoming_response_waits_instead_of_dropping_when_queue_is_full() { + let connection_id = ConnectionId(42); + let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1); + let (writer_tx, _writer_rx) = mpsc::channel(1); + + let first_message = JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + transport_event_tx + .send(TransportEvent::IncomingMessage { + connection_id, + message: first_message.clone(), + }) + .await + .expect("queue should accept first message"); + + let response = JSONRPCMessage::Response(JSONRPCResponse { + id: RequestId::Integer(7), + result: json!({"ok": true}), + }); + let transport_event_tx_for_enqueue = transport_event_tx.clone(); + let writer_tx_for_enqueue = writer_tx.clone(); + let enqueue_handle = tokio::spawn(async move { + enqueue_incoming_message( + &transport_event_tx_for_enqueue, + &writer_tx_for_enqueue, + connection_id, + response, + ) + .await + }); + + let queued_event = transport_event_rx + .recv() + .await + .expect("first event should be dequeued"); + match queued_event { + TransportEvent::IncomingMessage { + connection_id: queued_connection_id, + message, + } => { + assert_eq!(queued_connection_id, connection_id); + assert_eq!(message, first_message); + } + _ => panic!("expected queued incoming message"), + } + + let enqueue_result = enqueue_handle.await.expect("enqueue task should not panic"); + assert!(enqueue_result); + + let forwarded_event = transport_event_rx + .recv() + .await + .expect("response should be forwarded instead of dropped"); + match forwarded_event { + TransportEvent::IncomingMessage { + connection_id: queued_connection_id, + message: JSONRPCMessage::Response(JSONRPCResponse { id, result }), + } => { + assert_eq!(queued_connection_id, connection_id); + assert_eq!(id, RequestId::Integer(7)); + assert_eq!(result, json!({"ok": true})); + } + _ => panic!("expected forwarded response message"), + } + } + + #[tokio::test] + async fn enqueue_incoming_request_does_not_block_when_writer_queue_is_full() { + let connection_id = ConnectionId(42); + let (transport_event_tx, _transport_event_rx) = mpsc::channel(1); + let (writer_tx, mut writer_rx) = mpsc::channel(1); + + transport_event_tx + .send(TransportEvent::IncomingMessage { + connection_id, + message: JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }), + }) + .await + .expect("transport queue should accept first message"); + + writer_tx + .send(QueuedOutgoingMessage::new( + OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { + summary: "queued".to_string(), + details: None, + path: None, + range: None, + }, + )), + )) + .await + .expect("writer queue should accept first message"); + + let request = JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(7), + method: "config/read".to_string(), + params: Some(json!({ "includeLayers": false })), + trace: None, + }); + + let enqueue_result = timeout( + Duration::from_millis(100), + enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request), + ) + .await + .expect("enqueue should not block while writer queue is full"); + assert!(enqueue_result); + + let queued_outgoing = writer_rx + .recv() + .await + .expect("writer queue should still contain original message"); + let queued_json = + serde_json::to_value(queued_outgoing.message).expect("serialize queued message"); + assert_eq!( + queued_json, + json!({ + "method": "configWarning", + "params": { + "summary": "queued", + "details": null, + }, + }) + ); + } +} diff --git a/codex-rs/app-server/src/transport/remote_control/client_tracker.rs b/codex-rs/app-server-transport/src/transport/remote_control/client_tracker.rs similarity index 100% rename from codex-rs/app-server/src/transport/remote_control/client_tracker.rs rename to codex-rs/app-server-transport/src/transport/remote_control/client_tracker.rs diff --git a/codex-rs/app-server/src/transport/remote_control/enroll.rs b/codex-rs/app-server-transport/src/transport/remote_control/enroll.rs similarity index 100% rename from codex-rs/app-server/src/transport/remote_control/enroll.rs rename to codex-rs/app-server-transport/src/transport/remote_control/enroll.rs diff --git a/codex-rs/app-server/src/transport/remote_control/mod.rs b/codex-rs/app-server-transport/src/transport/remote_control/mod.rs similarity index 93% rename from codex-rs/app-server/src/transport/remote_control/mod.rs rename to codex-rs/app-server-transport/src/transport/remote_control/mod.rs index 2d0eb7dfb..87405efa4 100644 --- a/codex-rs/app-server/src/transport/remote_control/mod.rs +++ b/codex-rs/app-server-transport/src/transport/remote_control/mod.rs @@ -36,14 +36,14 @@ pub(super) struct QueuedServerEnvelope { } #[derive(Clone)] -pub(crate) struct RemoteControlHandle { +pub struct RemoteControlHandle { enabled_tx: Arc>, status_tx: Arc>, state_db_available: bool, } impl RemoteControlHandle { - pub(crate) fn set_enabled(&self, enabled: bool) { + pub fn set_enabled(&self, enabled: bool) { let requested_enabled = enabled; let enabled = enabled && self.state_db_available; if requested_enabled && !self.state_db_available { @@ -56,14 +56,12 @@ impl RemoteControlHandle { }); } - pub(crate) fn status_receiver( - &self, - ) -> watch::Receiver { + pub fn status_receiver(&self) -> watch::Receiver { self.status_tx.subscribe() } } -pub(crate) async fn start_remote_control( +pub async fn start_remote_control( remote_control_url: String, state_db: Option>, auth_manager: Arc, diff --git a/codex-rs/app-server/src/transport/remote_control/protocol.rs b/codex-rs/app-server-transport/src/transport/remote_control/protocol.rs similarity index 100% rename from codex-rs/app-server/src/transport/remote_control/protocol.rs rename to codex-rs/app-server-transport/src/transport/remote_control/protocol.rs diff --git a/codex-rs/app-server/src/transport/remote_control/segment.rs b/codex-rs/app-server-transport/src/transport/remote_control/segment.rs similarity index 100% rename from codex-rs/app-server/src/transport/remote_control/segment.rs rename to codex-rs/app-server-transport/src/transport/remote_control/segment.rs diff --git a/codex-rs/app-server/src/transport/remote_control/segment_tests.rs b/codex-rs/app-server-transport/src/transport/remote_control/segment_tests.rs similarity index 100% rename from codex-rs/app-server/src/transport/remote_control/segment_tests.rs rename to codex-rs/app-server-transport/src/transport/remote_control/segment_tests.rs diff --git a/codex-rs/app-server/src/transport/remote_control/tests.rs b/codex-rs/app-server-transport/src/transport/remote_control/tests.rs similarity index 100% rename from codex-rs/app-server/src/transport/remote_control/tests.rs rename to codex-rs/app-server-transport/src/transport/remote_control/tests.rs diff --git a/codex-rs/app-server/src/transport/remote_control/websocket.rs b/codex-rs/app-server-transport/src/transport/remote_control/websocket.rs similarity index 100% rename from codex-rs/app-server/src/transport/remote_control/websocket.rs rename to codex-rs/app-server-transport/src/transport/remote_control/websocket.rs diff --git a/codex-rs/app-server/src/transport/stdio.rs b/codex-rs/app-server-transport/src/transport/stdio.rs similarity index 98% rename from codex-rs/app-server/src/transport/stdio.rs rename to codex-rs/app-server-transport/src/transport/stdio.rs index 14466c86c..2d30296cd 100644 --- a/codex-rs/app-server/src/transport/stdio.rs +++ b/codex-rs/app-server-transport/src/transport/stdio.rs @@ -21,7 +21,7 @@ use tracing::debug; use tracing::error; use tracing::info; -pub(crate) async fn start_stdio_connection( +pub async fn start_stdio_connection( transport_event_tx: mpsc::Sender, stdio_handles: &mut Vec>, initialize_client_name_tx: oneshot::Sender, diff --git a/codex-rs/app-server/src/transport/unix_socket.rs b/codex-rs/app-server-transport/src/transport/unix_socket.rs similarity index 99% rename from codex-rs/app-server/src/transport/unix_socket.rs rename to codex-rs/app-server-transport/src/transport/unix_socket.rs index 5ab1377fb..f75d3fe99 100644 --- a/codex-rs/app-server/src/transport/unix_socket.rs +++ b/codex-rs/app-server-transport/src/transport/unix_socket.rs @@ -20,7 +20,7 @@ use tracing::warn; #[cfg(unix)] const CONTROL_SOCKET_MODE: u32 = 0o600; -pub(crate) async fn start_control_socket_acceptor( +pub async fn start_control_socket_acceptor( socket_path: AbsolutePathBuf, transport_event_tx: mpsc::Sender, shutdown_token: CancellationToken, diff --git a/codex-rs/app-server/src/transport/unix_socket_tests.rs b/codex-rs/app-server-transport/src/transport/unix_socket_tests.rs similarity index 100% rename from codex-rs/app-server/src/transport/unix_socket_tests.rs rename to codex-rs/app-server-transport/src/transport/unix_socket_tests.rs diff --git a/codex-rs/app-server/src/transport/websocket.rs b/codex-rs/app-server-transport/src/transport/websocket.rs similarity index 99% rename from codex-rs/app-server/src/transport/websocket.rs rename to codex-rs/app-server-transport/src/transport/websocket.rs index 783018946..627197c29 100644 --- a/codex-rs/app-server/src/transport/websocket.rs +++ b/codex-rs/app-server-transport/src/transport/websocket.rs @@ -128,7 +128,7 @@ async fn websocket_upgrade_handler( .into_response() } -pub(crate) async fn start_websocket_acceptor( +pub async fn start_websocket_acceptor( bind_address: SocketAddr, transport_event_tx: mpsc::Sender, shutdown_token: CancellationToken, diff --git a/codex-rs/app-server/Cargo.toml b/codex-rs/app-server/Cargo.toml index 5d73f97c2..6d201bdee 100644 --- a/codex-rs/app-server/Cargo.toml +++ b/codex-rs/app-server/Cargo.toml @@ -30,7 +30,6 @@ axum = { workspace = true, default-features = false, features = [ "ws", ] } codex-analytics = { workspace = true } -codex-api = { workspace = true } codex-arg0 = { workspace = true } codex-cloud-requirements = { workspace = true } codex-config = { workspace = true } @@ -58,6 +57,7 @@ codex-model-provider = { workspace = true } codex-models-manager = { workspace = true } codex-protocol = { workspace = true } codex-app-server-protocol = { workspace = true } +codex-app-server-transport = { workspace = true } codex-feedback = { workspace = true } codex-rmcp-client = { workspace = true } codex-rollout = { workspace = true } @@ -65,18 +65,11 @@ codex-sandboxing = { workspace = true } codex-state = { workspace = true } codex-thread-store = { workspace = true } codex-tools = { workspace = true } -codex-uds = { workspace = true } codex-utils-absolute-path = { workspace = true } codex-utils-json-to-toml = { workspace = true } -codex-utils-rustls-provider = { workspace = true } chrono = { workspace = true } clap = { workspace = true, features = ["derive"] } -constant_time_eq = { workspace = true } futures = { workspace = true } -gethostname = { workspace = true } -hmac = { workspace = true } -jsonwebtoken = { workspace = true } -owo-colors = { workspace = true, features = ["supports-colors"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } sha2 = { workspace = true } @@ -93,7 +86,6 @@ tokio = { workspace = true, features = [ "signal", ] } tokio-util = { workspace = true } -tokio-tungstenite = { workspace = true } tracing = { workspace = true, features = ["log"] } tracing-subscriber = { workspace = true, features = ["env-filter", "fmt", "json"] } url = { workspace = true } @@ -111,6 +103,7 @@ core_test_support = { workspace = true } codex-model-provider-info = { workspace = true } codex-utils-cargo-bin = { workspace = true } flate2 = { workspace = true } +hmac = { workspace = true } opentelemetry = { workspace = true } opentelemetry_sdk = { workspace = true } pretty_assertions = { workspace = true } diff --git a/codex-rs/app-server/src/outgoing_message.rs b/codex-rs/app-server/src/outgoing_message.rs index 34441f83a..f7a90538c 100644 --- a/codex-rs/app-server/src/outgoing_message.rs +++ b/codex-rs/app-server/src/outgoing_message.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::fmt; use std::sync::Arc; use std::sync::atomic::AtomicI64; use std::sync::atomic::Ordering; @@ -15,7 +14,6 @@ use codex_app_server_protocol::ServerRequestPayload; use codex_otel::span_w3c_trace_context; use codex_protocol::ThreadId; use codex_protocol::protocol::W3cTraceContext; -use serde::Serialize; use tokio::sync::Mutex; use tokio::sync::mpsc; use tokio::sync::oneshot; @@ -26,22 +24,17 @@ use tracing::warn; use crate::error_code::INTERNAL_ERROR_CODE; use crate::error_code::internal_error; use crate::server_request_error::TURN_TRANSITION_PENDING_REQUEST_ERROR_REASON; +pub(crate) use codex_app_server_transport::ConnectionId; +pub(crate) use codex_app_server_transport::OutgoingError; +pub(crate) use codex_app_server_transport::OutgoingMessage; +pub(crate) use codex_app_server_transport::OutgoingResponse; +pub(crate) use codex_app_server_transport::QueuedOutgoingMessage; #[cfg(test)] use codex_protocol::account::PlanType; pub(crate) type ClientRequestResult = std::result::Result; -/// Stable identifier for a transport connection. -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] -pub(crate) struct ConnectionId(pub(crate) u64); - -impl fmt::Display for ConnectionId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - /// Stable identifier for a client request scoped to a transport connection. #[derive(Clone, Debug, Eq, Hash, PartialEq)] pub(crate) struct ConnectionRequestId { @@ -96,21 +89,6 @@ pub(crate) enum OutgoingEnvelope { }, } -#[derive(Debug)] -pub(crate) struct QueuedOutgoingMessage { - pub(crate) message: OutgoingMessage, - pub(crate) write_complete_tx: Option>, -} - -impl QueuedOutgoingMessage { - pub(crate) fn new(message: OutgoingMessage) -> Self { - Self { - message, - write_complete_tx: None, - } - } -} - /// Sends messages to the client and manages request callbacks. pub(crate) struct OutgoingMessageSender { next_server_request_id: AtomicI64, @@ -665,30 +643,6 @@ impl OutgoingMessageSender { } } -/// Outgoing message from the server to the client. -#[derive(Debug, Clone, Serialize)] -#[serde(untagged)] -pub(crate) enum OutgoingMessage { - Request(ServerRequest), - /// AppServerNotification is specific to the case where this is run as an - /// "app server" as opposed to an MCP server. - AppServerNotification(ServerNotification), - Response(OutgoingResponse), - Error(OutgoingError), -} - -#[derive(Debug, Clone, PartialEq, Serialize)] -pub(crate) struct OutgoingResponse { - pub id: RequestId, - pub result: Result, -} - -#[derive(Debug, Clone, PartialEq, Serialize)] -pub(crate) struct OutgoingError { - pub error: JSONRPCErrorError, - pub id: RequestId, -} - #[cfg(test)] mod tests { use std::time::Duration; diff --git a/codex-rs/app-server/src/transport.rs b/codex-rs/app-server/src/transport.rs new file mode 100644 index 000000000..9c16f8a39 --- /dev/null +++ b/codex-rs/app-server/src/transport.rs @@ -0,0 +1,232 @@ +use crate::message_processor::ConnectionSessionState; +use crate::outgoing_message::OutgoingEnvelope; +use codex_app_server_protocol::ExperimentalApi; +use codex_app_server_protocol::ServerRequest; +use std::collections::HashMap; +use std::collections::HashSet; +use std::sync::Arc; +use std::sync::RwLock; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; +use tracing::warn; + +pub use codex_app_server_transport::AppServerTransport; +pub(crate) use codex_app_server_transport::CHANNEL_CAPACITY; +pub(crate) use codex_app_server_transport::ConnectionId; +pub(crate) use codex_app_server_transport::ConnectionOrigin; +pub(crate) use codex_app_server_transport::OutgoingMessage; +pub(crate) use codex_app_server_transport::QueuedOutgoingMessage; +pub(crate) use codex_app_server_transport::RemoteControlHandle; +pub(crate) use codex_app_server_transport::TransportEvent; +pub use codex_app_server_transport::app_server_control_socket_path; +pub use codex_app_server_transport::auth; +pub(crate) use codex_app_server_transport::start_control_socket_acceptor; +pub(crate) use codex_app_server_transport::start_remote_control; +pub(crate) use codex_app_server_transport::start_stdio_connection; +pub(crate) use codex_app_server_transport::start_websocket_acceptor; + +pub(crate) struct ConnectionState { + pub(crate) outbound_initialized: Arc, + pub(crate) outbound_experimental_api_enabled: Arc, + pub(crate) outbound_opted_out_notification_methods: Arc>>, + pub(crate) session: Arc, +} + +impl ConnectionState { + pub(crate) fn new( + origin: ConnectionOrigin, + outbound_initialized: Arc, + outbound_experimental_api_enabled: Arc, + outbound_opted_out_notification_methods: Arc>>, + ) -> Self { + Self { + outbound_initialized, + outbound_experimental_api_enabled, + outbound_opted_out_notification_methods, + session: Arc::new(ConnectionSessionState::new(origin)), + } + } +} + +pub(crate) struct OutboundConnectionState { + pub(crate) initialized: Arc, + pub(crate) experimental_api_enabled: Arc, + pub(crate) opted_out_notification_methods: Arc>>, + pub(crate) writer: mpsc::Sender, + disconnect_sender: Option, +} + +impl OutboundConnectionState { + pub(crate) fn new( + writer: mpsc::Sender, + initialized: Arc, + experimental_api_enabled: Arc, + opted_out_notification_methods: Arc>>, + disconnect_sender: Option, + ) -> Self { + Self { + initialized, + experimental_api_enabled, + opted_out_notification_methods, + writer, + disconnect_sender, + } + } + + fn can_disconnect(&self) -> bool { + self.disconnect_sender.is_some() + } + + pub(crate) fn request_disconnect(&self) { + if let Some(disconnect_sender) = &self.disconnect_sender { + disconnect_sender.cancel(); + } + } +} + +fn should_skip_notification_for_connection( + connection_state: &OutboundConnectionState, + message: &OutgoingMessage, +) -> bool { + let Ok(opted_out_notification_methods) = connection_state.opted_out_notification_methods.read() + else { + warn!("failed to read outbound opted-out notifications"); + return false; + }; + match message { + OutgoingMessage::AppServerNotification(notification) => { + if notification.experimental_reason().is_some() + && !connection_state + .experimental_api_enabled + .load(Ordering::Acquire) + { + return true; + } + let method = notification.to_string(); + opted_out_notification_methods.contains(method.as_str()) + } + _ => false, + } +} + +fn disconnect_connection( + connections: &mut HashMap, + connection_id: ConnectionId, +) -> bool { + if let Some(connection_state) = connections.remove(&connection_id) { + connection_state.request_disconnect(); + return true; + } + false +} + +async fn send_message_to_connection( + connections: &mut HashMap, + connection_id: ConnectionId, + message: OutgoingMessage, + write_complete_tx: Option>, +) -> bool { + let Some(connection_state) = connections.get(&connection_id) else { + warn!("dropping message for disconnected connection: {connection_id:?}"); + return false; + }; + let message = filter_outgoing_message_for_connection(connection_state, message); + if should_skip_notification_for_connection(connection_state, &message) { + return false; + } + + let writer = connection_state.writer.clone(); + let queued_message = QueuedOutgoingMessage { + message, + write_complete_tx, + }; + if connection_state.can_disconnect() { + match writer.try_send(queued_message) { + Ok(()) => false, + Err(mpsc::error::TrySendError::Full(_)) => { + warn!( + "disconnecting slow connection after outbound queue filled: {connection_id:?}" + ); + disconnect_connection(connections, connection_id) + } + Err(mpsc::error::TrySendError::Closed(_)) => { + disconnect_connection(connections, connection_id) + } + } + } else if writer.send(queued_message).await.is_err() { + disconnect_connection(connections, connection_id) + } else { + false + } +} + +fn filter_outgoing_message_for_connection( + connection_state: &OutboundConnectionState, + message: OutgoingMessage, +) -> OutgoingMessage { + let experimental_api_enabled = connection_state + .experimental_api_enabled + .load(Ordering::Acquire); + match message { + OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval { + request_id, + mut params, + }) => { + if !experimental_api_enabled { + params.strip_experimental_fields(); + } + OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval { + request_id, + params, + }) + } + _ => message, + } +} + +pub(crate) async fn route_outgoing_envelope( + connections: &mut HashMap, + envelope: OutgoingEnvelope, +) { + match envelope { + OutgoingEnvelope::ToConnection { + connection_id, + message, + write_complete_tx, + } => { + let _ = + send_message_to_connection(connections, connection_id, message, write_complete_tx) + .await; + } + OutgoingEnvelope::Broadcast { message } => { + let target_connections: Vec = connections + .iter() + .filter_map(|(connection_id, connection_state)| { + if connection_state.initialized.load(Ordering::Acquire) + && !should_skip_notification_for_connection(connection_state, &message) + { + Some(*connection_id) + } else { + None + } + }) + .collect(); + + for connection_id in target_connections { + let _ = send_message_to_connection( + connections, + connection_id, + message.clone(), + /*write_complete_tx*/ None, + ) + .await; + } + } + } +} + +#[cfg(test)] +#[path = "transport_tests.rs"] +mod tests; diff --git a/codex-rs/app-server/src/transport/mod.rs b/codex-rs/app-server/src/transport/mod.rs deleted file mode 100644 index b610f099a..000000000 --- a/codex-rs/app-server/src/transport/mod.rs +++ /dev/null @@ -1,1210 +0,0 @@ -pub(crate) mod auth; - -use crate::error_code::OVERLOADED_ERROR_CODE; -use crate::message_processor::ConnectionSessionState; -use crate::outgoing_message::ConnectionId; -use crate::outgoing_message::OutgoingEnvelope; -use crate::outgoing_message::OutgoingError; -use crate::outgoing_message::OutgoingMessage; -use crate::outgoing_message::QueuedOutgoingMessage; -use codex_app_server_protocol::ExperimentalApi; -use codex_app_server_protocol::JSONRPCErrorError; -use codex_app_server_protocol::JSONRPCMessage; -use codex_app_server_protocol::ServerRequest; -use codex_core::config::find_codex_home; -use codex_utils_absolute_path::AbsolutePathBuf; -use std::collections::HashMap; -use std::collections::HashSet; -use std::net::SocketAddr; -use std::path::Path; -use std::str::FromStr; -use std::sync::Arc; -use std::sync::RwLock; -use std::sync::atomic::AtomicBool; -use std::sync::atomic::AtomicU64; -use std::sync::atomic::Ordering; -use tokio::sync::mpsc; -use tokio_util::sync::CancellationToken; -use tracing::error; -use tracing::warn; - -/// Size of the bounded channels used to communicate between tasks. The value -/// is a balance between throughput and memory usage - 128 messages should be -/// plenty for an interactive CLI. -pub(crate) const CHANNEL_CAPACITY: usize = 128; - -mod remote_control; -mod stdio; -mod unix_socket; -#[cfg(test)] -mod unix_socket_tests; -mod websocket; - -pub(crate) use remote_control::RemoteControlHandle; -pub(crate) use remote_control::start_remote_control; -pub(crate) use stdio::start_stdio_connection; -pub(crate) use unix_socket::start_control_socket_acceptor; -pub(crate) use websocket::start_websocket_acceptor; - -const APP_SERVER_CONTROL_SOCKET_DIR_NAME: &str = "app-server-control"; -const APP_SERVER_CONTROL_SOCKET_FILE_NAME: &str = "app-server-control.sock"; - -pub fn app_server_control_socket_path(codex_home: &Path) -> std::io::Result { - AbsolutePathBuf::from_absolute_path( - codex_home - .join(APP_SERVER_CONTROL_SOCKET_DIR_NAME) - .join(APP_SERVER_CONTROL_SOCKET_FILE_NAME), - ) -} - -#[derive(Clone, Debug, Eq, PartialEq)] -pub enum AppServerTransport { - Stdio, - UnixSocket { socket_path: AbsolutePathBuf }, - WebSocket { bind_address: SocketAddr }, - Off, -} - -#[derive(Debug, Clone, Eq, PartialEq)] -pub enum AppServerTransportParseError { - UnsupportedListenUrl(String), - InvalidUnixSocketPath { listen_url: String, message: String }, - InvalidWebSocketListenUrl(String), -} - -impl std::fmt::Display for AppServerTransportParseError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - AppServerTransportParseError::UnsupportedListenUrl(listen_url) => write!( - f, - "unsupported --listen URL `{listen_url}`; expected `stdio://`, `unix://`, `unix://PATH`, `ws://IP:PORT`, or `off`" - ), - AppServerTransportParseError::InvalidUnixSocketPath { - listen_url, - message, - } => write!( - f, - "invalid unix socket --listen URL `{listen_url}`; failed to resolve socket path: {message}" - ), - AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url) => write!( - f, - "invalid websocket --listen URL `{listen_url}`; expected `ws://IP:PORT`" - ), - } - } -} - -impl std::error::Error for AppServerTransportParseError {} - -impl AppServerTransport { - pub const DEFAULT_LISTEN_URL: &'static str = "stdio://"; - - pub fn from_listen_url(listen_url: &str) -> Result { - if listen_url == Self::DEFAULT_LISTEN_URL { - return Ok(Self::Stdio); - } - - if let Some(raw_socket_path) = listen_url.strip_prefix("unix://") { - let socket_path = if raw_socket_path.is_empty() { - let codex_home = find_codex_home().map_err(|err| { - AppServerTransportParseError::InvalidUnixSocketPath { - listen_url: listen_url.to_string(), - message: format!("failed to resolve CODEX_HOME: {err}"), - } - })?; - app_server_control_socket_path(&codex_home).map_err(|err| { - AppServerTransportParseError::InvalidUnixSocketPath { - listen_url: listen_url.to_string(), - message: err.to_string(), - } - })? - } else { - AbsolutePathBuf::relative_to_current_dir(raw_socket_path).map_err(|err| { - AppServerTransportParseError::InvalidUnixSocketPath { - listen_url: listen_url.to_string(), - message: err.to_string(), - } - })? - }; - return Ok(Self::UnixSocket { socket_path }); - } - - if listen_url == "off" { - return Ok(Self::Off); - } - - if let Some(socket_addr) = listen_url.strip_prefix("ws://") { - let bind_address = socket_addr.parse::().map_err(|_| { - AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string()) - })?; - return Ok(Self::WebSocket { bind_address }); - } - - Err(AppServerTransportParseError::UnsupportedListenUrl( - listen_url.to_string(), - )) - } -} - -impl FromStr for AppServerTransport { - type Err = AppServerTransportParseError; - - fn from_str(s: &str) -> Result { - Self::from_listen_url(s) - } -} - -#[derive(Debug)] -pub(crate) enum TransportEvent { - ConnectionOpened { - connection_id: ConnectionId, - origin: ConnectionOrigin, - writer: mpsc::Sender, - disconnect_sender: Option, - }, - ConnectionClosed { - connection_id: ConnectionId, - }, - IncomingMessage { - connection_id: ConnectionId, - message: JSONRPCMessage, - }, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(crate) enum ConnectionOrigin { - Stdio, - InProcess, - WebSocket, - RemoteControl, -} - -impl ConnectionOrigin { - pub(crate) fn allows_device_key_requests(self) -> bool { - // Device-key endpoints are only for local connections that own the app-server instance. - // Do not include remote transports such as SSH or remote-control websocket connections. - matches!(self, Self::Stdio | Self::InProcess) - } -} - -pub(crate) struct ConnectionState { - pub(crate) outbound_initialized: Arc, - pub(crate) outbound_experimental_api_enabled: Arc, - pub(crate) outbound_opted_out_notification_methods: Arc>>, - pub(crate) session: Arc, -} - -impl ConnectionState { - pub(crate) fn new( - origin: ConnectionOrigin, - outbound_initialized: Arc, - outbound_experimental_api_enabled: Arc, - outbound_opted_out_notification_methods: Arc>>, - ) -> Self { - Self { - outbound_initialized, - outbound_experimental_api_enabled, - outbound_opted_out_notification_methods, - session: Arc::new(ConnectionSessionState::new(origin)), - } - } -} - -pub(crate) struct OutboundConnectionState { - pub(crate) initialized: Arc, - pub(crate) experimental_api_enabled: Arc, - pub(crate) opted_out_notification_methods: Arc>>, - pub(crate) writer: mpsc::Sender, - disconnect_sender: Option, -} - -impl OutboundConnectionState { - pub(crate) fn new( - writer: mpsc::Sender, - initialized: Arc, - experimental_api_enabled: Arc, - opted_out_notification_methods: Arc>>, - disconnect_sender: Option, - ) -> Self { - Self { - initialized, - experimental_api_enabled, - opted_out_notification_methods, - writer, - disconnect_sender, - } - } - - fn can_disconnect(&self) -> bool { - self.disconnect_sender.is_some() - } - - pub(crate) fn request_disconnect(&self) { - if let Some(disconnect_sender) = &self.disconnect_sender { - disconnect_sender.cancel(); - } - } -} - -static CONNECTION_ID_COUNTER: AtomicU64 = AtomicU64::new(0); - -fn next_connection_id() -> ConnectionId { - ConnectionId(CONNECTION_ID_COUNTER.fetch_add(1, Ordering::Relaxed)) -} - -async fn forward_incoming_message( - transport_event_tx: &mpsc::Sender, - writer: &mpsc::Sender, - connection_id: ConnectionId, - payload: &str, -) -> bool { - match serde_json::from_str::(payload) { - Ok(message) => { - enqueue_incoming_message(transport_event_tx, writer, connection_id, message).await - } - Err(err) => { - error!("Failed to deserialize JSONRPCMessage: {err}"); - true - } - } -} - -async fn enqueue_incoming_message( - transport_event_tx: &mpsc::Sender, - writer: &mpsc::Sender, - connection_id: ConnectionId, - message: JSONRPCMessage, -) -> bool { - let event = TransportEvent::IncomingMessage { - connection_id, - message, - }; - match transport_event_tx.try_send(event) { - Ok(()) => true, - Err(mpsc::error::TrySendError::Closed(_)) => false, - Err(mpsc::error::TrySendError::Full(TransportEvent::IncomingMessage { - connection_id, - message: JSONRPCMessage::Request(request), - })) => { - let overload_error = OutgoingMessage::Error(OutgoingError { - id: request.id, - error: JSONRPCErrorError { - code: OVERLOADED_ERROR_CODE, - message: "Server overloaded; retry later.".to_string(), - data: None, - }, - }); - match writer.try_send(QueuedOutgoingMessage::new(overload_error)) { - Ok(()) => true, - Err(mpsc::error::TrySendError::Closed(_)) => false, - Err(mpsc::error::TrySendError::Full(_overload_error)) => { - warn!( - "dropping overload response for connection {:?}: outbound queue is full", - connection_id - ); - true - } - } - } - Err(mpsc::error::TrySendError::Full(event)) => transport_event_tx.send(event).await.is_ok(), - } -} - -fn serialize_outgoing_message(outgoing_message: OutgoingMessage) -> Option { - let value = match serde_json::to_value(outgoing_message) { - Ok(value) => value, - Err(err) => { - error!("Failed to convert OutgoingMessage to JSON value: {err}"); - return None; - } - }; - match serde_json::to_string(&value) { - Ok(json) => Some(json), - Err(err) => { - error!("Failed to serialize JSONRPCMessage: {err}"); - None - } - } -} - -fn should_skip_notification_for_connection( - connection_state: &OutboundConnectionState, - message: &OutgoingMessage, -) -> bool { - let Ok(opted_out_notification_methods) = connection_state.opted_out_notification_methods.read() - else { - warn!("failed to read outbound opted-out notifications"); - return false; - }; - match message { - OutgoingMessage::AppServerNotification(notification) => { - if notification.experimental_reason().is_some() - && !connection_state - .experimental_api_enabled - .load(Ordering::Acquire) - { - return true; - } - let method = notification.to_string(); - opted_out_notification_methods.contains(method.as_str()) - } - _ => false, - } -} - -fn disconnect_connection( - connections: &mut HashMap, - connection_id: ConnectionId, -) -> bool { - if let Some(connection_state) = connections.remove(&connection_id) { - connection_state.request_disconnect(); - return true; - } - false -} - -async fn send_message_to_connection( - connections: &mut HashMap, - connection_id: ConnectionId, - message: OutgoingMessage, - write_complete_tx: Option>, -) -> bool { - let Some(connection_state) = connections.get(&connection_id) else { - warn!("dropping message for disconnected connection: {connection_id:?}"); - return false; - }; - let message = filter_outgoing_message_for_connection(connection_state, message); - if should_skip_notification_for_connection(connection_state, &message) { - return false; - } - - let writer = connection_state.writer.clone(); - let queued_message = QueuedOutgoingMessage { - message, - write_complete_tx, - }; - if connection_state.can_disconnect() { - match writer.try_send(queued_message) { - Ok(()) => false, - Err(mpsc::error::TrySendError::Full(_)) => { - warn!( - "disconnecting slow connection after outbound queue filled: {connection_id:?}" - ); - disconnect_connection(connections, connection_id) - } - Err(mpsc::error::TrySendError::Closed(_)) => { - disconnect_connection(connections, connection_id) - } - } - } else if writer.send(queued_message).await.is_err() { - disconnect_connection(connections, connection_id) - } else { - false - } -} - -fn filter_outgoing_message_for_connection( - connection_state: &OutboundConnectionState, - message: OutgoingMessage, -) -> OutgoingMessage { - let experimental_api_enabled = connection_state - .experimental_api_enabled - .load(Ordering::Acquire); - match message { - OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval { - request_id, - mut params, - }) => { - if !experimental_api_enabled { - params.strip_experimental_fields(); - } - OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval { - request_id, - params, - }) - } - _ => message, - } -} - -pub(crate) async fn route_outgoing_envelope( - connections: &mut HashMap, - envelope: OutgoingEnvelope, -) { - match envelope { - OutgoingEnvelope::ToConnection { - connection_id, - message, - write_complete_tx, - } => { - let _ = - send_message_to_connection(connections, connection_id, message, write_complete_tx) - .await; - } - OutgoingEnvelope::Broadcast { message } => { - let target_connections: Vec = connections - .iter() - .filter_map(|(connection_id, connection_state)| { - if connection_state.initialized.load(Ordering::Acquire) - && !should_skip_notification_for_connection(connection_state, &message) - { - Some(*connection_id) - } else { - None - } - }) - .collect(); - - for connection_id in target_connections { - let _ = send_message_to_connection( - connections, - connection_id, - message.clone(), - /*write_complete_tx*/ None, - ) - .await; - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use codex_app_server_protocol::ConfigWarningNotification; - use codex_app_server_protocol::JSONRPCNotification; - use codex_app_server_protocol::JSONRPCRequest; - use codex_app_server_protocol::JSONRPCResponse; - use codex_app_server_protocol::RequestId; - use codex_app_server_protocol::ServerNotification; - use codex_app_server_protocol::ThreadGoal; - use codex_app_server_protocol::ThreadGoalStatus; - use codex_app_server_protocol::ThreadGoalUpdatedNotification; - use codex_utils_absolute_path::AbsolutePathBuf; - use pretty_assertions::assert_eq; - use serde_json::json; - use tokio::time::Duration; - use tokio::time::timeout; - - fn absolute_path(path: &str) -> AbsolutePathBuf { - AbsolutePathBuf::from_absolute_path(path).expect("absolute path") - } - - fn thread_goal_updated_notification() -> ServerNotification { - ServerNotification::ThreadGoalUpdated(ThreadGoalUpdatedNotification { - thread_id: "thread-1".to_string(), - turn_id: None, - goal: ThreadGoal { - thread_id: "thread-1".to_string(), - objective: "ship goal mode".to_string(), - status: ThreadGoalStatus::Active, - token_budget: None, - tokens_used: 0, - time_used_seconds: 0, - created_at: 1, - updated_at: 1, - }, - }) - } - - #[test] - fn listen_off_parses_as_off_transport() { - assert_eq!( - AppServerTransport::from_listen_url("off"), - Ok(AppServerTransport::Off) - ); - } - - #[tokio::test] - async fn enqueue_incoming_request_returns_overload_error_when_queue_is_full() { - let connection_id = ConnectionId(42); - let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1); - let (writer_tx, mut writer_rx) = mpsc::channel(1); - - let first_message = JSONRPCMessage::Notification(JSONRPCNotification { - method: "initialized".to_string(), - params: None, - }); - transport_event_tx - .send(TransportEvent::IncomingMessage { - connection_id, - message: first_message.clone(), - }) - .await - .expect("queue should accept first message"); - - let request = JSONRPCMessage::Request(JSONRPCRequest { - id: RequestId::Integer(7), - method: "config/read".to_string(), - params: Some(json!({ "includeLayers": false })), - trace: None, - }); - assert!( - enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request).await - ); - - let queued_event = transport_event_rx - .recv() - .await - .expect("first event should stay queued"); - match queued_event { - TransportEvent::IncomingMessage { - connection_id: queued_connection_id, - message, - } => { - assert_eq!(queued_connection_id, connection_id); - assert_eq!(message, first_message); - } - _ => panic!("expected queued incoming message"), - } - - let overload = writer_rx - .recv() - .await - .expect("request should receive overload error"); - let overload_json = - serde_json::to_value(overload.message).expect("serialize overload error"); - assert_eq!( - overload_json, - json!({ - "id": 7, - "error": { - "code": OVERLOADED_ERROR_CODE, - "message": "Server overloaded; retry later." - } - }) - ); - } - - #[tokio::test] - async fn enqueue_incoming_response_waits_instead_of_dropping_when_queue_is_full() { - let connection_id = ConnectionId(42); - let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1); - let (writer_tx, _writer_rx) = mpsc::channel(1); - - let first_message = JSONRPCMessage::Notification(JSONRPCNotification { - method: "initialized".to_string(), - params: None, - }); - transport_event_tx - .send(TransportEvent::IncomingMessage { - connection_id, - message: first_message.clone(), - }) - .await - .expect("queue should accept first message"); - - let response = JSONRPCMessage::Response(JSONRPCResponse { - id: RequestId::Integer(7), - result: json!({"ok": true}), - }); - let transport_event_tx_for_enqueue = transport_event_tx.clone(); - let writer_tx_for_enqueue = writer_tx.clone(); - let enqueue_handle = tokio::spawn(async move { - enqueue_incoming_message( - &transport_event_tx_for_enqueue, - &writer_tx_for_enqueue, - connection_id, - response, - ) - .await - }); - - let queued_event = transport_event_rx - .recv() - .await - .expect("first event should be dequeued"); - match queued_event { - TransportEvent::IncomingMessage { - connection_id: queued_connection_id, - message, - } => { - assert_eq!(queued_connection_id, connection_id); - assert_eq!(message, first_message); - } - _ => panic!("expected queued incoming message"), - } - - let enqueue_result = enqueue_handle.await.expect("enqueue task should not panic"); - assert!(enqueue_result); - - let forwarded_event = transport_event_rx - .recv() - .await - .expect("response should be forwarded instead of dropped"); - match forwarded_event { - TransportEvent::IncomingMessage { - connection_id: queued_connection_id, - message: JSONRPCMessage::Response(JSONRPCResponse { id, result }), - } => { - assert_eq!(queued_connection_id, connection_id); - assert_eq!(id, RequestId::Integer(7)); - assert_eq!(result, json!({"ok": true})); - } - _ => panic!("expected forwarded response message"), - } - } - - #[tokio::test] - async fn enqueue_incoming_request_does_not_block_when_writer_queue_is_full() { - let connection_id = ConnectionId(42); - let (transport_event_tx, _transport_event_rx) = mpsc::channel(1); - let (writer_tx, mut writer_rx) = mpsc::channel(1); - - transport_event_tx - .send(TransportEvent::IncomingMessage { - connection_id, - message: JSONRPCMessage::Notification(JSONRPCNotification { - method: "initialized".to_string(), - params: None, - }), - }) - .await - .expect("transport queue should accept first message"); - - writer_tx - .send(QueuedOutgoingMessage::new( - OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( - ConfigWarningNotification { - summary: "queued".to_string(), - details: None, - path: None, - range: None, - }, - )), - )) - .await - .expect("writer queue should accept first message"); - - let request = JSONRPCMessage::Request(JSONRPCRequest { - id: RequestId::Integer(7), - method: "config/read".to_string(), - params: Some(json!({ "includeLayers": false })), - trace: None, - }); - - let enqueue_result = timeout( - Duration::from_millis(100), - enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request), - ) - .await - .expect("enqueue should not block while writer queue is full"); - assert!(enqueue_result); - - let queued_outgoing = writer_rx - .recv() - .await - .expect("writer queue should still contain original message"); - let queued_json = - serde_json::to_value(queued_outgoing.message).expect("serialize queued message"); - assert_eq!( - queued_json, - json!({ - "method": "configWarning", - "params": { - "summary": "queued", - "details": null, - }, - }) - ); - } - - #[tokio::test] - async fn to_connection_notification_respects_opt_out_filters() { - let connection_id = ConnectionId(7); - let (writer_tx, mut writer_rx) = mpsc::channel(1); - let initialized = Arc::new(AtomicBool::new(true)); - let opted_out_notification_methods = - Arc::new(RwLock::new(HashSet::from(["configWarning".to_string()]))); - - let mut connections = HashMap::new(); - connections.insert( - connection_id, - OutboundConnectionState::new( - writer_tx, - initialized, - Arc::new(AtomicBool::new(true)), - opted_out_notification_methods, - /*disconnect_sender*/ None, - ), - ); - - route_outgoing_envelope( - &mut connections, - OutgoingEnvelope::ToConnection { - connection_id, - message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( - ConfigWarningNotification { - summary: "task_started".to_string(), - details: None, - path: None, - range: None, - }, - )), - write_complete_tx: None, - }, - ) - .await; - - assert!( - writer_rx.try_recv().is_err(), - "opted-out notification should be dropped" - ); - } - - #[tokio::test] - async fn to_connection_notifications_are_dropped_for_opted_out_clients() { - let connection_id = ConnectionId(10); - let (writer_tx, mut writer_rx) = mpsc::channel(1); - - let mut connections = HashMap::new(); - connections.insert( - connection_id, - OutboundConnectionState::new( - writer_tx, - Arc::new(AtomicBool::new(true)), - Arc::new(AtomicBool::new(true)), - Arc::new(RwLock::new(HashSet::from(["configWarning".to_string()]))), - /*disconnect_sender*/ None, - ), - ); - - route_outgoing_envelope( - &mut connections, - OutgoingEnvelope::ToConnection { - connection_id, - message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( - ConfigWarningNotification { - summary: "task_started".to_string(), - details: None, - path: None, - range: None, - }, - )), - write_complete_tx: None, - }, - ) - .await; - - assert!( - writer_rx.try_recv().is_err(), - "opted-out notifications should not reach clients" - ); - } - - #[tokio::test] - async fn to_connection_notifications_are_preserved_for_non_opted_out_clients() { - let connection_id = ConnectionId(11); - let (writer_tx, mut writer_rx) = mpsc::channel(1); - - let mut connections = HashMap::new(); - connections.insert( - connection_id, - OutboundConnectionState::new( - writer_tx, - Arc::new(AtomicBool::new(true)), - Arc::new(AtomicBool::new(true)), - Arc::new(RwLock::new(HashSet::new())), - /*disconnect_sender*/ None, - ), - ); - - route_outgoing_envelope( - &mut connections, - OutgoingEnvelope::ToConnection { - connection_id, - message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( - ConfigWarningNotification { - summary: "task_started".to_string(), - details: None, - path: None, - range: None, - }, - )), - write_complete_tx: None, - }, - ) - .await; - - let message = writer_rx - .recv() - .await - .expect("notification should reach non-opted-out clients"); - assert!(matches!( - message.message, - OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( - ConfigWarningNotification { summary, .. } - )) if summary == "task_started" - )); - } - - #[tokio::test] - async fn experimental_notifications_are_dropped_without_capability() { - let connection_id = ConnectionId(12); - let (writer_tx, mut writer_rx) = mpsc::channel(1); - - let mut connections = HashMap::new(); - connections.insert( - connection_id, - OutboundConnectionState::new( - writer_tx, - Arc::new(AtomicBool::new(true)), - Arc::new(AtomicBool::new(false)), - Arc::new(RwLock::new(HashSet::new())), - /*disconnect_sender*/ None, - ), - ); - - route_outgoing_envelope( - &mut connections, - OutgoingEnvelope::ToConnection { - connection_id, - message: OutgoingMessage::AppServerNotification(thread_goal_updated_notification()), - write_complete_tx: None, - }, - ) - .await; - - assert!( - writer_rx.try_recv().is_err(), - "experimental notifications should not reach clients without capability" - ); - } - - #[tokio::test] - async fn experimental_notifications_are_preserved_with_capability() { - let connection_id = ConnectionId(13); - let (writer_tx, mut writer_rx) = mpsc::channel(1); - - let mut connections = HashMap::new(); - connections.insert( - connection_id, - OutboundConnectionState::new( - writer_tx, - Arc::new(AtomicBool::new(true)), - Arc::new(AtomicBool::new(true)), - Arc::new(RwLock::new(HashSet::new())), - /*disconnect_sender*/ None, - ), - ); - - route_outgoing_envelope( - &mut connections, - OutgoingEnvelope::ToConnection { - connection_id, - message: OutgoingMessage::AppServerNotification(thread_goal_updated_notification()), - write_complete_tx: None, - }, - ) - .await; - - let message = writer_rx - .recv() - .await - .expect("experimental notification should reach opted-in client"); - assert!(matches!( - message.message, - OutgoingMessage::AppServerNotification(ServerNotification::ThreadGoalUpdated(_)) - )); - } - - #[tokio::test] - async fn command_execution_request_approval_strips_additional_permissions_without_capability() { - let connection_id = ConnectionId(8); - let (writer_tx, mut writer_rx) = mpsc::channel(1); - - let mut connections = HashMap::new(); - connections.insert( - connection_id, - OutboundConnectionState::new( - writer_tx, - Arc::new(AtomicBool::new(true)), - Arc::new(AtomicBool::new(false)), - Arc::new(RwLock::new(HashSet::new())), - /*disconnect_sender*/ None, - ), - ); - - route_outgoing_envelope( - &mut connections, - OutgoingEnvelope::ToConnection { - connection_id, - message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval { - request_id: RequestId::Integer(1), - params: codex_app_server_protocol::CommandExecutionRequestApprovalParams { - thread_id: "thr_123".to_string(), - turn_id: "turn_123".to_string(), - item_id: "call_123".to_string(), - approval_id: None, - reason: Some("Need extra read access".to_string()), - network_approval_context: None, - command: Some("cat file".to_string()), - cwd: Some(absolute_path("/tmp")), - command_actions: None, - additional_permissions: Some( - codex_app_server_protocol::AdditionalPermissionProfile { - network: None, - file_system: Some( - codex_app_server_protocol::AdditionalFileSystemPermissions { - read: Some(vec![absolute_path("/tmp/allowed")]), - write: None, - glob_scan_max_depth: None, - entries: None, - }, - ), - }, - ), - proposed_execpolicy_amendment: None, - proposed_network_policy_amendments: None, - available_decisions: None, - }, - }), - write_complete_tx: None, - }, - ) - .await; - - let message = writer_rx - .recv() - .await - .expect("request should be delivered to the connection"); - let json = serde_json::to_value(message.message).expect("request should serialize"); - assert_eq!(json["params"].get("additionalPermissions"), None); - } - - #[tokio::test] - async fn command_execution_request_approval_keeps_additional_permissions_with_capability() { - let connection_id = ConnectionId(9); - let (writer_tx, mut writer_rx) = mpsc::channel(1); - - let mut connections = HashMap::new(); - connections.insert( - connection_id, - OutboundConnectionState::new( - writer_tx, - Arc::new(AtomicBool::new(true)), - Arc::new(AtomicBool::new(true)), - Arc::new(RwLock::new(HashSet::new())), - /*disconnect_sender*/ None, - ), - ); - - route_outgoing_envelope( - &mut connections, - OutgoingEnvelope::ToConnection { - connection_id, - message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval { - request_id: RequestId::Integer(1), - params: codex_app_server_protocol::CommandExecutionRequestApprovalParams { - thread_id: "thr_123".to_string(), - turn_id: "turn_123".to_string(), - item_id: "call_123".to_string(), - approval_id: None, - reason: Some("Need extra read access".to_string()), - network_approval_context: None, - command: Some("cat file".to_string()), - cwd: Some(absolute_path("/tmp")), - command_actions: None, - additional_permissions: Some( - codex_app_server_protocol::AdditionalPermissionProfile { - network: None, - file_system: Some( - codex_app_server_protocol::AdditionalFileSystemPermissions { - read: Some(vec![absolute_path("/tmp/allowed")]), - write: None, - glob_scan_max_depth: None, - entries: None, - }, - ), - }, - ), - proposed_execpolicy_amendment: None, - proposed_network_policy_amendments: None, - available_decisions: None, - }, - }), - write_complete_tx: None, - }, - ) - .await; - - let message = writer_rx - .recv() - .await - .expect("request should be delivered to the connection"); - let json = serde_json::to_value(message.message).expect("request should serialize"); - let allowed_path = absolute_path("/tmp/allowed").to_string_lossy().into_owned(); - assert_eq!( - json["params"]["additionalPermissions"], - json!({ - "network": null, - "fileSystem": { - "read": [allowed_path], - "write": null, - }, - }) - ); - } - - #[tokio::test] - async fn broadcast_does_not_block_on_slow_connection() { - let fast_connection_id = ConnectionId(1); - let slow_connection_id = ConnectionId(2); - - let (fast_writer_tx, mut fast_writer_rx) = mpsc::channel(1); - let (slow_writer_tx, mut slow_writer_rx) = mpsc::channel(1); - let fast_disconnect_token = CancellationToken::new(); - let slow_disconnect_token = CancellationToken::new(); - - let mut connections = HashMap::new(); - connections.insert( - fast_connection_id, - OutboundConnectionState::new( - fast_writer_tx, - Arc::new(AtomicBool::new(true)), - Arc::new(AtomicBool::new(true)), - Arc::new(RwLock::new(HashSet::new())), - Some(fast_disconnect_token.clone()), - ), - ); - connections.insert( - slow_connection_id, - OutboundConnectionState::new( - slow_writer_tx.clone(), - Arc::new(AtomicBool::new(true)), - Arc::new(AtomicBool::new(true)), - Arc::new(RwLock::new(HashSet::new())), - Some(slow_disconnect_token.clone()), - ), - ); - - let queued_message = OutgoingMessage::AppServerNotification( - ServerNotification::ConfigWarning(ConfigWarningNotification { - summary: "already-buffered".to_string(), - details: None, - path: None, - range: None, - }), - ); - slow_writer_tx - .try_send(QueuedOutgoingMessage::new(queued_message)) - .expect("channel should have room"); - - let broadcast_message = OutgoingMessage::AppServerNotification( - ServerNotification::ConfigWarning(ConfigWarningNotification { - summary: "test".to_string(), - details: None, - path: None, - range: None, - }), - ); - timeout( - Duration::from_millis(100), - route_outgoing_envelope( - &mut connections, - OutgoingEnvelope::Broadcast { - message: broadcast_message, - }, - ), - ) - .await - .expect("broadcast should return even when one connection is slow"); - assert!(!connections.contains_key(&slow_connection_id)); - assert!(slow_disconnect_token.is_cancelled()); - assert!(!fast_disconnect_token.is_cancelled()); - let fast_message = fast_writer_rx - .try_recv() - .expect("fast connection should receive the broadcast notification"); - assert!(matches!( - fast_message.message, - OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( - ConfigWarningNotification { summary, .. } - )) if summary == "test" - )); - - let slow_message = slow_writer_rx - .try_recv() - .expect("slow connection should retain its original buffered message"); - assert!(matches!( - slow_message.message, - OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( - ConfigWarningNotification { summary, .. } - )) if summary == "already-buffered" - )); - } - - #[tokio::test] - async fn to_connection_stdio_waits_instead_of_disconnecting_when_writer_queue_is_full() { - let connection_id = ConnectionId(3); - let (writer_tx, mut writer_rx) = mpsc::channel(1); - writer_tx - .send(QueuedOutgoingMessage::new( - OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( - ConfigWarningNotification { - summary: "queued".to_string(), - details: None, - path: None, - range: None, - }, - )), - )) - .await - .expect("channel should accept the first queued message"); - - let mut connections = HashMap::new(); - connections.insert( - connection_id, - OutboundConnectionState::new( - writer_tx, - Arc::new(AtomicBool::new(true)), - Arc::new(AtomicBool::new(true)), - Arc::new(RwLock::new(HashSet::new())), - /*disconnect_sender*/ None, - ), - ); - - let route_task = tokio::spawn(async move { - route_outgoing_envelope( - &mut connections, - OutgoingEnvelope::ToConnection { - connection_id, - message: OutgoingMessage::AppServerNotification( - ServerNotification::ConfigWarning(ConfigWarningNotification { - summary: "second".to_string(), - details: None, - path: None, - range: None, - }), - ), - write_complete_tx: None, - }, - ) - .await - }); - - let first = timeout(Duration::from_millis(100), writer_rx.recv()) - .await - .expect("first queued message should be readable") - .expect("first queued message should exist"); - timeout(Duration::from_millis(100), route_task) - .await - .expect("routing should finish after the first queued message is drained") - .expect("routing task should succeed"); - - assert!(matches!( - first.message, - OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( - ConfigWarningNotification { summary, .. } - )) if summary == "queued" - )); - let second = writer_rx - .try_recv() - .expect("second notification should be delivered once the queue has room"); - assert!(matches!( - second.message, - OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( - ConfigWarningNotification { summary, .. } - )) if summary == "second" - )); - } -} diff --git a/codex-rs/app-server/src/transport_tests.rs b/codex-rs/app-server/src/transport_tests.rs new file mode 100644 index 000000000..1600b8be8 --- /dev/null +++ b/codex-rs/app-server/src/transport_tests.rs @@ -0,0 +1,532 @@ +use super::*; +use codex_app_server_protocol::ConfigWarningNotification; +use codex_app_server_protocol::RequestId; +use codex_app_server_protocol::ServerNotification; +use codex_app_server_protocol::ThreadGoal; +use codex_app_server_protocol::ThreadGoalStatus; +use codex_app_server_protocol::ThreadGoalUpdatedNotification; +use codex_utils_absolute_path::AbsolutePathBuf; +use pretty_assertions::assert_eq; +use serde_json::json; +use tokio::time::Duration; +use tokio::time::timeout; + +fn absolute_path(path: &str) -> AbsolutePathBuf { + AbsolutePathBuf::from_absolute_path(path).expect("absolute path") +} + +fn thread_goal_updated_notification() -> ServerNotification { + ServerNotification::ThreadGoalUpdated(ThreadGoalUpdatedNotification { + thread_id: "thread-1".to_string(), + turn_id: None, + goal: ThreadGoal { + thread_id: "thread-1".to_string(), + objective: "ship goal mode".to_string(), + status: ThreadGoalStatus::Active, + token_budget: None, + tokens_used: 0, + time_used_seconds: 0, + created_at: 1, + updated_at: 1, + }, + }) +} + +#[tokio::test] +async fn to_connection_notification_respects_opt_out_filters() { + let connection_id = ConnectionId(7); + let (writer_tx, mut writer_rx) = mpsc::channel(1); + let initialized = Arc::new(AtomicBool::new(true)); + let opted_out_notification_methods = + Arc::new(RwLock::new(HashSet::from(["configWarning".to_string()]))); + + let mut connections = HashMap::new(); + connections.insert( + connection_id, + OutboundConnectionState::new( + writer_tx, + initialized, + Arc::new(AtomicBool::new(true)), + opted_out_notification_methods, + /*disconnect_sender*/ None, + ), + ); + + route_outgoing_envelope( + &mut connections, + OutgoingEnvelope::ToConnection { + connection_id, + message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { + summary: "task_started".to_string(), + details: None, + path: None, + range: None, + }, + )), + write_complete_tx: None, + }, + ) + .await; + + assert!( + writer_rx.try_recv().is_err(), + "opted-out notification should be dropped" + ); +} + +#[tokio::test] +async fn to_connection_notifications_are_dropped_for_opted_out_clients() { + let connection_id = ConnectionId(10); + let (writer_tx, mut writer_rx) = mpsc::channel(1); + + let mut connections = HashMap::new(); + connections.insert( + connection_id, + OutboundConnectionState::new( + writer_tx, + Arc::new(AtomicBool::new(true)), + Arc::new(AtomicBool::new(true)), + Arc::new(RwLock::new(HashSet::from(["configWarning".to_string()]))), + /*disconnect_sender*/ None, + ), + ); + + route_outgoing_envelope( + &mut connections, + OutgoingEnvelope::ToConnection { + connection_id, + message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { + summary: "task_started".to_string(), + details: None, + path: None, + range: None, + }, + )), + write_complete_tx: None, + }, + ) + .await; + + assert!( + writer_rx.try_recv().is_err(), + "opted-out notifications should not reach clients" + ); +} + +#[tokio::test] +async fn to_connection_notifications_are_preserved_for_non_opted_out_clients() { + let connection_id = ConnectionId(11); + let (writer_tx, mut writer_rx) = mpsc::channel(1); + + let mut connections = HashMap::new(); + connections.insert( + connection_id, + OutboundConnectionState::new( + writer_tx, + Arc::new(AtomicBool::new(true)), + Arc::new(AtomicBool::new(true)), + Arc::new(RwLock::new(HashSet::new())), + /*disconnect_sender*/ None, + ), + ); + + route_outgoing_envelope( + &mut connections, + OutgoingEnvelope::ToConnection { + connection_id, + message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { + summary: "task_started".to_string(), + details: None, + path: None, + range: None, + }, + )), + write_complete_tx: None, + }, + ) + .await; + + let message = writer_rx + .recv() + .await + .expect("notification should reach non-opted-out clients"); + assert!(matches!( + message.message, + OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { summary, .. } + )) if summary == "task_started" + )); +} + +#[tokio::test] +async fn experimental_notifications_are_dropped_without_capability() { + let connection_id = ConnectionId(12); + let (writer_tx, mut writer_rx) = mpsc::channel(1); + + let mut connections = HashMap::new(); + connections.insert( + connection_id, + OutboundConnectionState::new( + writer_tx, + Arc::new(AtomicBool::new(true)), + Arc::new(AtomicBool::new(false)), + Arc::new(RwLock::new(HashSet::new())), + /*disconnect_sender*/ None, + ), + ); + + route_outgoing_envelope( + &mut connections, + OutgoingEnvelope::ToConnection { + connection_id, + message: OutgoingMessage::AppServerNotification(thread_goal_updated_notification()), + write_complete_tx: None, + }, + ) + .await; + + assert!( + writer_rx.try_recv().is_err(), + "experimental notifications should not reach clients without capability" + ); +} + +#[tokio::test] +async fn experimental_notifications_are_preserved_with_capability() { + let connection_id = ConnectionId(13); + let (writer_tx, mut writer_rx) = mpsc::channel(1); + + let mut connections = HashMap::new(); + connections.insert( + connection_id, + OutboundConnectionState::new( + writer_tx, + Arc::new(AtomicBool::new(true)), + Arc::new(AtomicBool::new(true)), + Arc::new(RwLock::new(HashSet::new())), + /*disconnect_sender*/ None, + ), + ); + + route_outgoing_envelope( + &mut connections, + OutgoingEnvelope::ToConnection { + connection_id, + message: OutgoingMessage::AppServerNotification(thread_goal_updated_notification()), + write_complete_tx: None, + }, + ) + .await; + + let message = writer_rx + .recv() + .await + .expect("experimental notification should reach opted-in client"); + assert!(matches!( + message.message, + OutgoingMessage::AppServerNotification(ServerNotification::ThreadGoalUpdated(_)) + )); +} + +#[tokio::test] +async fn command_execution_request_approval_strips_additional_permissions_without_capability() { + let connection_id = ConnectionId(8); + let (writer_tx, mut writer_rx) = mpsc::channel(1); + + let mut connections = HashMap::new(); + connections.insert( + connection_id, + OutboundConnectionState::new( + writer_tx, + Arc::new(AtomicBool::new(true)), + Arc::new(AtomicBool::new(false)), + Arc::new(RwLock::new(HashSet::new())), + /*disconnect_sender*/ None, + ), + ); + + route_outgoing_envelope( + &mut connections, + OutgoingEnvelope::ToConnection { + connection_id, + message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval { + request_id: RequestId::Integer(1), + params: codex_app_server_protocol::CommandExecutionRequestApprovalParams { + thread_id: "thr_123".to_string(), + turn_id: "turn_123".to_string(), + item_id: "call_123".to_string(), + approval_id: None, + reason: Some("Need extra read access".to_string()), + network_approval_context: None, + command: Some("cat file".to_string()), + cwd: Some(absolute_path("/tmp")), + command_actions: None, + additional_permissions: Some( + codex_app_server_protocol::AdditionalPermissionProfile { + network: None, + file_system: Some( + codex_app_server_protocol::AdditionalFileSystemPermissions { + read: Some(vec![absolute_path("/tmp/allowed")]), + write: None, + glob_scan_max_depth: None, + entries: None, + }, + ), + }, + ), + proposed_execpolicy_amendment: None, + proposed_network_policy_amendments: None, + available_decisions: None, + }, + }), + write_complete_tx: None, + }, + ) + .await; + + let message = writer_rx + .recv() + .await + .expect("request should be delivered to the connection"); + let json = serde_json::to_value(message.message).expect("request should serialize"); + assert_eq!(json["params"].get("additionalPermissions"), None); +} + +#[tokio::test] +async fn command_execution_request_approval_keeps_additional_permissions_with_capability() { + let connection_id = ConnectionId(9); + let (writer_tx, mut writer_rx) = mpsc::channel(1); + + let mut connections = HashMap::new(); + connections.insert( + connection_id, + OutboundConnectionState::new( + writer_tx, + Arc::new(AtomicBool::new(true)), + Arc::new(AtomicBool::new(true)), + Arc::new(RwLock::new(HashSet::new())), + /*disconnect_sender*/ None, + ), + ); + + route_outgoing_envelope( + &mut connections, + OutgoingEnvelope::ToConnection { + connection_id, + message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval { + request_id: RequestId::Integer(1), + params: codex_app_server_protocol::CommandExecutionRequestApprovalParams { + thread_id: "thr_123".to_string(), + turn_id: "turn_123".to_string(), + item_id: "call_123".to_string(), + approval_id: None, + reason: Some("Need extra read access".to_string()), + network_approval_context: None, + command: Some("cat file".to_string()), + cwd: Some(absolute_path("/tmp")), + command_actions: None, + additional_permissions: Some( + codex_app_server_protocol::AdditionalPermissionProfile { + network: None, + file_system: Some( + codex_app_server_protocol::AdditionalFileSystemPermissions { + read: Some(vec![absolute_path("/tmp/allowed")]), + write: None, + glob_scan_max_depth: None, + entries: None, + }, + ), + }, + ), + proposed_execpolicy_amendment: None, + proposed_network_policy_amendments: None, + available_decisions: None, + }, + }), + write_complete_tx: None, + }, + ) + .await; + + let message = writer_rx + .recv() + .await + .expect("request should be delivered to the connection"); + let json = serde_json::to_value(message.message).expect("request should serialize"); + let allowed_path = absolute_path("/tmp/allowed").to_string_lossy().into_owned(); + assert_eq!( + json["params"]["additionalPermissions"], + json!({ + "network": null, + "fileSystem": { + "read": [allowed_path], + "write": null, + }, + }) + ); +} + +#[tokio::test] +async fn broadcast_does_not_block_on_slow_connection() { + let fast_connection_id = ConnectionId(1); + let slow_connection_id = ConnectionId(2); + + let (fast_writer_tx, mut fast_writer_rx) = mpsc::channel(1); + let (slow_writer_tx, mut slow_writer_rx) = mpsc::channel(1); + let fast_disconnect_token = CancellationToken::new(); + let slow_disconnect_token = CancellationToken::new(); + + let mut connections = HashMap::new(); + connections.insert( + fast_connection_id, + OutboundConnectionState::new( + fast_writer_tx, + Arc::new(AtomicBool::new(true)), + Arc::new(AtomicBool::new(true)), + Arc::new(RwLock::new(HashSet::new())), + Some(fast_disconnect_token.clone()), + ), + ); + connections.insert( + slow_connection_id, + OutboundConnectionState::new( + slow_writer_tx.clone(), + Arc::new(AtomicBool::new(true)), + Arc::new(AtomicBool::new(true)), + Arc::new(RwLock::new(HashSet::new())), + Some(slow_disconnect_token.clone()), + ), + ); + + let queued_message = OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { + summary: "already-buffered".to_string(), + details: None, + path: None, + range: None, + }, + )); + slow_writer_tx + .try_send(QueuedOutgoingMessage::new(queued_message)) + .expect("channel should have room"); + + let broadcast_message = OutgoingMessage::AppServerNotification( + ServerNotification::ConfigWarning(ConfigWarningNotification { + summary: "test".to_string(), + details: None, + path: None, + range: None, + }), + ); + timeout( + Duration::from_millis(100), + route_outgoing_envelope( + &mut connections, + OutgoingEnvelope::Broadcast { + message: broadcast_message, + }, + ), + ) + .await + .expect("broadcast should return even when one connection is slow"); + assert!(!connections.contains_key(&slow_connection_id)); + assert!(slow_disconnect_token.is_cancelled()); + assert!(!fast_disconnect_token.is_cancelled()); + let fast_message = fast_writer_rx + .try_recv() + .expect("fast connection should receive the broadcast notification"); + assert!(matches!( + fast_message.message, + OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { summary, .. } + )) if summary == "test" + )); + + let slow_message = slow_writer_rx + .try_recv() + .expect("slow connection should retain its original buffered message"); + assert!(matches!( + slow_message.message, + OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { summary, .. } + )) if summary == "already-buffered" + )); +} + +#[tokio::test] +async fn to_connection_stdio_waits_instead_of_disconnecting_when_writer_queue_is_full() { + let connection_id = ConnectionId(3); + let (writer_tx, mut writer_rx) = mpsc::channel(1); + writer_tx + .send(QueuedOutgoingMessage::new( + OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { + summary: "queued".to_string(), + details: None, + path: None, + range: None, + }, + )), + )) + .await + .expect("channel should accept the first queued message"); + + let mut connections = HashMap::new(); + connections.insert( + connection_id, + OutboundConnectionState::new( + writer_tx, + Arc::new(AtomicBool::new(true)), + Arc::new(AtomicBool::new(true)), + Arc::new(RwLock::new(HashSet::new())), + /*disconnect_sender*/ None, + ), + ); + + let route_task = tokio::spawn(async move { + route_outgoing_envelope( + &mut connections, + OutgoingEnvelope::ToConnection { + connection_id, + message: OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { + summary: "second".to_string(), + details: None, + path: None, + range: None, + }, + )), + write_complete_tx: None, + }, + ) + .await + }); + + let first = timeout(Duration::from_millis(100), writer_rx.recv()) + .await + .expect("first queued message should be readable") + .expect("first queued message should exist"); + timeout(Duration::from_millis(100), route_task) + .await + .expect("routing should finish after the first queued message is drained") + .expect("routing task should succeed"); + + assert!(matches!( + first.message, + OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { summary, .. } + )) if summary == "queued" + )); + let second = writer_rx + .try_recv() + .expect("second notification should be delivered once the queue has room"); + assert!(matches!( + second.message, + OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { summary, .. } + )) if summary == "second" + )); +}