diff --git a/codex-rs/exec-server/src/client.rs b/codex-rs/exec-server/src/client.rs index f6c531845..ae77dc7c4 100644 --- a/codex-rs/exec-server/src/client.rs +++ b/codex-rs/exec-server/src/client.rs @@ -3,7 +3,9 @@ use std::collections::HashMap; use std::sync::Arc; use std::sync::Mutex as StdMutex; use std::sync::OnceLock; +use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicU64; +use std::sync::atomic::Ordering; use std::time::Duration; use arc_swap::ArcSwap; @@ -25,6 +27,7 @@ use crate::client_api::ExecServerTransportParams; use crate::client_api::HttpClient; use crate::client_api::RemoteExecServerConnectArgs; use crate::client_api::StdioExecServerConnectArgs; +use crate::client_transport::ExecServerReconnectStrategy; use crate::connection::JsonRpcConnection; use crate::process::ExecProcessEvent; use crate::process::ExecProcessEventLog; @@ -95,9 +98,10 @@ use crate::protocol::WriteParams; use crate::protocol::WriteResponse; use crate::rpc::RpcCallError; use crate::rpc::RpcClient; -use crate::rpc::RpcClientEvent; pub(crate) mod http_client; +#[path = "client_recovery.rs"] +mod recovery; const CONNECT_TIMEOUT: Duration = Duration::from_secs(10); const INITIALIZE_TIMEOUT: Duration = Duration::from_secs(10); @@ -150,16 +154,19 @@ pub(crate) struct SessionState { wake_tx: watch::Sender, events: ExecProcessEventLog, ordered_events: StdMutex, - failure: Mutex>, + recoverable: AtomicBool, } #[derive(Default)] struct OrderedSessionEvents { last_published_seq: u64, + exit_published: bool, + closed_published: bool, // Server-side output, exit, and closed notifications are emitted by // different tasks and can reach the client out of order. Keep future events // here until all lower sequence numbers have been published. pending: BTreeMap, + failure: Option, } #[derive(Clone)] @@ -170,7 +177,8 @@ pub(crate) struct Session { } struct Inner { - client: RpcClient, + connection: StdMutex, + connection_changed: watch::Sender<()>, // The remote transport delivers one shared notification stream for every // process on the connection. Keep a local process_id -> session registry so // we can turn those connection-global notifications into process wakeups @@ -179,11 +187,7 @@ struct Inner { // ArcSwap makes reads cheap on the hot notification path, but writes still // need serialization so concurrent register/remove operations do not // overwrite each other's copy-on-write updates. - sessions_write_lock: Mutex<()>, - // Once the transport closes, every environment operation should fail quickly - // with the same canonical message. This client never reconnects, so the - // latch only moves from unset to set once. - disconnected: OnceLock, + sessions_write_lock: StdMutex<()>, // Streaming HTTP responses are keyed by a client-generated request id // because they share the same connection-global notification channel as // process output. Keep the routing table local to the client so higher @@ -192,14 +196,19 @@ struct Inner { http_body_stream_failures: ArcSwap>, http_body_streams_write_lock: Mutex<()>, http_body_stream_next_id: AtomicU64, - session_id: std::sync::RwLock>, - reader_task: tokio::task::JoinHandle<()>, + session_id: OnceLock, + reconnect_strategy: Option, } -impl Drop for Inner { - fn drop(&mut self) { - self.reader_task.abort(); - } +struct ConnectionState { + status: ConnectionStatus, + active_process_starts: usize, +} + +enum ConnectionStatus { + Connected(Arc), + Recovering, + Failed(String), } #[derive(Clone)] @@ -207,6 +216,16 @@ pub struct ExecServerClient { inner: Arc, } +struct ActiveProcessStart { + inner: Arc, +} + +impl Drop for ActiveProcessStart { + fn drop(&mut self) { + self.inner.finish_process_start(); + } +} + #[derive(Clone)] pub(crate) struct LazyRemoteExecServerClient { transport_params: ExecServerTransportParams, @@ -339,6 +358,15 @@ impl ExecServerClient { pub async fn initialize( &self, options: ExecServerClientConnectOptions, + ) -> Result { + let rpc_client = self.inner.rpc_client().await?; + self.initialize_rpc(&rpc_client, options).await + } + + async fn initialize_rpc( + &self, + rpc_client: &RpcClient, + options: ExecServerClientConnectOptions, ) -> Result { let ExecServerClientConnectOptions { client_name, @@ -347,9 +375,7 @@ impl ExecServerClient { } = options; timeout(initialize_timeout, async { - let response: InitializeResponse = self - .inner - .client + let response: InitializeResponse = rpc_client .call( INITIALIZE_METHOD, &InitializeParams { @@ -358,15 +384,19 @@ impl ExecServerClient { }, ) .await?; - { - let mut session_id = self - .inner - .session_id - .write() - .unwrap_or_else(std::sync::PoisonError::into_inner); - *session_id = Some(response.session_id.clone()); + let session_id = self + .inner + .session_id + .get_or_init(|| response.session_id.clone()); + if session_id != &response.session_id { + return Err(ExecServerError::Protocol(format!( + "exec-server initialized an unexpected session {}", + response.session_id + ))); } - self.notify_initialized().await?; + rpc_client + .notify(INITIALIZED_METHOD, &serde_json::json!({})) + .await?; Ok(response) }) .await @@ -503,14 +533,72 @@ impl ExecServerClient { self.call(FS_COPY_METHOD, ¶ms).await } + pub(crate) async fn start_process( + &self, + params: ExecParams, + ) -> Result { + loop { + let rpc_client = self.inner.rpc_client().await?; + if !self.inner.begin_process_start(&rpc_client) { + continue; + } + + let process_id = params.process_id.clone(); + let state = Arc::new(SessionState::new(/*recoverable*/ false)); + if let Err(error) = self.inner.insert_session(&process_id, Arc::clone(&state)) { + self.inner.finish_process_start(); + return Err(error); + } + let active_start = ActiveProcessStart { + inner: Arc::clone(&self.inner), + }; + let client = self.clone(); + let (result_tx, result_rx) = tokio::sync::oneshot::channel(); + tokio::spawn(async move { + let _active_start = active_start; + match client + .call_rpc::<_, ExecResponse>(&rpc_client, EXEC_METHOD, ¶ms) + .await + { + Ok(_) => { + state.recoverable.store(true, Ordering::Release); + let session = Session { + client: client.clone(), + process_id: process_id.clone(), + state: Arc::clone(&state), + }; + if result_tx.send(Ok(session)).is_err() { + state.recoverable.store(false, Ordering::Release); + tokio::spawn(async move { + cleanup_process_start(&client, &process_id, &state).await; + }); + } + } + Err(error) => { + if is_transport_closed_error(&error) { + tokio::spawn(async move { + cleanup_process_start(&client, &process_id, &state).await; + }); + } else { + client.inner.remove_session_if(&process_id, &state); + } + let _ = result_tx.send(Err(error)); + } + } + }); + return result_rx.await.map_err(|_| { + ExecServerError::Protocol("process start task stopped unexpectedly".to_string()) + })?; + } + } + + #[cfg(test)] pub(crate) async fn register_session( &self, process_id: &ProcessId, ) -> Result { - let state = Arc::new(SessionState::new()); - self.inner - .insert_session(process_id, Arc::clone(&state)) - .await?; + let state = Arc::new(SessionState::new(/*recoverable*/ true)); + self.inner.insert_session(process_id, Arc::clone(&state))?; Ok(Session { client: self.clone(), process_id: process_id.clone(), @@ -518,84 +606,52 @@ impl ExecServerClient { }) } - pub(crate) async fn unregister_session(&self, process_id: &ProcessId) { - self.inner.remove_session(process_id).await; - } - pub fn session_id(&self) -> Option { - self.inner - .session_id - .read() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .clone() + self.inner.session_id.get().cloned() } fn is_disconnected(&self) -> bool { - self.inner.disconnected.get().is_some() || self.inner.client.is_disconnected() + self.inner.is_failed() } pub(crate) async fn connect( connection: JsonRpcConnection, options: ExecServerClientConnectOptions, ) -> Result { - let (rpc_client, mut events_rx) = RpcClient::new(connection); - let inner = Arc::new_cyclic(|weak| { - let weak = weak.clone(); - let reader_task = tokio::spawn(async move { - while let Some(event) = events_rx.recv().await { - match event { - RpcClientEvent::Notification(notification) => { - if let Some(inner) = weak.upgrade() - && let Err(err) = - handle_server_notification(&inner, notification).await - { - let message = record_disconnected( - &inner, - format!("exec-server notification handling failed: {err}"), - ); - fail_all_in_flight_work(&inner, message).await; - return; - } - } - RpcClientEvent::Disconnected { reason } => { - if let Some(inner) = weak.upgrade() { - let message = record_disconnected( - &inner, - disconnected_message(reason.as_deref()), - ); - fail_all_in_flight_work(&inner, message).await; - } - return; - } - } - } - }); - - Inner { - client: rpc_client, - sessions: ArcSwap::from_pointee(HashMap::new()), - sessions_write_lock: Mutex::new(()), - disconnected: OnceLock::new(), - http_body_streams: ArcSwap::from_pointee(HashMap::new()), - http_body_stream_failures: ArcSwap::from_pointee(HashMap::new()), - http_body_streams_write_lock: Mutex::new(()), - http_body_stream_next_id: AtomicU64::new(1), - session_id: std::sync::RwLock::new(None), - reader_task, - } - }); - - let client = Self { inner }; - client.initialize(options).await?; - Ok(client) + Self::connect_with_recovery(connection, options, /*reconnect_strategy*/ None).await } - async fn notify_initialized(&self) -> Result<(), ExecServerError> { - self.inner - .client - .notify(INITIALIZED_METHOD, &serde_json::json!({})) - .await - .map_err(ExecServerError::Json) + pub(crate) async fn connect_with_recovery( + connection: JsonRpcConnection, + options: ExecServerClientConnectOptions, + reconnect_strategy: Option, + ) -> Result { + let (rpc_client, events_rx) = RpcClient::new(connection); + let rpc_client = Arc::new(rpc_client); + let session_id = OnceLock::new(); + let (connection_changed, _connection_changed_rx) = watch::channel(()); + let inner = Arc::new(Inner { + connection: StdMutex::new(ConnectionState { + status: ConnectionStatus::Connected(Arc::clone(&rpc_client)), + active_process_starts: 0, + }), + connection_changed, + sessions: ArcSwap::from_pointee(HashMap::new()), + sessions_write_lock: StdMutex::new(()), + http_body_streams: ArcSwap::from_pointee(HashMap::new()), + http_body_stream_failures: ArcSwap::from_pointee(HashMap::new()), + http_body_streams_write_lock: Mutex::new(()), + http_body_stream_next_id: AtomicU64::new(1), + session_id, + reconnect_strategy, + }); + let client = Self { inner }; + // An explicit resume can redirect notifications from running processes + // before initialize returns. Drain them immediately so a burst cannot + // fill the bounded event channel and block the initialize response. + client.spawn_rpc_reader(&rpc_client, events_rx); + client.initialize_rpc(&rpc_client, options).await?; + Ok(client) } async fn call(&self, method: &str, params: &P) -> Result @@ -603,24 +659,28 @@ impl ExecServerClient { P: serde::Serialize, T: serde::de::DeserializeOwned, { - // Reject new work before allocating a JSON-RPC request id. MCP tool - // calls, process writes, and fs operations all pass through here, so - // this is the shared low-level failure path after environment disconnect. - if let Some(error) = self.inner.disconnected_error() { - return Err(error); - } + let rpc_client = self.inner.rpc_client().await?; + self.call_rpc(&rpc_client, method, params).await + } - match self.inner.client.call(method, params).await { + async fn call_rpc( + &self, + rpc_client: &Arc, + method: &str, + params: &P, + ) -> Result + where + P: serde::Serialize, + T: serde::de::DeserializeOwned, + { + match rpc_client.call(method, params).await { Ok(response) => Ok(response), Err(error) => { let error = ExecServerError::from(error); if is_transport_closed_error(&error) { - // A call can race with disconnect after the preflight - // check. Only the reader task drains sessions so queued - // process notifications stay ordered before disconnect. - let message = disconnected_message(/*reason*/ None); - let message = record_disconnected(&self.inner, message); - Err(ExecServerError::Disconnected(message)) + Err(ExecServerError::Disconnected(disconnected_message( + /*reason*/ None, + ))) } else { Err(error) } @@ -629,6 +689,23 @@ impl ExecServerClient { } } +async fn cleanup_process_start( + client: &ExecServerClient, + process_id: &ProcessId, + state: &Arc, +) { + loop { + match client.terminate(process_id).await { + Ok(_) => break, + Err(error) if is_transport_closed_error(&error) && !client.inner.is_failed() => { + continue; + } + Err(_) => break, + } + } + client.inner.remove_session_if(process_id, state); +} + impl From for ExecServerError { fn from(value: RpcCallError) -> Self { match value { @@ -643,7 +720,7 @@ impl From for ExecServerError { } impl SessionState { - fn new() -> Self { + fn new(recoverable: bool) -> Self { let (wake_tx, _wake_rx) = watch::channel(0); Self { wake_tx, @@ -652,7 +729,7 @@ impl SessionState { PROCESS_EVENT_RETAINED_BYTES, ), ordered_events: StdMutex::new(OrderedSessionEvents::default()), - failure: Mutex::new(None), + recoverable: AtomicBool::new(recoverable), } } @@ -665,8 +742,8 @@ impl SessionState { } fn note_change(&self, seq: u64) { - let next = (*self.wake_tx.borrow()).max(seq); - let _ = self.wake_tx.send(next); + self.wake_tx + .send_modify(|current| *current = (*current).max(seq)); } /// Publishes a process event only when all earlier sequenced events have @@ -682,55 +759,61 @@ impl SessionState { return false; }; - let mut ready = Vec::new(); + let mut ordered_events = self + .ordered_events + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + // We have already delivered this sequence number or moved past it, + // so accepting it again would duplicate output or lifecycle events. + if ordered_events.failure.is_some() + || ordered_events.closed_published + || seq <= ordered_events.last_published_seq { - let mut ordered_events = self - .ordered_events - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - // We have already delivered this sequence number or moved past it, - // so accepting it again would duplicate output or lifecycle events. - if seq <= ordered_events.last_published_seq { - return false; - } - - ordered_events.pending.entry(seq).or_insert(event); - loop { - let next_seq = ordered_events.last_published_seq + 1; - let Some(event) = ordered_events.pending.remove(&next_seq) else { - break; - }; - ordered_events.last_published_seq += 1; - ready.push(event); - } + return false; } + ordered_events.pending.entry(seq).or_insert(event); + self.publish_ready(&mut ordered_events) + } + + fn publish_ready(&self, ordered_events: &mut OrderedSessionEvents) -> bool { let mut published_closed = false; - for event in ready { - published_closed |= matches!(&event, ExecProcessEvent::Closed { .. }); + loop { + let next_seq = ordered_events.last_published_seq.saturating_add(1); + let Some(event) = ordered_events.pending.remove(&next_seq) else { + break; + }; + ordered_events.last_published_seq = next_seq; + ordered_events.exit_published |= matches!(&event, ExecProcessEvent::Exited { .. }); + let is_closed = matches!(&event, ExecProcessEvent::Closed { .. }); + ordered_events.closed_published |= is_closed; + published_closed |= is_closed; self.events.publish(event); } published_closed } - async fn set_failure(&self, message: String) { - let mut failure = self.failure.lock().await; - let should_publish = failure.is_none(); - if should_publish { - *failure = Some(message.clone()); - } - drop(failure); - let next = (*self.wake_tx.borrow()).saturating_add(1); - let _ = self.wake_tx.send(next); - if should_publish { - let _ = self.publish_ordered_event(ExecProcessEvent::Failed(message)); + fn set_failure(&self, message: String) { + let mut ordered_events = self + .ordered_events + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + if ordered_events.failure.is_some() || ordered_events.closed_published { + return; } + ordered_events.failure = Some(message.clone()); + ordered_events.pending.clear(); + self.events.publish(ExecProcessEvent::Failed(message)); + drop(ordered_events); + self.wake_tx + .send_modify(|current| *current = current.saturating_add(1)); } - async fn failed_response(&self) -> Option { - self.failure + fn failed_response(&self) -> Option { + self.ordered_events .lock() - .await + .unwrap_or_else(std::sync::PoisonError::into_inner) + .failure .clone() .map(|message| self.synthesized_failure(message)) } @@ -767,27 +850,37 @@ impl Session { max_bytes: Option, wait_ms: Option, ) -> Result { - if let Some(response) = self.state.failed_response().await { - return Ok(response); - } - - match self - .client - .read(ReadParams { - process_id: self.process_id.clone(), - after_seq, - max_bytes, - wait_ms, - }) - .await - { - Ok(response) => Ok(response), - Err(err) if is_transport_closed_error(&err) => { - let message = disconnected_message(/*reason*/ None); - self.state.set_failure(message.clone()).await; - Ok(self.state.synthesized_failure(message)) + loop { + if let Some(response) = self.state.failed_response() { + return Ok(response); + } + + match self + .client + .read(ReadParams { + process_id: self.process_id.clone(), + after_seq, + max_bytes, + wait_ms, + }) + .await + { + Ok(response) => return Ok(response), + Err(error) + if is_transport_closed_error(&error) && !self.client.inner.is_failed() => + { + continue; + } + Err(error) if is_transport_closed_error(&error) => { + if let Some(response) = self.state.failed_response() { + return Ok(response); + } + let message = error.to_string(); + self.state.set_failure(message.clone()); + return Ok(self.state.synthesized_failure(message)); + } + Err(error) => return Err(error), } - Err(err) => Err(err), } } @@ -805,40 +898,31 @@ impl Session { } pub(crate) async fn unregister(&self) { - self.client.unregister_session(&self.process_id).await; + self.client + .inner + .remove_session_if(&self.process_id, &self.state); } } impl Inner { - fn disconnected_error(&self) -> Option { - self.disconnected - .get() - .cloned() - .map(ExecServerError::Disconnected) - } - - fn set_disconnected(&self, message: String) -> Option { - match self.disconnected.set(message.clone()) { - Ok(()) => Some(message), - Err(_) => None, - } - } - fn get_session(&self, process_id: &ProcessId) -> Option> { self.sessions.load().get(process_id).cloned() } - async fn insert_session( + fn insert_session( &self, process_id: &ProcessId, session: Arc, ) -> Result<(), ExecServerError> { - let _sessions_write_guard = self.sessions_write_lock.lock().await; + let _sessions_write_guard = self + .sessions_write_lock + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); // Do not register a process session that can never receive environment // notifications. Without this check, remote MCP startup could create a // dead session and wait for process output that will never arrive. - if let Some(error) = self.disconnected_error() { - return Err(error); + if let Some(message) = self.failure_message() { + return Err(ExecServerError::Disconnected(message)); } let sessions = self.sessions.load(); if sessions.contains_key(process_id) { @@ -852,19 +936,28 @@ impl Inner { Ok(()) } - async fn remove_session(&self, process_id: &ProcessId) -> Option> { - let _sessions_write_guard = self.sessions_write_lock.lock().await; + fn remove_session_if(&self, process_id: &ProcessId, expected: &Arc) { + let _sessions_write_guard = self + .sessions_write_lock + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); let sessions = self.sessions.load(); - let session = sessions.get(process_id).cloned(); - session.as_ref()?; + if !sessions + .get(process_id) + .is_some_and(|session| Arc::ptr_eq(session, expected)) + { + return; + } let mut next_sessions = sessions.as_ref().clone(); next_sessions.remove(process_id); self.sessions.store(Arc::new(next_sessions)); - session } - async fn take_all_sessions(&self) -> HashMap> { - let _sessions_write_guard = self.sessions_write_lock.lock().await; + fn take_all_sessions(&self) -> HashMap> { + let _sessions_write_guard = self + .sessions_write_lock + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); let sessions = self.sessions.load(); let drained_sessions = sessions.as_ref().clone(); self.sessions.store(Arc::new(HashMap::new())); @@ -892,31 +985,20 @@ fn is_transport_closed_error(error: &ExecServerError) -> bool { ) } -fn record_disconnected(inner: &Arc, message: String) -> String { - // The first observer records the canonical disconnect reason. Session - // draining stays with the reader task so it can preserve notification - // ordering before publishing the terminal failure. - if let Some(message) = inner.set_disconnected(message.clone()) { - message - } else { - inner.disconnected.get().cloned().unwrap_or(message) - } -} - -async fn fail_all_sessions(inner: &Arc, message: String) { - let sessions = inner.take_all_sessions().await; +fn fail_all_sessions(inner: &Arc, message: String) { + let sessions = inner.take_all_sessions(); for (_, session) in sessions { // Sessions synthesize a closed read response and emit a pushed Failed // event. That covers both polling consumers and streaming consumers // such as environment-backed MCP stdio. - session.set_failure(message.clone()).await; + session.set_failure(message.clone()); } } /// Fails all in-flight work that depends on the shared JSON-RPC transport. async fn fail_all_in_flight_work(inner: &Arc, message: String) { - fail_all_sessions(inner, message.clone()).await; + fail_all_sessions(inner, message.clone()); inner.fail_all_http_body_streams(message).await; } @@ -937,7 +1019,7 @@ async fn handle_server_notification( chunk: params.chunk, })); if published_closed { - inner.remove_session(¶ms.process_id).await; + inner.remove_session_if(¶ms.process_id, &session); } } } @@ -951,7 +1033,7 @@ async fn handle_server_notification( exit_code: params.exit_code, }); if published_closed { - inner.remove_session(¶ms.process_id).await; + inner.remove_session_if(¶ms.process_id, &session); } } } @@ -966,7 +1048,7 @@ async fn handle_server_notification( let published_closed = session.publish_ordered_event(ExecProcessEvent::Closed { seq: params.seq }); if published_closed { - inner.remove_session(¶ms.process_id).await; + inner.remove_session_if(¶ms.process_id, &session); } } } @@ -1020,6 +1102,7 @@ mod tests { #[cfg(not(windows))] use crate::client_api::DEFAULT_REMOTE_EXEC_SERVER_INITIALIZE_TIMEOUT; use crate::client_api::ExecServerTransportParams; + use crate::client_api::RemoteExecServerConnectArgs; use crate::client_api::StdioExecServerCommand; use crate::client_api::StdioExecServerConnectArgs; use crate::connection::JsonRpcConnection; @@ -1136,19 +1219,6 @@ mod tests { } } - async fn wait_for_disconnect(client: &ExecServerClient) { - timeout(Duration::from_secs(1), async { - loop { - if client.is_disconnected() { - return; - } - tokio::task::yield_now().await; - } - }) - .await - .expect("client should observe disconnect"); - } - #[cfg(not(windows))] #[tokio::test] async fn connect_stdio_command_initializes_json_rpc_client() { @@ -1567,7 +1637,7 @@ mod tests { } #[tokio::test] - async fn remote_websocket_client_replaces_disconnected_client_with_fresh_session() { + async fn remote_websocket_client_resumes_session() { let listener = TcpListener::bind("127.0.0.1:0") .await .expect("listener should bind"); @@ -1575,28 +1645,27 @@ mod tests { "ws://{}", listener.local_addr().expect("listener should have address") ); - let server = tokio::spawn({ - async move { - let mut first = accept_websocket(&listener).await; - complete_websocket_initialize( - &mut first, - "session-1", - /*expected_resume_session_id*/ None, - ) - .await; - first - .close(None) - .await - .expect("first websocket should close"); + let (resumed_tx, resumed_rx) = oneshot::channel(); + let (finish_tx, finish_rx) = oneshot::channel(); + let server = tokio::spawn(async move { + let mut first = accept_websocket(&listener).await; + complete_websocket_initialize( + &mut first, + "session-1", + /*expected_resume_session_id*/ None, + ) + .await; + first.close(None).await.expect("websocket should close"); - let mut second = accept_websocket(&listener).await; - complete_websocket_initialize( - &mut second, - "session-2", - /*expected_resume_session_id*/ None, - ) - .await; - } + let mut resumed = accept_websocket(&listener).await; + complete_websocket_initialize( + &mut resumed, + "session-1", + /*expected_resume_session_id*/ Some("session-1"), + ) + .await; + resumed_tx.send(()).expect("resume should signal"); + finish_rx.await.expect("test should finish"); }); let client = LazyRemoteExecServerClient::new(ExecServerTransportParams::WebSocketUrl { @@ -1604,16 +1673,103 @@ mod tests { connect_timeout: Duration::from_secs(1), initialize_timeout: Duration::from_secs(1), }); - let first = client.get().await.expect("first client should connect"); - wait_for_disconnect(&first).await; + let stable_client = client.get().await.expect("client should connect"); + timeout(Duration::from_secs(1), resumed_rx) + .await + .expect("session resume should not time out") + .expect("session resume should signal"); + let reused_client = client.get().await.expect("client should stay connected"); + assert_eq!(stable_client.session_id().as_deref(), Some("session-1")); + assert!(Arc::ptr_eq(&stable_client.inner, &reused_client.inner)); + finish_tx.send(()).expect("test should finish"); + server.await.expect("server task should finish"); + } - let (replacement_a, replacement_b) = tokio::join!(client.get(), client.get()); - let replacement_a = replacement_a.expect("first replacement should connect"); - let replacement_b = replacement_b.expect("second replacement should reuse client"); - assert_eq!(replacement_a.session_id().as_deref(), Some("session-2")); - assert_eq!(replacement_b.session_id().as_deref(), Some("session-2")); - assert!(Arc::ptr_eq(&replacement_a.inner, &replacement_b.inner)); + #[tokio::test] + async fn explicit_resume_drains_notifications_before_initialize_response() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let websocket_url = format!( + "ws://{}", + listener.local_addr().expect("listener should have address") + ); + let (initialized_tx, initialized_rx) = oneshot::channel(); + let (finish_tx, finish_rx) = oneshot::channel(); + let server = tokio::spawn(async move { + let mut websocket = accept_websocket(&listener).await; + let initialize = read_jsonrpc_websocket(&mut websocket).await; + let request = match initialize { + JSONRPCMessage::Request(request) if request.method == INITIALIZE_METHOD => request, + other => panic!("expected initialize request, got {other:?}"), + }; + let params: crate::protocol::InitializeParams = + serde_json::from_value(request.params.expect("initialize params should exist")) + .expect("initialize params should deserialize"); + assert_eq!(params.resume_session_id.as_deref(), Some("session-1")); + for seq in 1..=256 { + write_jsonrpc_websocket( + &mut websocket, + JSONRPCMessage::Notification(JSONRPCNotification { + method: EXEC_OUTPUT_DELTA_METHOD.to_string(), + params: Some( + serde_json::to_value(ExecOutputDeltaNotification { + process_id: ProcessId::from("busy-process"), + seq, + stream: ExecOutputStream::Stdout, + chunk: b"output".to_vec().into(), + }) + .expect("output notification should serialize"), + ), + }), + ) + .await; + } + write_jsonrpc_websocket( + &mut websocket, + JSONRPCMessage::Response(JSONRPCResponse { + id: request.id, + result: serde_json::to_value(InitializeResponse { + session_id: "session-1".to_string(), + }) + .expect("initialize response should serialize"), + }), + ) + .await; + + let initialized = read_jsonrpc_websocket(&mut websocket).await; + match initialized { + JSONRPCMessage::Notification(notification) + if notification.method == INITIALIZED_METHOD => {} + other => panic!("expected initialized notification, got {other:?}"), + } + initialized_tx + .send(()) + .expect("initialized notification should signal"); + finish_rx.await.expect("test should finish"); + }); + + let client = timeout( + Duration::from_secs(1), + ExecServerClient::connect_websocket(RemoteExecServerConnectArgs { + websocket_url, + client_name: "test-client".to_string(), + connect_timeout: Duration::from_secs(1), + initialize_timeout: Duration::from_secs(1), + resume_session_id: Some("session-1".to_string()), + }), + ) + .await + .expect("explicit resume should not time out") + .expect("explicit resume should connect"); + assert_eq!(client.session_id().as_deref(), Some("session-1")); + + timeout(Duration::from_secs(1), initialized_rx) + .await + .expect("initialized notification should not time out") + .expect("initialized notification should signal"); + finish_tx.send(()).expect("test should finish"); server.await.expect("server task should finish"); } diff --git a/codex-rs/exec-server/src/client/rpc_http_client.rs b/codex-rs/exec-server/src/client/rpc_http_client.rs index d2ce842ca..d243a8d42 100644 --- a/codex-rs/exec-server/src/client/rpc_http_client.rs +++ b/codex-rs/exec-server/src/client/rpc_http_client.rs @@ -42,6 +42,7 @@ impl ExecServerClient { &self, mut params: HttpRequestParams, ) -> Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError> { + let rpc_client = self.inner.rpc_client().await?; params.stream_response = true; let request_id = self.inner.next_http_body_stream_request_id(); params.request_id = request_id.clone(); @@ -51,7 +52,10 @@ impl ExecServerClient { .await?; let mut registration = HttpBodyStreamRegistration::new(Arc::clone(&self.inner), request_id.clone()); - let response = match self.call(HTTP_REQUEST_METHOD, ¶ms).await { + let response = match self + .call_rpc(&rpc_client, HTTP_REQUEST_METHOD, ¶ms) + .await + { Ok(response) => response, Err(error) => { self.inner.remove_http_body_stream(&request_id).await; diff --git a/codex-rs/exec-server/src/client_recovery.rs b/codex-rs/exec-server/src/client_recovery.rs new file mode 100644 index 000000000..9a2f3c29a --- /dev/null +++ b/codex-rs/exec-server/src/client_recovery.rs @@ -0,0 +1,516 @@ +use std::sync::Arc; +use std::sync::atomic::Ordering; +use std::time::Duration; + +use tokio::sync::mpsc; +use tokio::time::Instant; +use tokio::time::sleep; +use tokio::time::timeout_at; + +use super::ConnectionStatus; +use super::ExecServerClient; +use super::ExecServerError; +use super::Inner; +use super::OrderedSessionEvents; +use super::SessionState; +use super::disconnected_message; +use super::fail_all_in_flight_work; +use super::handle_server_notification; +use super::is_transport_closed_error; +use crate::process::ExecProcessEvent; +use crate::protocol::EXEC_READ_METHOD; +use crate::protocol::EXEC_TERMINATE_METHOD; +use crate::protocol::ReadParams; +use crate::protocol::ReadResponse; +use crate::protocol::TerminateParams; +use crate::protocol::TerminateResponse; +use crate::rpc::RpcClient; +use crate::rpc::RpcClientEvent; +use crate::rpc::SESSION_ALREADY_ATTACHED_ERROR_CODE; + +#[cfg(test)] +const SESSION_RECOVERY_TIMEOUT: Duration = Duration::from_millis(500); +#[cfg(not(test))] +// Leave margin inside the server's 30-second retention windows because the +// client and server start their disconnect clocks independently. +const SESSION_RECOVERY_TIMEOUT: Duration = Duration::from_secs(25); +const SESSION_RECOVERY_RETRY_INTERVAL: Duration = Duration::from_millis(100); + +impl SessionState { + fn last_published_seq(&self) -> u64 { + self.ordered_events + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .last_published_seq + } + + fn recover_events(&self, response: ReadResponse) -> Result { + let ReadResponse { + chunks, + next_seq, + exited, + exit_code, + closed, + failure, + } = response; + if let Some(message) = failure { + return Err(ExecServerError::Protocol(format!( + "process failed while recovering: {message}" + ))); + } + + let target_seq = next_seq.saturating_sub(1); + let published_closed = { + let mut ordered_events = self + .ordered_events + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + if ordered_events.failure.is_some() + || ordered_events.closed_published + || target_seq <= ordered_events.last_published_seq + { + return Ok(false); + } + for chunk in chunks { + if chunk.seq > ordered_events.last_published_seq { + ordered_events + .pending + .entry(chunk.seq) + .or_insert(ExecProcessEvent::Output(chunk)); + } + } + if closed { + match ordered_events.pending.get(&target_seq) { + Some(ExecProcessEvent::Closed { .. }) => {} + Some(_) => { + return Err(ExecServerError::Protocol(format!( + "process close sequence {target_seq} conflicts with recovered output" + ))); + } + None => { + ordered_events + .pending + .insert(target_seq, ExecProcessEvent::Closed { seq: target_seq }); + } + } + } + + let exit_known = ordered_events.exit_published + || ordered_events + .pending + .range(..=target_seq) + .any(|(_, event)| matches!(event, ExecProcessEvent::Exited { .. })); + let event_count = target_seq - ordered_events.last_published_seq; + let retained_count = ordered_events + .pending + .range(ordered_events.last_published_seq.saturating_add(1)..=target_seq) + .count() as u64; + let missing_count = event_count.saturating_sub(retained_count); + if exited && !exit_known { + if missing_count != 1 { + return Err(recovery_gap_error(target_seq)); + } + let seq = first_missing_seq(&ordered_events, target_seq); + let exit_code = exit_code.ok_or_else(|| { + ExecServerError::Protocol( + "recovering exited process did not include its exit code".to_string(), + ) + })?; + ordered_events + .pending + .insert(seq, ExecProcessEvent::Exited { seq, exit_code }); + } else if missing_count != 0 { + return Err(recovery_gap_error(target_seq)); + } + self.publish_ready(&mut ordered_events) + }; + + self.note_change(target_seq); + Ok(published_closed) + } +} + +fn first_missing_seq(events: &OrderedSessionEvents, target_seq: u64) -> u64 { + let mut expected = events.last_published_seq.saturating_add(1); + for seq in events + .pending + .range(expected..=target_seq) + .map(|(seq, _)| *seq) + { + if seq != expected { + break; + } + expected = expected.saturating_add(1); + } + expected +} + +fn recovery_gap_error(target_seq: u64) -> ExecServerError { + ExecServerError::Protocol(format!( + "process events are no longer retained while recovering through sequence {target_seq}" + )) +} + +impl Inner { + pub(super) async fn rpc_client(self: &Arc) -> Result, ExecServerError> { + let mut connection_changed = self.connection_changed.subscribe(); + loop { + if let Some(message) = self.failure_message() { + return Err(ExecServerError::Disconnected(message)); + } + + let rpc_client = { + let connection = self + .connection + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + match &connection.status { + ConnectionStatus::Connected(rpc_client) => Some(Arc::clone(rpc_client)), + ConnectionStatus::Recovering | ConnectionStatus::Failed(_) => None, + } + }; + let Some(rpc_client) = rpc_client else { + let _ = connection_changed.changed().await; + continue; + }; + if !rpc_client.is_disconnected() { + return Ok(rpc_client); + } + + let _ = connection_changed.changed().await; + } + } + + pub(super) fn begin_process_start(&self, expected: &Arc) -> bool { + let mut connection = self + .connection + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let ConnectionStatus::Connected(current) = &connection.status else { + return false; + }; + if !Arc::ptr_eq(current, expected) || expected.is_disconnected() { + return false; + } + connection.active_process_starts += 1; + true + } + + pub(super) fn finish_process_start(&self) { + { + let mut connection = self + .connection + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + if connection.active_process_starts == 0 { + tracing::error!("finished an exec-server process start that was not active"); + return; + } + connection.active_process_starts -= 1; + } + self.notify_connection_changed(); + } + + pub(super) fn is_failed(&self) -> bool { + self.failure_message().is_some() + } + + pub(super) fn failure_message(&self) -> Option { + let connection = self + .connection + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + match &connection.status { + ConnectionStatus::Failed(message) => Some(message.clone()), + ConnectionStatus::Connected(_) | ConnectionStatus::Recovering => None, + } + } + + fn request_recovery( + self: &Arc, + failed_rpc_client: Arc, + disconnect_message: String, + ) { + let should_recover = { + let mut connection = self + .connection + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + match &connection.status { + ConnectionStatus::Connected(current) + if Arc::ptr_eq(current, &failed_rpc_client) => + { + connection.status = ConnectionStatus::Recovering; + true + } + ConnectionStatus::Connected(_) + | ConnectionStatus::Recovering + | ConnectionStatus::Failed(_) => false, + } + }; + if !should_recover { + return; + } + + self.notify_connection_changed(); + let inner = Arc::clone(self); + tokio::spawn(async move { + inner.recover(disconnect_message).await; + }); + } + + async fn recover(self: &Arc, disconnect_message: String) { + let deadline = Instant::now() + SESSION_RECOVERY_TIMEOUT; + self.fail_all_http_body_streams(disconnect_message.clone()) + .await; + if timeout_at(deadline, self.wait_for_process_starts()) + .await + .is_err() + { + let message = format!( + "{disconnect_message}; failed to resume exec-server session: recovery timed out after {SESSION_RECOVERY_TIMEOUT:?}" + ); + self.fail(message).await; + return; + } + if self.reconnect_strategy.is_none() { + self.fail(disconnect_message).await; + return; + } + + let Some(session_id) = self.session_id.get().cloned() else { + let message = format!( + "{disconnect_message}; failed to resume exec-server session: missing session id" + ); + self.fail(message).await; + return; + }; + let last_error = loop { + match timeout_at(deadline, self.resume_once(&session_id)).await { + Ok(Ok(candidate)) if !candidate.is_disconnected() => { + if self.install_recovered_client(candidate) { + return; + } + } + Ok(Ok(_)) => {} + Ok(Err(error)) if !is_retryable_recovery_error(&error) => { + break error.to_string(); + } + Ok(Err(_)) => {} + Err(_) => { + break format!("recovery timed out after {SESSION_RECOVERY_TIMEOUT:?}"); + } + } + + let now = Instant::now(); + if now >= deadline { + break format!("recovery timed out after {SESSION_RECOVERY_TIMEOUT:?}"); + } + sleep(SESSION_RECOVERY_RETRY_INTERVAL.min(deadline - now)).await; + }; + + let message = + format!("{disconnect_message}; failed to resume exec-server session: {last_error}"); + self.fail(message).await; + } + + async fn wait_for_process_starts(&self) { + let mut connection_changed = self.connection_changed.subscribe(); + loop { + let starts_are_done = self + .connection + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .active_process_starts + == 0; + if starts_are_done { + return; + } + let _ = connection_changed.changed().await; + } + } + + fn install_recovered_client(&self, rpc_client: Arc) -> bool { + let installed = { + let mut connection = self + .connection + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + if !matches!(connection.status, ConnectionStatus::Recovering) + || rpc_client.is_disconnected() + { + false + } else { + connection.status = ConnectionStatus::Connected(rpc_client); + true + } + }; + if installed { + self.notify_connection_changed(); + } + installed + } + + fn notify_connection_changed(&self) { + self.connection_changed.send_replace(()); + } + + async fn resume_once( + self: &Arc, + session_id: &str, + ) -> Result, ExecServerError> { + let reconnect_strategy = self + .reconnect_strategy + .as_ref() + .ok_or_else(|| ExecServerError::Protocol("missing reconnect strategy".to_string()))?; + let (connection, options) = reconnect_strategy.resume(session_id).await?; + let (rpc_client, events_rx) = RpcClient::new(connection); + let rpc_client = Arc::new(rpc_client); + let client = ExecServerClient { + inner: Arc::clone(self), + }; + // Resuming a session redirects notifications from its running processes + // to this connection during initialize. Drain them immediately so a + // burst cannot fill the bounded event channel and block the initialize + // response behind it. + client.spawn_rpc_reader(&rpc_client, events_rx); + client.initialize_rpc(&rpc_client, options).await?; + + self.recover_processes(&rpc_client).await?; + Ok(rpc_client) + } + + async fn recover_processes( + self: &Arc, + rpc_client: &RpcClient, + ) -> Result<(), ExecServerError> { + let sessions = self.sessions.load_full(); + for (process_id, session) in sessions.iter() { + if !session.recoverable.load(Ordering::Acquire) { + continue; + } + let response = rpc_client + .call::<_, ReadResponse>( + EXEC_READ_METHOD, + &ReadParams { + process_id: process_id.clone(), + after_seq: Some(session.last_published_seq()), + max_bytes: None, + wait_ms: Some(0), + }, + ) + .await + .map_err(ExecServerError::from); + let recovered = match response { + Ok(response) => session.recover_events(response), + Err(error) if is_transport_closed_error(&error) => return Err(error), + Err(error) => Err(error), + }; + match recovered { + Ok(true) => self.remove_session_if(process_id, session), + Ok(false) => {} + Err(error) => { + let terminated: Result = rpc_client + .call( + EXEC_TERMINATE_METHOD, + &TerminateParams { + process_id: process_id.clone(), + }, + ) + .await + .map_err(ExecServerError::from); + if let Err(terminate_error) = terminated + && is_transport_closed_error(&terminate_error) + { + return Err(terminate_error); + } + self.remove_session_if(process_id, session); + session.set_failure(format!("failed to recover process {process_id}: {error}")); + } + } + } + Ok(()) + } + + async fn fail(self: &Arc, message: String) { + let (message, newly_failed) = { + let mut connection = self + .connection + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + match &connection.status { + ConnectionStatus::Failed(existing) => (existing.clone(), false), + ConnectionStatus::Connected(_) | ConnectionStatus::Recovering => { + connection.status = ConnectionStatus::Failed(message.clone()); + (message, true) + } + } + }; + if newly_failed { + self.notify_connection_changed(); + fail_all_in_flight_work(self, message.clone()).await; + } + } +} + +impl ExecServerClient { + pub(super) fn spawn_rpc_reader( + &self, + rpc_client: &Arc, + mut events_rx: mpsc::Receiver, + ) { + let inner = Arc::downgrade(&self.inner); + let rpc_client = Arc::downgrade(rpc_client); + tokio::spawn(async move { + while let Some(event) = events_rx.recv().await { + let (Some(inner), Some(rpc_client)) = (inner.upgrade(), rpc_client.upgrade()) + else { + return; + }; + match event { + RpcClientEvent::Notification(notification) => { + if let Err(error) = handle_server_notification(&inner, notification).await { + rpc_client.close_transport().await; + inner.request_recovery( + rpc_client, + format!("exec-server notification handling failed: {error}"), + ); + return; + } + } + RpcClientEvent::Disconnected { reason } => { + inner.request_recovery(rpc_client, disconnected_message(reason.as_deref())); + return; + } + } + } + }); + } +} + +fn is_retryable_recovery_error(error: &ExecServerError) -> bool { + is_transport_closed_error(error) + || matches!( + error, + ExecServerError::WebSocketConnectTimeout { .. } + | ExecServerError::WebSocketConnect { .. } + | ExecServerError::InitializeTimedOut { .. } + ) + || matches!( + error, + ExecServerError::EnvironmentRegistryRequest(error) + if error.is_connect() || error.is_timeout() + ) + || matches!( + error, + ExecServerError::EnvironmentRegistryHttp { status, .. } + if status.is_server_error() + || *status == reqwest::StatusCode::REQUEST_TIMEOUT + || *status == reqwest::StatusCode::TOO_MANY_REQUESTS + ) + || matches!( + error, + ExecServerError::Server { code, .. } + if *code == SESSION_ALREADY_ATTACHED_ERROR_CODE + ) +} diff --git a/codex-rs/exec-server/src/client_transport.rs b/codex-rs/exec-server/src/client_transport.rs index c31a8fcdf..b079d3272 100644 --- a/codex-rs/exec-server/src/client_transport.rs +++ b/codex-rs/exec-server/src/client_transport.rs @@ -1,4 +1,7 @@ use std::process::Stdio; +use std::sync::Arc; +use std::time::Duration; + use tokio::io::AsyncBufReadExt; use tokio::io::BufReader; use tokio::process::Command; @@ -14,12 +17,15 @@ use crate::ExecServerClient; use crate::ExecServerError; use crate::client_api::DEFAULT_REMOTE_EXEC_SERVER_CONNECT_TIMEOUT; use crate::client_api::DEFAULT_REMOTE_EXEC_SERVER_INITIALIZE_TIMEOUT; +use crate::client_api::ExecServerClientConnectOptions; use crate::client_api::NoiseRendezvousConnectArgs; use crate::client_api::NoiseRendezvousConnectBundle; +use crate::client_api::NoiseRendezvousConnectProvider; use crate::client_api::RemoteExecServerConnectArgs; use crate::client_api::StdioExecServerCommand; use crate::client_api::StdioExecServerConnectArgs; use crate::connection::JsonRpcConnection; +use crate::noise_channel::NoiseChannelIdentity; use crate::noise_relay::NoiseHarnessConnectionArgs; use crate::noise_relay::noise_harness_connection_from_websocket; use crate::noise_relay::noise_relay_websocket_config; @@ -27,6 +33,57 @@ use crate::relay::harness_connection_from_websocket; const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment"; +/// Reopens the transport for one logical exec-server client session. +/// +/// URL connections reuse their configured endpoint. Noise connections retain +/// the harness identity but fetch a fresh single-use authorization bundle for +/// every physical connection attempt. +#[derive(Clone)] +pub(crate) enum ExecServerReconnectStrategy { + WebSocket(RemoteExecServerConnectArgs), + NoiseRendezvous { + provider: Arc, + identity: NoiseChannelIdentity, + client_name: String, + connect_timeout: Duration, + initialize_timeout: Duration, + }, +} + +impl ExecServerReconnectStrategy { + pub(crate) async fn resume( + &self, + session_id: &str, + ) -> Result<(JsonRpcConnection, ExecServerClientConnectOptions), ExecServerError> { + match self { + Self::WebSocket(args) => { + let mut args = args.clone(); + args.resume_session_id = Some(session_id.to_string()); + let connection = ExecServerClient::open_websocket_connection(&args).await?; + Ok((connection, args.into())) + } + Self::NoiseRendezvous { + provider, + identity, + client_name, + connect_timeout, + initialize_timeout, + } => { + let bundle = provider.connect_bundle(identity.public_key()).await?; + ExecServerClient::open_noise_rendezvous_connection(NoiseRendezvousConnectArgs { + bundle, + harness_identity: identity.clone(), + client_name: client_name.clone(), + connect_timeout: *connect_timeout, + initialize_timeout: *initialize_timeout, + resume_session_id: Some(session_id.to_string()), + }) + .await + } + } + } +} + impl ExecServerClient { /// Open the selected transport and run the common JSON-RPC initialization. /// Noise connection details are fetched here so reconnects get a fresh URL @@ -53,16 +110,25 @@ impl ExecServerClient { provider, identity, } => { - let bundle = provider.connect_bundle(identity.public_key()).await?; - Self::connect_noise_rendezvous(NoiseRendezvousConnectArgs { - bundle, - harness_identity: identity, + let reconnect_strategy = ExecServerReconnectStrategy::NoiseRendezvous { + provider: Arc::clone(&provider), + identity: identity.clone(), client_name: ENVIRONMENT_CLIENT_NAME.to_string(), connect_timeout: DEFAULT_REMOTE_EXEC_SERVER_CONNECT_TIMEOUT, initialize_timeout: DEFAULT_REMOTE_EXEC_SERVER_INITIALIZE_TIMEOUT, - resume_session_id: None, - }) - .await + }; + let bundle = provider.connect_bundle(identity.public_key()).await?; + let (connection, options) = + Self::open_noise_rendezvous_connection(NoiseRendezvousConnectArgs { + bundle, + harness_identity: identity, + client_name: ENVIRONMENT_CLIENT_NAME.to_string(), + connect_timeout: DEFAULT_REMOTE_EXEC_SERVER_CONNECT_TIMEOUT, + initialize_timeout: DEFAULT_REMOTE_EXEC_SERVER_INITIALIZE_TIMEOUT, + resume_session_id: None, + }) + .await?; + Self::connect_with_recovery(connection, options, Some(reconnect_strategy)).await } crate::client_api::ExecServerTransportParams::StdioCommand { command, @@ -82,6 +148,19 @@ impl ExecServerClient { pub async fn connect_websocket( args: RemoteExecServerConnectArgs, ) -> Result { + let connection = Self::open_websocket_connection(&args).await?; + let options = args.clone().into(); + Self::connect_with_recovery( + connection, + options, + Some(ExecServerReconnectStrategy::WebSocket(args)), + ) + .await + } + + pub(crate) async fn open_websocket_connection( + args: &RemoteExecServerConnectArgs, + ) -> Result { ensure_rustls_crypto_provider(); let websocket_url = args.websocket_url.clone(); let connect_timeout = args.connect_timeout; @@ -102,15 +181,26 @@ impl ExecServerClient { } else { JsonRpcConnection::from_websocket(stream, connection_label) }; - Self::connect(connection, args.into()).await + Ok(connection) } - /// Connect to one exec-server through an authenticated rendezvous stream. + /// Connect to one exec-server through an authenticated rendezvous stream + /// using a caller-supplied single-use authorization bundle. + /// /// The executor key is pinned before JSON-RPC starts; the websocket carries - /// only ciphertext after that. + /// only ciphertext after that. Environment-managed connections use a + /// retained [`NoiseRendezvousConnectProvider`] so recovery can fetch a fresh + /// bundle for each reconnect. pub async fn connect_noise_rendezvous( args: NoiseRendezvousConnectArgs, ) -> Result { + let (connection, options) = Self::open_noise_rendezvous_connection(args).await?; + Self::connect(connection, options).await + } + + pub(crate) async fn open_noise_rendezvous_connection( + args: NoiseRendezvousConnectArgs, + ) -> Result<(JsonRpcConnection, ExecServerClientConnectOptions), ExecServerError> { ensure_rustls_crypto_provider(); // Keep the registry-issued URL, key, and authorization together for this // connection attempt. @@ -164,15 +254,14 @@ impl ExecServerClient { harness_key_authorization, }, ); - Self::connect( + Ok(( connection, - crate::client_api::ExecServerClientConnectOptions { + ExecServerClientConnectOptions { client_name, initialize_timeout, resume_session_id, }, - ) - .await + )) } pub(crate) async fn connect_stdio_command( diff --git a/codex-rs/exec-server/src/local_process.rs b/codex-rs/exec-server/src/local_process.rs index 78a53897e..d8ee65160 100644 --- a/codex-rs/exec-server/src/local_process.rs +++ b/codex-rs/exec-server/src/local_process.rs @@ -81,8 +81,10 @@ struct RunningProcess { closed: bool, } +struct ProcessStart; + enum ProcessEntry { - Starting, + Starting(Arc), Running(Box), } @@ -128,7 +130,7 @@ impl LocalProcess { processes .drain() .filter_map(|(_, process)| match process { - ProcessEntry::Starting => None, + ProcessEntry::Starting(_) => None, ProcessEntry::Running(process) => Some(process), }) .collect::>() @@ -163,6 +165,7 @@ impl LocalProcess { )) })?; + let start = Arc::new(ProcessStart); { let mut process_map = self.inner.processes.lock().await; if process_map.contains_key(&process_id) { @@ -170,7 +173,10 @@ impl LocalProcess { "process {process_id} already exists" ))); } - process_map.insert(process_id.clone(), ProcessEntry::Starting); + process_map.insert( + process_id.clone(), + ProcessEntry::Starting(Arc::clone(&start)), + ); } let env = child_env(¶ms); @@ -207,7 +213,10 @@ impl LocalProcess { Ok(spawned) => spawned, Err(err) => { let mut process_map = self.inner.processes.lock().await; - if matches!(process_map.get(&process_id), Some(ProcessEntry::Starting)) { + if matches!( + process_map.get(&process_id), + Some(ProcessEntry::Starting(current)) if Arc::ptr_eq(current, &start) + ) { process_map.remove(&process_id); } return Err(internal_error(err.to_string())); @@ -222,6 +231,16 @@ impl LocalProcess { ); { let mut process_map = self.inner.processes.lock().await; + if !matches!( + process_map.get(&process_id), + Some(ProcessEntry::Starting(current)) if Arc::ptr_eq(current, &start) + ) { + drop(process_map); + spawned.session.terminate(); + return Err(invalid_request(format!( + "process {process_id} start was cancelled" + ))); + } process_map.insert( process_id.clone(), ProcessEntry::Running(Box::new(RunningProcess { @@ -320,7 +339,9 @@ impl LocalProcess { break; } } - + if params.max_bytes.is_none() { + next_seq = process.next_seq; + } ( ReadResponse { chunks, @@ -408,7 +429,7 @@ impl LocalProcess { .signal(pty_process_signal(params.signal)) .map_err(|err| internal_error(format!("failed to signal process: {err}")))? } - Some(ProcessEntry::Starting) | None => {} + Some(ProcessEntry::Starting(_)) | None => {} } } @@ -420,7 +441,7 @@ impl LocalProcess { params: TerminateParams, ) -> Result { let running = { - let process_map = self.inner.processes.lock().await; + let mut process_map = self.inner.processes.lock().await; match process_map.get(¶ms.process_id) { Some(ProcessEntry::Running(process)) => { if process.exit_code.is_some() { @@ -429,7 +450,11 @@ impl LocalProcess { process.session.terminate(); true } - Some(ProcessEntry::Starting) | None => false, + Some(ProcessEntry::Starting(_)) => { + process_map.remove(¶ms.process_id); + true + } + None => false, } }; @@ -915,6 +940,16 @@ mod tests { ) .await .expect("process should close"); + let replay_after_exit = backend + .exec_read(ReadParams { + process_id: process.process_id.clone(), + after_seq: Some(1), + max_bytes: None, + wait_ms: Some(0), + }) + .await + .expect("closed process should remain readable"); + assert_eq!(replay_after_exit.next_seq, 4); backend.shutdown().await; } diff --git a/codex-rs/exec-server/src/remote_process.rs b/codex-rs/exec-server/src/remote_process.rs index e130114cc..72f41ae70 100644 --- a/codex-rs/exec-server/src/remote_process.rs +++ b/codex-rs/exec-server/src/remote_process.rs @@ -35,13 +35,8 @@ impl RemoteProcess { &self, params: ExecParams, ) -> Result { - let process_id = params.process_id.clone(); let client = self.client.get().await?; - let session = client.register_session(&process_id).await?; - if let Err(err) = client.exec(params).await { - session.unregister().await; - return Err(err); - } + let session = client.start_process(params).await?; Ok(StartedExecProcess { process: Arc::new(RemoteExecProcess { session }), diff --git a/codex-rs/exec-server/src/rpc.rs b/codex-rs/exec-server/src/rpc.rs index 981a1c1a8..82cf2b200 100644 --- a/codex-rs/exec-server/src/rpc.rs +++ b/codex-rs/exec-server/src/rpc.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::future::Future; use std::pin::Pin; use std::sync::Arc; +use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicI64; use std::sync::atomic::Ordering; @@ -25,6 +26,8 @@ use crate::connection::JsonRpcConnection; use crate::connection::JsonRpcConnectionEvent; use crate::connection::JsonRpcTransport; +pub(crate) const SESSION_ALREADY_ATTACHED_ERROR_CODE: i64 = -32010; + #[derive(Debug)] pub(crate) enum RpcCallError { /// The underlying JSON-RPC transport closed before this call completed. @@ -225,6 +228,7 @@ pub(crate) struct RpcClient { // immediately when the socket closes, even if no JSON-RPC error response // can be delivered for their request id. disconnected_rx: watch::Receiver, + closed: Arc, next_request_id: AtomicI64, transport_tasks: Vec>, transport: JsonRpcTransport, @@ -241,9 +245,11 @@ impl RpcClient { transport, } = connection; let pending = Arc::new(Mutex::new(HashMap::::new())); + let closed = Arc::new(AtomicBool::new(false)); let (event_tx, event_rx) = mpsc::channel(128); let pending_for_reader = Arc::clone(&pending); + let closed_for_reader = Arc::clone(&closed); let transport_for_reader = transport.clone(); let reader_task = tokio::spawn(async move { let disconnect_reason = loop { @@ -269,12 +275,13 @@ impl RpcClient { } }; + closed_for_reader.store(true, Ordering::Release); + drain_pending(&pending_for_reader).await; let _ = event_tx .send(RpcClientEvent::Disconnected { reason: disconnect_reason, }) .await; - drain_pending(&pending_for_reader).await; transport_for_reader.terminate(); }); @@ -283,6 +290,7 @@ impl RpcClient { write_tx, pending, disconnected_rx, + closed, next_request_id: AtomicI64::new(1), transport_tasks, transport, @@ -296,24 +304,31 @@ impl RpcClient { &self, method: &str, params: &P, - ) -> Result<(), serde_json::Error> { - let params = serde_json::to_value(params)?; + ) -> Result<(), RpcCallError> { + let params = serde_json::to_value(params).map_err(RpcCallError::Json)?; + if self.closed.load(Ordering::Acquire) || *self.disconnected_rx.borrow() { + return Err(RpcCallError::Closed); + } self.write_tx .send(JSONRPCMessage::Notification(JSONRPCNotification { method: method.to_string(), params: Some(params), })) .await - .map_err(|_| { - serde_json::Error::io(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "JSON-RPC transport closed", - )) - }) + .map_err(|_| RpcCallError::Closed) } pub(crate) fn is_disconnected(&self) -> bool { - *self.disconnected_rx.borrow() + self.closed.load(Ordering::Acquire) || *self.disconnected_rx.borrow() + } + + pub(crate) async fn close_transport(&self) { + self.closed.store(true, Ordering::Release); + self.transport.terminate(); + for task in &self.transport_tasks { + task.abort(); + } + drain_pending(&self.pending).await; } pub(crate) async fn call(&self, method: &str, params: &P) -> Result @@ -328,7 +343,7 @@ impl RpcClient { // Registering the pending request and checking disconnect must be // atomic with the reader's drain_pending path. Otherwise a call // can sneak in after the drain and wait forever. - if *self.disconnected_rx.borrow() { + if self.closed.load(Ordering::Acquire) || *self.disconnected_rx.borrow() { return Err(RpcCallError::Closed); } pending.insert(request_id.clone(), response_tx); @@ -417,6 +432,14 @@ pub(crate) fn invalid_request(message: String) -> JSONRPCErrorError { } } +pub(crate) fn session_already_attached(message: String) -> JSONRPCErrorError { + JSONRPCErrorError { + code: SESSION_ALREADY_ATTACHED_ERROR_CODE, + data: None, + message, + } +} + pub(crate) fn method_not_found(message: String) -> JSONRPCErrorError { JSONRPCErrorError { code: -32601, diff --git a/codex-rs/exec-server/src/server/handler/tests.rs b/codex-rs/exec-server/src/server/handler/tests.rs index 97bfba534..0bd682bd9 100644 --- a/codex-rs/exec-server/src/server/handler/tests.rs +++ b/codex-rs/exec-server/src/server/handler/tests.rs @@ -257,7 +257,7 @@ async fn active_session_resume_is_rejected() { .await .expect_err("active session resume should fail"); - assert_eq!(err.code, -32600); + assert_eq!(err.code, crate::rpc::SESSION_ALREADY_ATTACHED_ERROR_CODE); assert_eq!( err.message, format!( diff --git a/codex-rs/exec-server/src/server/session_registry.rs b/codex-rs/exec-server/src/server/session_registry.rs index 82c779c6b..59da3b50a 100644 --- a/codex-rs/exec-server/src/server/session_registry.rs +++ b/codex-rs/exec-server/src/server/session_registry.rs @@ -9,12 +9,13 @@ use uuid::Uuid; use crate::rpc::RpcNotificationSender; use crate::rpc::invalid_request; +use crate::rpc::session_already_attached; use crate::server::process_handler::ProcessHandler; #[cfg(test)] const DETACHED_SESSION_TTL: Duration = Duration::from_millis(200); #[cfg(not(test))] -const DETACHED_SESSION_TTL: Duration = Duration::from_secs(10); +const DETACHED_SESSION_TTL: Duration = Duration::from_secs(30); pub(crate) struct SessionRegistry { sessions: Mutex>>, @@ -82,7 +83,7 @@ impl SessionRegistry { })?; Ok(AttachOutcome::Expired { session_id, entry }) } else if entry.has_active_connection() { - Err(invalid_request(format!( + Err(session_already_attached(format!( "session {session_id} is already attached to another connection" ))) } else { @@ -176,6 +177,7 @@ impl SessionEntry { return false; } + self.process.set_notification_sender(/*notifications*/ None); attachment.current_connection_id = None; attachment.detached_connection_id = Some(connection_id); attachment.detached_expires_at = Some(tokio::time::Instant::now() + DETACHED_SESSION_TTL); @@ -245,10 +247,6 @@ impl SessionHandle { return; } - self.entry - .process - .set_notification_sender(/*notifications*/ None); - let registry = Arc::clone(&self.registry); let session_id = self.entry.session_id.clone(); let connection_id = self.connection_id; diff --git a/codex-rs/exec-server/tests/common/exec_server.rs b/codex-rs/exec-server/tests/common/exec_server.rs index f1bf03d25..ad089d5cd 100644 --- a/codex-rs/exec-server/tests/common/exec_server.rs +++ b/codex-rs/exec-server/tests/common/exec_server.rs @@ -14,8 +14,13 @@ use futures::StreamExt; use tempfile::TempDir; use tokio::io::AsyncBufReadExt; use tokio::io::BufReader; +use tokio::io::copy_bidirectional; +use tokio::net::TcpListener; +use tokio::net::TcpStream; use tokio::process::Child; use tokio::process::Command; +use tokio::sync::oneshot; +use tokio::task::JoinHandle; use tokio::time::Instant; use tokio::time::sleep; use tokio::time::timeout; @@ -48,6 +53,20 @@ pub(crate) struct TestCodexHelperPaths { pub(crate) codex_linux_sandbox_exe: Option, } +pub(crate) struct DisconnectableWebSocketProxy { + websocket_url: String, + pause_tx: Option>, + blocked_connection_rx: Option>, + resume_tx: Option>, + task: JoinHandle<()>, +} + +impl Drop for DisconnectableWebSocketProxy { + fn drop(&mut self) { + self.task.abort(); + } +} + pub(crate) fn test_codex_helper_paths() -> anyhow::Result { let (helper_binary, codex_linux_sandbox_exe) = super::current_test_binary_helper_paths()?; Ok(TestCodexHelperPaths { @@ -106,6 +125,35 @@ impl ExecServerHarness { Ok(()) } + pub(crate) async fn disconnectable_websocket_proxy( + &self, + ) -> anyhow::Result { + let upstream = self + .websocket_url + .strip_prefix("ws://") + .ok_or_else(|| anyhow!("exec-server websocket URL must use ws://"))? + .to_string(); + let listener = TcpListener::bind("127.0.0.1:0").await?; + let websocket_url = format!("ws://{}", listener.local_addr()?); + let (pause_tx, pause_rx) = oneshot::channel(); + let (blocked_connection_tx, blocked_connection_rx) = oneshot::channel(); + let (resume_tx, resume_rx) = oneshot::channel(); + let task = tokio::spawn(run_disconnectable_proxy( + listener, + upstream, + pause_rx, + blocked_connection_tx, + resume_rx, + )); + Ok(DisconnectableWebSocketProxy { + websocket_url, + pause_tx: Some(pause_tx), + blocked_connection_rx: Some(blocked_connection_rx), + resume_tx: Some(resume_tx), + task, + }) + } + pub(crate) async fn send_request( &mut self, method: &str, @@ -213,6 +261,85 @@ impl ExecServerHarness { } } +impl DisconnectableWebSocketProxy { + pub(crate) fn websocket_url(&self) -> &str { + &self.websocket_url + } + + pub(crate) async fn pause_and_disconnect(&mut self) -> anyhow::Result<()> { + self.pause_tx + .take() + .ok_or_else(|| anyhow!("disconnectable websocket proxy is already paused"))? + .send(()) + .map_err(|_| anyhow!("disconnectable websocket proxy stopped"))?; + let blocked_connection_rx = self + .blocked_connection_rx + .take() + .ok_or_else(|| anyhow!("disconnectable websocket proxy is already paused"))?; + timeout(CONNECT_TIMEOUT, blocked_connection_rx) + .await + .map_err(|_| anyhow!("timed out waiting for client reconnect attempt"))? + .map_err(|_| anyhow!("disconnectable websocket proxy stopped"))?; + Ok(()) + } + + pub(crate) fn resume(&mut self) -> anyhow::Result<()> { + self.resume_tx + .take() + .ok_or_else(|| anyhow!("disconnectable websocket proxy is already resumed"))? + .send(()) + .map_err(|_| anyhow!("disconnectable websocket proxy stopped"))?; + Ok(()) + } +} + +async fn run_disconnectable_proxy( + listener: TcpListener, + upstream: String, + pause_rx: oneshot::Receiver<()>, + blocked_connection_tx: oneshot::Sender<()>, + mut resume_rx: oneshot::Receiver<()>, +) { + let Ok((mut downstream, _)) = listener.accept().await else { + return; + }; + let Ok(mut upstream_stream) = TcpStream::connect(&upstream).await else { + return; + }; + tokio::select! { + _ = copy_bidirectional(&mut downstream, &mut upstream_stream) => return, + _ = pause_rx => {} + } + drop(downstream); + drop(upstream_stream); + + let mut blocked_connection_tx = Some(blocked_connection_tx); + loop { + tokio::select! { + _ = &mut resume_rx => break, + accepted = listener.accept() => { + let Ok((blocked, _)) = accepted else { + break; + }; + drop(blocked); + if let Some(blocked_connection_tx) = blocked_connection_tx.take() { + let _ = blocked_connection_tx.send(()); + } + } + } + } + + loop { + let Ok((mut downstream, _)) = listener.accept().await else { + return; + }; + let Ok(mut upstream_stream) = TcpStream::connect(&upstream).await else { + continue; + }; + let _ = copy_bidirectional(&mut downstream, &mut upstream_stream).await; + } +} + async fn connect_websocket_when_ready( websocket_url: &str, ) -> anyhow::Result<( diff --git a/codex-rs/exec-server/tests/exec_process.rs b/codex-rs/exec-server/tests/exec_process.rs index e5522f407..11a3f3ece 100644 --- a/codex-rs/exec-server/tests/exec_process.rs +++ b/codex-rs/exec-server/tests/exec_process.rs @@ -1,5 +1,6 @@ mod common; +use std::collections::HashMap; use std::sync::Arc; use anyhow::Context; @@ -21,6 +22,7 @@ use tempfile::TempDir; use test_case::test_case; use tokio::sync::watch; use tokio::time::Duration; +use tokio::time::sleep; use tokio::time::timeout; use common::DELAYED_OUTPUT_AFTER_EXIT_PARENT_ARG; @@ -30,7 +32,7 @@ use common::exec_server::exec_server; struct ProcessContext { backend: Arc, - server: Option, + _server: Option, } #[derive(Debug, PartialEq, Eq)] @@ -55,13 +57,13 @@ async fn create_process_context(use_remote: bool) -> Result { let environment = Environment::create_for_tests(Some(server.websocket_url().to_string()))?; Ok(ProcessContext { backend: environment.get_exec_backend(), - server: Some(server), + _server: Some(server), }) } else { let environment = Environment::create_for_tests(/*exec_server_url*/ None)?; Ok(ProcessContext { backend: environment.get_exec_backend(), - server: None, + _server: None, }) } } @@ -634,88 +636,145 @@ async fn assert_exec_process_preserves_queued_events_before_subscribe( #[cfg_attr(not(unix), ignore = "Unix-only exec-server process test")] // Serialize tests that launch a real exec-server process through the full CLI. #[serial_test::serial(remote_exec_server)] -async fn remote_exec_process_reports_transport_disconnect() -> Result<()> { - let mut context = create_process_context(/*use_remote*/ true).await?; - let session = context - .backend +async fn remote_exec_process_recovers_after_transport_disconnect() -> Result<()> { + let server = exec_server().await?; + let mut proxy = server.disconnectable_websocket_proxy().await?; + let environment = Environment::create_for_tests(Some(proxy.websocket_url().to_string()))?; + let backend = environment.get_exec_backend(); + let temp_dir = TempDir::new()?; + let gate_path = temp_dir.path().join("release-output"); + let emitted_path = temp_dir.path().join("output-emitted"); + let session = backend .start(ExecParams { - process_id: ProcessId::from("proc-disconnect"), + process_id: ProcessId::from("proc-recover"), argv: vec![ "/bin/sh".to_string(), "-c".to_string(), - "sleep 10".to_string(), + concat!( + "printf 'ready:%s\\n' \"$$\"; ", + "while [ ! -f \"$GATE\" ]; do /bin/sleep 0.01; done; ", + "printf 'during:%s\\n' \"$$\"; ", + ": > \"$EMITTED\"; ", + "IFS= read -r line; ", + "printf 'after:%s:%s\\n' \"$$\" \"$line\"; ", + "exit 7", + ) + .to_string(), ], cwd: PathUri::from_path(std::env::current_dir()?)?, env_policy: /*env_policy*/ None, - env: Default::default(), + env: HashMap::from([ + ( + "GATE".to_string(), + gate_path.to_string_lossy().into_owned(), + ), + ( + "EMITTED".to_string(), + emitted_path.to_string_lossy().into_owned(), + ), + ]), tty: false, - pipe_stdin: false, + pipe_stdin: true, arg0: None, }) .await?; let process = Arc::clone(&session.process); let mut events = process.subscribe_events(); - let process_for_pending_read = Arc::clone(&process); - let pending_read = tokio::spawn(async move { - process_for_pending_read + let mut output = Vec::new(); + let mut last_seq = 0; + while !output.ends_with(b"\n") { + match timeout(Duration::from_secs(5), events.recv()).await?? { + ExecProcessEvent::Output(chunk) => { + assert_eq!(chunk.seq, last_seq + 1); + last_seq = chunk.seq; + output.extend_from_slice(&chunk.chunk.into_inner()); + } + event => anyhow::bail!("expected ready output before disconnect, got {event:?}"), + } + } + let ready = String::from_utf8(output.clone())?; + let pid = ready + .strip_prefix("ready:") + .and_then(|line| line.strip_suffix('\n')) + .context("ready output should contain the process id")? + .to_string(); + + proxy.pause_and_disconnect().await?; + tokio::fs::write(&gate_path, b"").await?; + timeout(Duration::from_secs(5), async { + while tokio::fs::metadata(&emitted_path).await.is_err() { + sleep(Duration::from_millis(10)).await; + } + }) + .await + .context("process did not emit output while disconnected")?; + + let process_for_read = Arc::clone(&process); + let mut pending_read = tokio::spawn(async move { + process_for_read .read( - /*after_seq*/ None, + /*after_seq*/ Some(last_seq), /*max_bytes*/ None, - /*wait_ms*/ Some(60_000), + /*wait_ms*/ Some(0), ) .await }); - let server = context - .server - .as_mut() - .expect("remote context should include exec-server harness"); - server.shutdown().await?; - - let event = timeout(Duration::from_secs(2), events.recv()).await??; - let ExecProcessEvent::Failed(event_message) = event else { - anyhow::bail!("expected process failure event, got {event:?}"); - }; assert!( - event_message.starts_with("exec-server transport disconnected"), - "unexpected failure event: {event_message}" + timeout(Duration::from_millis(200), &mut pending_read) + .await + .is_err(), + "process reads should wait while recovery is in progress" + ); + proxy.resume()?; + + let recovered_read = timeout(Duration::from_secs(5), pending_read) + .await + .context("timed out waiting for a read after recovery")??; + let recovered_read = recovered_read?; + assert_eq!(recovered_read.failure, None); + let recovered_output = recovered_read + .chunks + .into_iter() + .flat_map(|chunk| chunk.chunk.into_inner()) + .collect::>(); + assert_eq!( + String::from_utf8(recovered_output)?, + format!("during:{pid}\n") ); - let pending_response = timeout(Duration::from_secs(2), pending_read).await???; - let pending_message = pending_response - .failure - .expect("pending read should surface disconnect as a failure"); - assert!( - pending_message.starts_with("exec-server transport disconnected"), - "unexpected pending failure message: {pending_message}" - ); + let write = timeout(Duration::from_secs(5), process.write(b"hello\n".to_vec())) + .await + .context("timed out waiting for a write after recovery")??; + assert_eq!(write.status, WriteStatus::Accepted); - let mut wake_rx = process.subscribe_wake(); - let response = read_process_until_change(process, &mut wake_rx, /*after_seq*/ None).await?; - let message = response - .failure - .expect("disconnect should surface as a failure"); - assert!( - message.starts_with("exec-server transport disconnected"), - "unexpected failure message: {message}" - ); - assert!( - response.closed, - "disconnect should close the process session" - ); - - let write_result = timeout( - Duration::from_secs(2), - session.process.write(b"hello".to_vec()), - ) - .await - .context("timed out waiting for write after disconnect")?; - let write_error = write_result.expect_err("write after disconnect should fail"); - assert!( - write_error - .to_string() - .starts_with("exec-server transport disconnected"), - "unexpected write error: {write_error}" + let mut saw_exit = false; + loop { + match timeout(Duration::from_secs(5), events.recv()).await?? { + ExecProcessEvent::Output(chunk) => { + assert_eq!(chunk.seq, last_seq + 1); + last_seq = chunk.seq; + output.extend_from_slice(&chunk.chunk.into_inner()); + } + ExecProcessEvent::Exited { seq, exit_code } => { + assert_eq!(seq, last_seq + 1); + assert_eq!(exit_code, 7); + last_seq = seq; + saw_exit = true; + } + ExecProcessEvent::Closed { seq } => { + assert!(saw_exit, "closed must be delivered after exit"); + assert_eq!(seq, last_seq + 1); + break; + } + ExecProcessEvent::Failed(message) => { + anyhow::bail!("process recovery failed: {message}"); + } + } + } + assert_eq!( + String::from_utf8(output)?, + format!("ready:{pid}\nduring:{pid}\nafter:{pid}:hello\n") ); Ok(())