diff --git a/codex-rs/exec-server/src/client.rs b/codex-rs/exec-server/src/client.rs index 18cddfe84..8cb094fb2 100644 --- a/codex-rs/exec-server/src/client.rs +++ b/codex-rs/exec-server/src/client.rs @@ -2,6 +2,7 @@ use std::collections::BTreeMap; use std::collections::HashMap; use std::sync::Arc; use std::sync::Mutex as StdMutex; +use std::sync::OnceLock; use std::time::Duration; use arc_swap::ArcSwap; @@ -140,6 +141,10 @@ struct Inner { // need serialization so concurrent register/remove operations do not // overwrite each other's copy-on-write updates. sessions_write_lock: Mutex<()>, + // Once the transport closes, every executor operation should fail quickly + // with the same canonical message. This client never reconnects, so the + // latch only moves from unset to set once. + disconnected: OnceLock, session_id: std::sync::RwLock>, reader_task: tokio::task::JoinHandle<()>, } @@ -171,6 +176,8 @@ pub enum ExecServerError { InitializeTimedOut { timeout: Duration }, #[error("exec-server transport closed")] Closed, + #[error("{0}")] + Disconnected(String), #[error("failed to serialize or deserialize exec-server JSON: {0}")] Json(#[from] serde_json::Error), #[error("exec-server protocol error: {0}")] @@ -246,19 +253,11 @@ impl ExecServerClient { } pub async fn exec(&self, params: ExecParams) -> Result { - self.inner - .client - .call(EXEC_METHOD, ¶ms) - .await - .map_err(Into::into) + self.call(EXEC_METHOD, ¶ms).await } pub async fn read(&self, params: ReadParams) -> Result { - self.inner - .client - .call(EXEC_READ_METHOD, ¶ms) - .await - .map_err(Into::into) + self.call(EXEC_READ_METHOD, ¶ms).await } pub async fn write( @@ -266,107 +265,73 @@ impl ExecServerClient { process_id: &ProcessId, chunk: Vec, ) -> Result { - self.inner - .client - .call( - EXEC_WRITE_METHOD, - &WriteParams { - process_id: process_id.clone(), - chunk: chunk.into(), - }, - ) - .await - .map_err(Into::into) + self.call( + EXEC_WRITE_METHOD, + &WriteParams { + process_id: process_id.clone(), + chunk: chunk.into(), + }, + ) + .await } pub async fn terminate( &self, process_id: &ProcessId, ) -> Result { - self.inner - .client - .call( - EXEC_TERMINATE_METHOD, - &TerminateParams { - process_id: process_id.clone(), - }, - ) - .await - .map_err(Into::into) + self.call( + EXEC_TERMINATE_METHOD, + &TerminateParams { + process_id: process_id.clone(), + }, + ) + .await } pub async fn fs_read_file( &self, params: FsReadFileParams, ) -> Result { - self.inner - .client - .call(FS_READ_FILE_METHOD, ¶ms) - .await - .map_err(Into::into) + self.call(FS_READ_FILE_METHOD, ¶ms).await } pub async fn fs_write_file( &self, params: FsWriteFileParams, ) -> Result { - self.inner - .client - .call(FS_WRITE_FILE_METHOD, ¶ms) - .await - .map_err(Into::into) + self.call(FS_WRITE_FILE_METHOD, ¶ms).await } pub async fn fs_create_directory( &self, params: FsCreateDirectoryParams, ) -> Result { - self.inner - .client - .call(FS_CREATE_DIRECTORY_METHOD, ¶ms) - .await - .map_err(Into::into) + self.call(FS_CREATE_DIRECTORY_METHOD, ¶ms).await } pub async fn fs_get_metadata( &self, params: FsGetMetadataParams, ) -> Result { - self.inner - .client - .call(FS_GET_METADATA_METHOD, ¶ms) - .await - .map_err(Into::into) + self.call(FS_GET_METADATA_METHOD, ¶ms).await } pub async fn fs_read_directory( &self, params: FsReadDirectoryParams, ) -> Result { - self.inner - .client - .call(FS_READ_DIRECTORY_METHOD, ¶ms) - .await - .map_err(Into::into) + self.call(FS_READ_DIRECTORY_METHOD, ¶ms).await } pub async fn fs_remove( &self, params: FsRemoveParams, ) -> Result { - self.inner - .client - .call(FS_REMOVE_METHOD, ¶ms) - .await - .map_err(Into::into) + self.call(FS_REMOVE_METHOD, ¶ms).await } pub async fn fs_copy(&self, params: FsCopyParams) -> Result { - self.inner - .client - .call(FS_COPY_METHOD, ¶ms) - .await - .map_err(Into::into) + self.call(FS_COPY_METHOD, ¶ms).await } pub(crate) async fn register_session( @@ -411,18 +376,21 @@ impl ExecServerClient { && let Err(err) = handle_server_notification(&inner, notification).await { - fail_all_sessions( + let message = record_disconnected( &inner, format!("exec-server notification handling failed: {err}"), - ) - .await; + ); + fail_all_sessions(&inner, message).await; return; } } RpcClientEvent::Disconnected { reason } => { if let Some(inner) = weak.upgrade() { - fail_all_sessions(&inner, disconnected_message(reason.as_deref())) - .await; + let message = record_disconnected( + &inner, + disconnected_message(reason.as_deref()), + ); + fail_all_sessions(&inner, message).await; } return; } @@ -434,6 +402,7 @@ impl ExecServerClient { client: rpc_client, sessions: ArcSwap::from_pointee(HashMap::new()), sessions_write_lock: Mutex::new(()), + disconnected: OnceLock::new(), session_id: std::sync::RwLock::new(None), reader_task, } @@ -451,6 +420,36 @@ impl ExecServerClient { .await .map_err(ExecServerError::Json) } + + async fn call(&self, method: &str, params: &P) -> Result + where + P: serde::Serialize, + T: serde::de::DeserializeOwned, + { + // Reject new work before allocating a JSON-RPC request id. MCP tool + // calls, process writes, and fs operations all pass through here, so + // this is the shared low-level failure path after executor disconnect. + if let Some(error) = self.inner.disconnected_error() { + return Err(error); + } + + match self.inner.client.call(method, params).await { + Ok(response) => Ok(response), + Err(error) => { + let error = ExecServerError::from(error); + if is_transport_closed_error(&error) { + // A call can race with disconnect after the preflight + // check. Only the reader task drains sessions so queued + // process notifications stay ordered before disconnect. + let message = disconnected_message(/*reason*/ None); + let message = record_disconnected(&self.inner, message); + Err(ExecServerError::Disconnected(message)) + } else { + Err(error) + } + } + } + } } impl From for ExecServerError { @@ -630,6 +629,20 @@ impl Session { } impl Inner { + fn disconnected_error(&self) -> Option { + self.disconnected + .get() + .cloned() + .map(ExecServerError::Disconnected) + } + + fn set_disconnected(&self, message: String) -> Option { + match self.disconnected.set(message.clone()) { + Ok(()) => Some(message), + Err(_) => None, + } + } + fn get_session(&self, process_id: &ProcessId) -> Option> { self.sessions.load().get(process_id).cloned() } @@ -640,6 +653,12 @@ impl Inner { session: Arc, ) -> Result<(), ExecServerError> { let _sessions_write_guard = self.sessions_write_lock.lock().await; + // Do not register a process session that can never receive executor + // notifications. Without this check, remote MCP startup could create a + // dead session and wait for process output that will never arrive. + if let Some(error) = self.disconnected_error() { + return Err(error); + } let sessions = self.sessions.load(); if sessions.contains_key(process_id) { return Err(ExecServerError::Protocol(format!( @@ -680,20 +699,36 @@ fn disconnected_message(reason: Option<&str>) -> String { } fn is_transport_closed_error(error: &ExecServerError) -> bool { - matches!(error, ExecServerError::Closed) - || matches!( - error, - ExecServerError::Server { - code: -32000, - message, - } if message == "JSON-RPC transport closed" - ) + matches!( + error, + ExecServerError::Closed | ExecServerError::Disconnected(_) + ) || matches!( + error, + ExecServerError::Server { + code: -32000, + message, + } if message == "JSON-RPC transport closed" + ) +} + +fn record_disconnected(inner: &Arc, message: String) -> String { + // The first observer records the canonical disconnect reason. Session + // draining stays with the reader task so it can preserve notification + // ordering before publishing the terminal failure. + if let Some(message) = inner.set_disconnected(message.clone()) { + message + } else { + inner.disconnected.get().cloned().unwrap_or(message) + } } async fn fail_all_sessions(inner: &Arc, message: String) { let sessions = inner.take_all_sessions().await; for (_, session) in sessions { + // Sessions synthesize a closed read response and emit a pushed Failed + // event. That covers both polling consumers and streaming consumers + // such as executor-backed MCP stdio. session.set_failure(message.clone()).await; } } diff --git a/codex-rs/exec-server/src/remote_file_system.rs b/codex-rs/exec-server/src/remote_file_system.rs index 111e8d603..d6a32ba4d 100644 --- a/codex-rs/exec-server/src/remote_file_system.rs +++ b/codex-rs/exec-server/src/remote_file_system.rs @@ -195,9 +195,46 @@ fn map_remote_error(error: ExecServerError) -> io::Error { io::Error::new(io::ErrorKind::InvalidInput, message) } ExecServerError::Server { message, .. } => io::Error::other(message), - ExecServerError::Closed => { + ExecServerError::Closed | ExecServerError::Disconnected(_) => { io::Error::new(io::ErrorKind::BrokenPipe, "exec-server transport closed") } _ => io::Error::other(error.to_string()), } } + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn transport_errors_map_to_broken_pipe() { + let errors = [ + ExecServerError::Closed, + ExecServerError::Disconnected("exec-server transport disconnected".to_string()), + ]; + + let mapped_errors = errors + .into_iter() + .map(|error| { + let error = map_remote_error(error); + (error.kind(), error.to_string()) + }) + .collect::>(); + + assert_eq!( + mapped_errors, + vec![ + ( + io::ErrorKind::BrokenPipe, + "exec-server transport closed".to_string() + ), + ( + io::ErrorKind::BrokenPipe, + "exec-server transport closed".to_string() + ), + ] + ); + } +} diff --git a/codex-rs/exec-server/src/rpc.rs b/codex-rs/exec-server/src/rpc.rs index 30a3f70bb..e82b4a0ea 100644 --- a/codex-rs/exec-server/src/rpc.rs +++ b/codex-rs/exec-server/src/rpc.rs @@ -18,12 +18,23 @@ use serde_json::Value; use tokio::sync::Mutex; use tokio::sync::mpsc; use tokio::sync::oneshot; +use tokio::sync::watch; use tokio::task::JoinHandle; use crate::connection::JsonRpcConnection; use crate::connection::JsonRpcConnectionEvent; -type PendingRequest = oneshot::Sender>; +#[derive(Debug)] +pub(crate) enum RpcCallError { + /// The underlying JSON-RPC transport closed before this call completed. + Closed, + /// The response bytes were valid JSON-RPC but not the expected result type. + Json(serde_json::Error), + /// The executor returned a JSON-RPC error response for this call. + Server(JSONRPCErrorError), +} + +type PendingRequest = oneshot::Sender>; type BoxFuture = Pin + Send + 'static>>; type RequestRoute = Box, JSONRPCRequest) -> BoxFuture + Send + Sync>; @@ -172,6 +183,10 @@ where pub(crate) struct RpcClient { write_tx: mpsc::Sender, pending: Arc>>, + // Shared transport state from `JsonRpcConnection`. Calls use this to fail + // immediately when the socket closes, even if no JSON-RPC error response + // can be delivered for their request id. + disconnected_rx: watch::Receiver, next_request_id: AtomicI64, transport_tasks: Vec>, reader_task: JoinHandle<()>, @@ -179,8 +194,7 @@ pub(crate) struct RpcClient { impl RpcClient { pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver) { - let (write_tx, mut incoming_rx, _disconnected_rx, transport_tasks) = - connection.into_parts(); + let (write_tx, mut incoming_rx, disconnected_rx, transport_tasks) = connection.into_parts(); let pending = Arc::new(Mutex::new(HashMap::::new())); let (event_tx, event_rx) = mpsc::channel(128); @@ -218,6 +232,7 @@ impl RpcClient { Self { write_tx, pending, + disconnected_rx, next_request_id: AtomicI64::new(1), transport_tasks, reader_task, @@ -253,10 +268,16 @@ impl RpcClient { { let request_id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::SeqCst)); let (response_tx, response_rx) = oneshot::channel(); - self.pending - .lock() - .await - .insert(request_id.clone(), response_tx); + { + let mut pending = self.pending.lock().await; + // Registering the pending request and checking disconnect must be + // atomic with the reader's drain_pending path. Otherwise a call + // can sneak in after the drain and wait forever. + if *self.disconnected_rx.borrow() { + return Err(RpcCallError::Closed); + } + pending.insert(request_id.clone(), response_tx); + } let params = match serde_json::to_value(params) { Ok(params) => params, @@ -280,10 +301,17 @@ impl RpcClient { return Err(RpcCallError::Closed); } - let result = response_rx.await.map_err(|_| RpcCallError::Closed)?; + // Do not race in-flight requests directly against the transport-close + // watch value. The connection reader receives JSON-RPC messages and + // the terminal disconnect event on one ordered queue, then drains any + // still-pending requests. Awaiting this receiver preserves that order: + // responses already read before EOF still win, and truly pending calls + // are failed once the reader observes the disconnect. + let result: Result = + response_rx.await.map_err(|_| RpcCallError::Closed)?; let response = match result { Ok(response) => response, - Err(error) => return Err(RpcCallError::Server(error)), + Err(error) => return Err(error), }; serde_json::from_value(response).map_err(RpcCallError::Json) } @@ -304,13 +332,6 @@ impl Drop for RpcClient { } } -#[derive(Debug)] -pub(crate) enum RpcCallError { - Closed, - Json(serde_json::Error), - Server(JSONRPCErrorError), -} - pub(crate) fn encode_server_message( message: RpcServerOutboundMessage, ) -> Result { @@ -417,7 +438,7 @@ async fn handle_server_message( } JSONRPCMessage::Error(JSONRPCError { id, error }) => { if let Some(pending) = pending.lock().await.remove(&id) { - let _ = pending.send(Err(error)); + let _ = pending.send(Err(RpcCallError::Server(error))); } } JSONRPCMessage::Notification(notification) => { @@ -445,11 +466,7 @@ async fn drain_pending(pending: &Mutex>) { .collect::>() }; for pending in pending { - let _ = pending.send(Err(JSONRPCErrorError { - code: -32000, - data: None, - message: "JSON-RPC transport closed".to_string(), - })); + let _ = pending.send(Err(RpcCallError::Closed)); } } diff --git a/codex-rs/exec-server/tests/exec_process.rs b/codex-rs/exec-server/tests/exec_process.rs index 94e78b42b..d449315c8 100644 --- a/codex-rs/exec-server/tests/exec_process.rs +++ b/codex-rs/exec-server/tests/exec_process.rs @@ -4,6 +4,7 @@ mod common; use std::sync::Arc; +use anyhow::Context; use anyhow::Result; use codex_exec_server::Environment; use codex_exec_server::ExecBackend; @@ -484,6 +485,16 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> { let process = Arc::clone(&session.process); let mut events = process.subscribe_events(); + let process_for_pending_read = Arc::clone(&process); + let pending_read = tokio::spawn(async move { + process_for_pending_read + .read( + /*after_seq*/ None, + /*max_bytes*/ None, + /*wait_ms*/ Some(60_000), + ) + .await + }); let server = context .server .as_mut() @@ -499,6 +510,15 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> { "unexpected failure event: {event_message}" ); + let pending_response = timeout(Duration::from_secs(2), pending_read).await???; + let pending_message = pending_response + .failure + .expect("pending read should surface disconnect as a failure"); + assert!( + pending_message.starts_with("exec-server transport disconnected"), + "unexpected pending failure message: {pending_message}" + ); + let mut wake_rx = process.subscribe_wake(); let response = read_process_until_change(process, &mut wake_rx, /*after_seq*/ None).await?; let message = response @@ -513,6 +533,20 @@ async fn remote_exec_process_reports_transport_disconnect() -> Result<()> { "disconnect should close the process session" ); + let write_result = timeout( + Duration::from_secs(2), + session.process.write(b"hello".to_vec()), + ) + .await + .context("timed out waiting for write after disconnect")?; + let write_error = write_result.expect_err("write after disconnect should fail"); + assert!( + write_error + .to_string() + .starts_with("exec-server transport disconnected"), + "unexpected write error: {write_error}" + ); + Ok(()) }