From 085ffb445642452d6fcf79dcda34e88f4d108afb Mon Sep 17 00:00:00 2001 From: jif-oai Date: Fri, 10 Apr 2026 14:11:47 +0100 Subject: [PATCH] feat: move exec-server ownership (#16344) This introduces session-scoped ownership for exec-server so ws disconnects no longer immediately kill running remote exec processes, and it prepares the protocol for reconnect-based resume. - add session_id / resume_session_id to the exec-server initialize handshake - move process ownership under a shared session registry - detach sessions on websocket disconnect and expire them after a TTL instead of killing processes immediately (we will resume based on this) - allow a new connection to resume an existing session and take over notifications/ownership - I use UUID to make them not predictable as we don't have auth for now - make detached-session expiry authoritative at resume time so teardown wins at the TTL boundary - reject long-poll process/read calls that get resumed out from under an older attachment --------- Co-authored-by: Codex --- codex-rs/Cargo.lock | 1 + codex-rs/exec-server/Cargo.toml | 1 + codex-rs/exec-server/src/client.rs | 38 ++- codex-rs/exec-server/src/client_api.rs | 2 + codex-rs/exec-server/src/connection.rs | 44 ++- codex-rs/exec-server/src/environment.rs | 35 +-- codex-rs/exec-server/src/local_process.rs | 94 ++---- codex-rs/exec-server/src/protocol.rs | 6 +- codex-rs/exec-server/src/rpc.rs | 3 +- codex-rs/exec-server/src/server.rs | 1 + codex-rs/exec-server/src/server/handler.rs | 147 ++++++++-- .../exec-server/src/server/handler/tests.rs | 268 +++++++++++++++-- .../exec-server/src/server/process_handler.rs | 16 +- codex-rs/exec-server/src/server/processor.rs | 276 +++++++++++++++++- codex-rs/exec-server/src/server/registry.rs | 12 +- .../src/server/session_registry.rs | 259 ++++++++++++++++ codex-rs/exec-server/src/server/transport.rs | 15 +- .../exec-server/tests/common/exec_server.rs | 11 + codex-rs/exec-server/tests/initialize.rs | 4 +- codex-rs/exec-server/tests/process.rs | 138 +++++++++ codex-rs/exec-server/tests/websocket.rs | 4 +- 21 files changed, 1203 insertions(+), 172 deletions(-) create mode 100644 codex-rs/exec-server/src/server/session_registry.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index f09f5d3a3..bf76650bf 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2124,6 +2124,7 @@ dependencies = [ "tokio", "tokio-tungstenite", "tracing", + "uuid", ] [[package]] diff --git a/codex-rs/exec-server/Cargo.toml b/codex-rs/exec-server/Cargo.toml index 41d65cf3d..6bf5528d3 100644 --- a/codex-rs/exec-server/Cargo.toml +++ b/codex-rs/exec-server/Cargo.toml @@ -40,6 +40,7 @@ tokio = { workspace = true, features = [ ] } tokio-tungstenite = { workspace = true } tracing = { workspace = true } +uuid = { workspace = true, features = ["v4"] } [dev-dependencies] anyhow = { workspace = true } diff --git a/codex-rs/exec-server/src/client.rs b/codex-rs/exec-server/src/client.rs index 3121d4ce2..993d2ae01 100644 --- a/codex-rs/exec-server/src/client.rs +++ b/codex-rs/exec-server/src/client.rs @@ -71,6 +71,7 @@ impl Default for ExecServerClientConnectOptions { Self { client_name: "codex-core".to_string(), initialize_timeout: INITIALIZE_TIMEOUT, + resume_session_id: None, } } } @@ -80,6 +81,7 @@ impl From for ExecServerClientConnectOptions { Self { client_name: value.client_name, initialize_timeout: value.initialize_timeout, + resume_session_id: value.resume_session_id, } } } @@ -91,6 +93,7 @@ impl RemoteExecServerConnectArgs { client_name, connect_timeout: CONNECT_TIMEOUT, initialize_timeout: INITIALIZE_TIMEOUT, + resume_session_id: None, } } } @@ -118,6 +121,7 @@ struct Inner { // need serialization so concurrent register/remove operations do not // overwrite each other's copy-on-write updates. sessions_write_lock: Mutex<()>, + session_id: std::sync::RwLock>, reader_task: tokio::task::JoinHandle<()>, } @@ -190,14 +194,29 @@ impl ExecServerClient { let ExecServerClientConnectOptions { client_name, initialize_timeout, + resume_session_id, } = options; timeout(initialize_timeout, async { - let response = self + let response: InitializeResponse = self .inner .client - .call(INITIALIZE_METHOD, &InitializeParams { client_name }) + .call( + INITIALIZE_METHOD, + &InitializeParams { + client_name, + resume_session_id, + }, + ) .await?; + { + let mut session_id = self + .inner + .session_id + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner); + *session_id = Some(response.session_id.clone()); + } self.notify_initialized().await?; Ok(response) }) @@ -350,6 +369,14 @@ impl ExecServerClient { self.inner.remove_session(process_id).await; } + pub fn session_id(&self) -> Option { + self.inner + .session_id + .read() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone() + } + async fn connect( connection: JsonRpcConnection, options: ExecServerClientConnectOptions, @@ -388,6 +415,7 @@ impl ExecServerClient { client: rpc_client, sessions: ArcSwap::from_pointee(HashMap::new()), sessions_write_lock: Mutex::new(()), + session_id: std::sync::RwLock::new(None), reader_task, } }); @@ -693,8 +721,10 @@ mod tests { &mut server_writer, JSONRPCMessage::Response(JSONRPCResponse { id: request.id, - result: serde_json::to_value(InitializeResponse {}) - .expect("initialize response should serialize"), + result: serde_json::to_value(InitializeResponse { + session_id: "session-1".to_string(), + }) + .expect("initialize response should serialize"), }), ) .await; diff --git a/codex-rs/exec-server/src/client_api.rs b/codex-rs/exec-server/src/client_api.rs index 6e8976341..ac4371e2e 100644 --- a/codex-rs/exec-server/src/client_api.rs +++ b/codex-rs/exec-server/src/client_api.rs @@ -5,6 +5,7 @@ use std::time::Duration; pub struct ExecServerClientConnectOptions { pub client_name: String, pub initialize_timeout: Duration, + pub resume_session_id: Option, } /// WebSocket connection arguments for a remote exec-server. @@ -14,4 +15,5 @@ pub struct RemoteExecServerConnectArgs { pub client_name: String, pub connect_timeout: Duration, pub initialize_timeout: Duration, + pub resume_session_id: Option, } diff --git a/codex-rs/exec-server/src/connection.rs b/codex-rs/exec-server/src/connection.rs index 89f19560c..21eac6b4c 100644 --- a/codex-rs/exec-server/src/connection.rs +++ b/codex-rs/exec-server/src/connection.rs @@ -4,6 +4,7 @@ use futures::StreamExt; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; use tokio::sync::mpsc; +use tokio::sync::watch; use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::tungstenite::Message; @@ -28,6 +29,7 @@ pub(crate) enum JsonRpcConnectionEvent { pub(crate) struct JsonRpcConnection { outgoing_tx: mpsc::Sender, incoming_rx: mpsc::Receiver, + disconnected_rx: watch::Receiver, task_handles: Vec>, } @@ -40,9 +42,11 @@ impl JsonRpcConnection { { let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (disconnected_tx, disconnected_rx) = watch::channel(false); let reader_label = connection_label.clone(); let incoming_tx_for_reader = incoming_tx.clone(); + let disconnected_tx_for_reader = disconnected_tx.clone(); let reader_task = tokio::spawn(async move { let mut lines = BufReader::new(reader).lines(); loop { @@ -73,12 +77,18 @@ impl JsonRpcConnection { } } Ok(None) => { - send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await; + send_disconnected( + &incoming_tx_for_reader, + &disconnected_tx_for_reader, + /*reason*/ None, + ) + .await; break; } Err(err) => { send_disconnected( &incoming_tx_for_reader, + &disconnected_tx_for_reader, Some(format!( "failed to read JSON-RPC message from {reader_label}: {err}" )), @@ -96,6 +106,7 @@ impl JsonRpcConnection { if let Err(err) = write_jsonrpc_line_message(&mut writer, &message).await { send_disconnected( &incoming_tx, + &disconnected_tx, Some(format!( "failed to write JSON-RPC message to {connection_label}: {err}" )), @@ -109,6 +120,7 @@ impl JsonRpcConnection { Self { outgoing_tx, incoming_rx, + disconnected_rx, task_handles: vec![reader_task, writer_task], } } @@ -119,10 +131,12 @@ impl JsonRpcConnection { { let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (disconnected_tx, disconnected_rx) = watch::channel(false); let (mut websocket_writer, mut websocket_reader) = stream.split(); let reader_label = connection_label.clone(); let incoming_tx_for_reader = incoming_tx.clone(); + let disconnected_tx_for_reader = disconnected_tx.clone(); let reader_task = tokio::spawn(async move { loop { match websocket_reader.next().await { @@ -171,7 +185,12 @@ impl JsonRpcConnection { } } Some(Ok(Message::Close(_))) => { - send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await; + send_disconnected( + &incoming_tx_for_reader, + &disconnected_tx_for_reader, + /*reason*/ None, + ) + .await; break; } Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) => {} @@ -179,6 +198,7 @@ impl JsonRpcConnection { Some(Err(err)) => { send_disconnected( &incoming_tx_for_reader, + &disconnected_tx_for_reader, Some(format!( "failed to read websocket JSON-RPC message from {reader_label}: {err}" )), @@ -187,7 +207,12 @@ impl JsonRpcConnection { break; } None => { - send_disconnected(&incoming_tx_for_reader, /*reason*/ None).await; + send_disconnected( + &incoming_tx_for_reader, + &disconnected_tx_for_reader, + /*reason*/ None, + ) + .await; break; } } @@ -202,6 +227,7 @@ impl JsonRpcConnection { { send_disconnected( &incoming_tx, + &disconnected_tx, Some(format!( "failed to write websocket JSON-RPC message to {connection_label}: {err}" )), @@ -213,6 +239,7 @@ impl JsonRpcConnection { Err(err) => { send_disconnected( &incoming_tx, + &disconnected_tx, Some(format!( "failed to serialize JSON-RPC message for {connection_label}: {err}" )), @@ -227,6 +254,7 @@ impl JsonRpcConnection { Self { outgoing_tx, incoming_rx, + disconnected_rx, task_handles: vec![reader_task, writer_task], } } @@ -236,16 +264,24 @@ impl JsonRpcConnection { ) -> ( mpsc::Sender, mpsc::Receiver, + watch::Receiver, Vec>, ) { - (self.outgoing_tx, self.incoming_rx, self.task_handles) + ( + self.outgoing_tx, + self.incoming_rx, + self.disconnected_rx, + self.task_handles, + ) } } async fn send_disconnected( incoming_tx: &mpsc::Sender, + disconnected_tx: &watch::Sender, reason: Option, ) { + let _ = disconnected_tx.send(true); let _ = incoming_tx .send(JsonRpcConnectionEvent::Disconnected { reason }) .await; diff --git a/codex-rs/exec-server/src/environment.rs b/codex-rs/exec-server/src/environment.rs index 3323db006..00b7fbb0c 100644 --- a/codex-rs/exec-server/src/environment.rs +++ b/codex-rs/exec-server/src/environment.rs @@ -105,18 +105,10 @@ pub struct Environment { impl Default for Environment { fn default() -> Self { - let local_process = LocalProcess::default(); - if let Err(err) = local_process.initialize() { - panic!("default local process initialization should succeed: {err:?}"); - } - if let Err(err) = local_process.initialized() { - panic!("default local process should accept initialized notification: {err}"); - } - Self { exec_server_url: None, remote_exec_server_client: None, - exec_backend: Arc::new(local_process), + exec_backend: Arc::new(LocalProcess::default()), } } } @@ -146,6 +138,7 @@ impl Environment { client_name: "codex-environment".to_string(), connect_timeout: std::time::Duration::from_secs(5), initialize_timeout: std::time::Duration::from_secs(5), + resume_session_id: None, }) .await?, ) @@ -153,24 +146,12 @@ impl Environment { None }; - let exec_backend: Arc = match remote_exec_server_client.clone() { - Some(client) => Arc::new(RemoteProcess::new(client)), - None if exec_server_url.is_some() => { - return Err(ExecServerError::Protocol( - "remote mode should have an exec-server client".to_string(), - )); - } - None => { - let local_process = LocalProcess::default(); - local_process - .initialize() - .map_err(|err| ExecServerError::Protocol(err.message))?; - local_process - .initialized() - .map_err(ExecServerError::Protocol)?; - Arc::new(local_process) - } - }; + let exec_backend: Arc = + if let Some(client) = remote_exec_server_client.clone() { + Arc::new(RemoteProcess::new(client)) + } else { + Arc::new(LocalProcess::default()) + }; Ok(Self { exec_server_url, diff --git a/codex-rs/exec-server/src/local_process.rs b/codex-rs/exec-server/src/local_process.rs index 7c7f0af8c..5a5a6a6e6 100644 --- a/codex-rs/exec-server/src/local_process.rs +++ b/codex-rs/exec-server/src/local_process.rs @@ -1,8 +1,6 @@ use std::collections::HashMap; use std::collections::VecDeque; use std::sync::Arc; -use std::sync::atomic::AtomicBool; -use std::sync::atomic::Ordering; use std::time::Duration; use async_trait::async_trait; @@ -26,7 +24,6 @@ use crate::protocol::ExecOutputDeltaNotification; use crate::protocol::ExecOutputStream; use crate::protocol::ExecParams; use crate::protocol::ExecResponse; -use crate::protocol::InitializeResponse; use crate::protocol::ProcessOutputChunk; use crate::protocol::ReadParams; use crate::protocol::ReadResponse; @@ -74,10 +71,8 @@ enum ProcessEntry { } struct Inner { - notifications: RpcNotificationSender, + notifications: std::sync::RwLock>, processes: Mutex>, - initialize_requested: AtomicBool, - initialized: AtomicBool, } #[derive(Clone)] @@ -104,10 +99,8 @@ impl LocalProcess { pub(crate) fn new(notifications: RpcNotificationSender) -> Self { Self { inner: Arc::new(Inner { - notifications, + notifications: std::sync::RwLock::new(Some(notifications)), processes: Mutex::new(HashMap::new()), - initialize_requested: AtomicBool::new(false), - initialized: AtomicBool::new(false), }), } } @@ -128,45 +121,19 @@ impl LocalProcess { } } - pub(crate) fn initialize(&self) -> Result { - if self.inner.initialize_requested.swap(true, Ordering::SeqCst) { - return Err(invalid_request( - "initialize may only be sent once per connection".to_string(), - )); - } - Ok(InitializeResponse {}) - } - - pub(crate) fn initialized(&self) -> Result<(), String> { - if !self.inner.initialize_requested.load(Ordering::SeqCst) { - return Err("received `initialized` notification before `initialize`".into()); - } - self.inner.initialized.store(true, Ordering::SeqCst); - Ok(()) - } - - pub(crate) fn require_initialized_for( - &self, - method_family: &str, - ) -> Result<(), JSONRPCErrorError> { - if !self.inner.initialize_requested.load(Ordering::SeqCst) { - return Err(invalid_request(format!( - "client must call initialize before using {method_family} methods" - ))); - } - if !self.inner.initialized.load(Ordering::SeqCst) { - return Err(invalid_request(format!( - "client must send initialized before using {method_family} methods" - ))); - } - Ok(()) + pub(crate) fn set_notification_sender(&self, notifications: Option) { + let mut notification_sender = self + .inner + .notifications + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner); + *notification_sender = notifications; } async fn start_process( &self, params: ExecParams, ) -> Result<(ExecResponse, watch::Sender), JSONRPCErrorError> { - self.require_initialized_for("exec")?; let process_id = params.process_id.clone(); let (program, args) = params .argv @@ -277,7 +244,6 @@ impl LocalProcess { &self, params: ReadParams, ) -> Result { - self.require_initialized_for("exec")?; let _process_id = params.process_id.clone(); let after_seq = params.after_seq.unwrap_or(0); let max_bytes = params.max_bytes.unwrap_or(usize::MAX); @@ -354,7 +320,6 @@ impl LocalProcess { &self, params: WriteParams, ) -> Result { - self.require_initialized_for("exec")?; let _process_id = params.process_id.clone(); let _input_bytes = params.chunk.0.len(); let writer_tx = { @@ -391,7 +356,6 @@ impl LocalProcess { &self, params: TerminateParams, ) -> Result { - self.require_initialized_for("exec")?; let _process_id = params.process_id.clone(); let running = { let process_map = self.inner.processes.lock().await; @@ -546,13 +510,10 @@ async fn stream_output( } }; output_notify.notify_waiters(); - if inner - .notifications - .notify(crate::protocol::EXEC_OUTPUT_DELTA_METHOD, ¬ification) - .await - .is_err() - { - break; + if let Some(notifications) = notification_sender(&inner) { + let _ = notifications + .notify(crate::protocol::EXEC_OUTPUT_DELTA_METHOD, ¬ification) + .await; } } @@ -584,13 +545,11 @@ async fn watch_exit( }; output_notify.notify_waiters(); if let Some(notification) = notification - && inner - .notifications - .notify(crate::protocol::EXEC_EXITED_METHOD, ¬ification) - .await - .is_err() + && let Some(notifications) = notification_sender(&inner) { - return; + let _ = notifications + .notify(crate::protocol::EXEC_EXITED_METHOD, ¬ification) + .await; } maybe_emit_closed(process_id.clone(), Arc::clone(&inner)).await; @@ -645,10 +604,17 @@ async fn maybe_emit_closed(process_id: ProcessId, inner: Arc) { return; }; - if inner - .notifications - .notify(EXEC_CLOSED_METHOD, ¬ification) - .await - .is_err() - {} + if let Some(notifications) = notification_sender(&inner) { + let _ = notifications + .notify(EXEC_CLOSED_METHOD, ¬ification) + .await; + } +} + +fn notification_sender(inner: &Inner) -> Option { + inner + .notifications + .read() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone() } diff --git a/codex-rs/exec-server/src/protocol.rs b/codex-rs/exec-server/src/protocol.rs index 8a61a9de1..54034bdc5 100644 --- a/codex-rs/exec-server/src/protocol.rs +++ b/codex-rs/exec-server/src/protocol.rs @@ -46,11 +46,15 @@ impl From> for ByteChunk { #[serde(rename_all = "camelCase")] pub struct InitializeParams { pub client_name: String, + #[serde(default)] + pub resume_session_id: Option, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -pub struct InitializeResponse {} +pub struct InitializeResponse { + pub session_id: String, +} #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] diff --git a/codex-rs/exec-server/src/rpc.rs b/codex-rs/exec-server/src/rpc.rs index bf7cb27a5..30a3f70bb 100644 --- a/codex-rs/exec-server/src/rpc.rs +++ b/codex-rs/exec-server/src/rpc.rs @@ -179,7 +179,8 @@ pub(crate) struct RpcClient { impl RpcClient { pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver) { - let (write_tx, mut incoming_rx, transport_tasks) = connection.into_parts(); + let (write_tx, mut incoming_rx, _disconnected_rx, transport_tasks) = + connection.into_parts(); let pending = Arc::new(Mutex::new(HashMap::::new())); let (event_tx, event_rx) = mpsc::channel(128); diff --git a/codex-rs/exec-server/src/server.rs b/codex-rs/exec-server/src/server.rs index 46de5aa49..44dc0a5d0 100644 --- a/codex-rs/exec-server/src/server.rs +++ b/codex-rs/exec-server/src/server.rs @@ -3,6 +3,7 @@ mod handler; mod process_handler; mod processor; mod registry; +mod session_registry; mod transport; pub(crate) use handler::ExecServerHandler; diff --git a/codex-rs/exec-server/src/server/handler.rs b/codex-rs/exec-server/src/server/handler.rs index 39e26548b..65bfc15fd 100644 --- a/codex-rs/exec-server/src/server/handler.rs +++ b/codex-rs/exec-server/src/server/handler.rs @@ -1,3 +1,8 @@ +use std::sync::Arc; +use std::sync::Mutex as StdMutex; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; + use codex_app_server_protocol::JSONRPCErrorError; use crate::protocol::ExecParams; @@ -16,6 +21,7 @@ use crate::protocol::FsRemoveParams; use crate::protocol::FsRemoveResponse; use crate::protocol::FsWriteFileParams; use crate::protocol::FsWriteFileResponse; +use crate::protocol::InitializeParams; use crate::protocol::InitializeResponse; use crate::protocol::ReadParams; use crate::protocol::ReadResponse; @@ -24,65 +30,126 @@ use crate::protocol::TerminateResponse; use crate::protocol::WriteParams; use crate::protocol::WriteResponse; use crate::rpc::RpcNotificationSender; +use crate::rpc::invalid_request; use crate::server::file_system_handler::FileSystemHandler; -use crate::server::process_handler::ProcessHandler; +use crate::server::session_registry::SessionHandle; +use crate::server::session_registry::SessionRegistry; -#[derive(Clone)] pub(crate) struct ExecServerHandler { - process: ProcessHandler, + session_registry: Arc, + notifications: RpcNotificationSender, + session: StdMutex>, file_system: FileSystemHandler, + initialize_requested: AtomicBool, + initialized: AtomicBool, } impl ExecServerHandler { - pub(crate) fn new(notifications: RpcNotificationSender) -> Self { + pub(crate) fn new( + session_registry: Arc, + notifications: RpcNotificationSender, + ) -> Self { Self { - process: ProcessHandler::new(notifications), + session_registry, + notifications, + session: StdMutex::new(None), file_system: FileSystemHandler::default(), + initialize_requested: AtomicBool::new(false), + initialized: AtomicBool::new(false), } } pub(crate) async fn shutdown(&self) { - self.process.shutdown().await; + if let Some(session) = self.session() { + session.detach().await; + } } - pub(crate) fn initialize(&self) -> Result { - self.process.initialize() + pub(crate) fn is_session_attached(&self) -> bool { + self.session() + .is_none_or(|session| session.is_session_attached()) + } + + pub(crate) async fn initialize( + &self, + params: InitializeParams, + ) -> Result { + if self.initialize_requested.swap(true, Ordering::SeqCst) { + return Err(invalid_request( + "initialize may only be sent once per connection".to_string(), + )); + } + + let session = match self + .session_registry + .attach(params.resume_session_id.clone(), self.notifications.clone()) + .await + { + Ok(session) => session, + Err(error) => { + self.initialize_requested.store(false, Ordering::SeqCst); + return Err(error); + } + }; + let session_id = session.session_id().to_string(); + tracing::debug!( + session_id, + connection_id = %session.connection_id(), + "exec-server session attached" + ); + *self + .session + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(session); + Ok(InitializeResponse { session_id }) } pub(crate) fn initialized(&self) -> Result<(), String> { - self.process.initialized() + if !self.initialize_requested.load(Ordering::SeqCst) { + return Err("received `initialized` notification before `initialize`".into()); + } + self.require_session_attached() + .map_err(|error| error.message)?; + self.initialized.store(true, Ordering::SeqCst); + Ok(()) } pub(crate) async fn exec(&self, params: ExecParams) -> Result { - self.process.exec(params).await + let session = self.require_initialized_for("exec")?; + session.process().exec(params).await } pub(crate) async fn exec_read( &self, params: ReadParams, ) -> Result { - self.process.exec_read(params).await + let session = self.require_initialized_for("exec")?; + let response = session.process().exec_read(params).await?; + self.require_session_attached()?; + Ok(response) } pub(crate) async fn exec_write( &self, params: WriteParams, ) -> Result { - self.process.exec_write(params).await + let session = self.require_initialized_for("exec")?; + session.process().exec_write(params).await } pub(crate) async fn terminate( &self, params: TerminateParams, ) -> Result { - self.process.terminate(params).await + let session = self.require_initialized_for("exec")?; + session.process().terminate(params).await } pub(crate) async fn fs_read_file( &self, params: FsReadFileParams, ) -> Result { - self.process.require_initialized_for("filesystem")?; + self.require_initialized_for("filesystem")?; self.file_system.read_file(params).await } @@ -90,7 +157,7 @@ impl ExecServerHandler { &self, params: FsWriteFileParams, ) -> Result { - self.process.require_initialized_for("filesystem")?; + self.require_initialized_for("filesystem")?; self.file_system.write_file(params).await } @@ -98,7 +165,7 @@ impl ExecServerHandler { &self, params: FsCreateDirectoryParams, ) -> Result { - self.process.require_initialized_for("filesystem")?; + self.require_initialized_for("filesystem")?; self.file_system.create_directory(params).await } @@ -106,7 +173,7 @@ impl ExecServerHandler { &self, params: FsGetMetadataParams, ) -> Result { - self.process.require_initialized_for("filesystem")?; + self.require_initialized_for("filesystem")?; self.file_system.get_metadata(params).await } @@ -114,7 +181,7 @@ impl ExecServerHandler { &self, params: FsReadDirectoryParams, ) -> Result { - self.process.require_initialized_for("filesystem")?; + self.require_initialized_for("filesystem")?; self.file_system.read_directory(params).await } @@ -122,7 +189,7 @@ impl ExecServerHandler { &self, params: FsRemoveParams, ) -> Result { - self.process.require_initialized_for("filesystem")?; + self.require_initialized_for("filesystem")?; self.file_system.remove(params).await } @@ -130,9 +197,49 @@ impl ExecServerHandler { &self, params: FsCopyParams, ) -> Result { - self.process.require_initialized_for("filesystem")?; + self.require_initialized_for("filesystem")?; self.file_system.copy(params).await } + + fn require_initialized_for( + &self, + method_family: &str, + ) -> Result { + if !self.initialize_requested.load(Ordering::SeqCst) { + return Err(invalid_request(format!( + "client must call initialize before using {method_family} methods" + ))); + } + let session = self.require_session_attached()?; + if !self.initialized.load(Ordering::SeqCst) { + return Err(invalid_request(format!( + "client must send initialized before using {method_family} methods" + ))); + } + Ok(session) + } + + fn require_session_attached(&self) -> Result { + let Some(session) = self.session() else { + return Err(invalid_request( + "client must call initialize before using methods".to_string(), + )); + }; + if session.is_session_attached() { + return Ok(session); + } + + Err(invalid_request( + "session has been resumed by another connection".to_string(), + )) + } + + fn session(&self) -> Option { + self.session + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone() + } } #[cfg(test)] diff --git a/codex-rs/exec-server/src/server/handler/tests.rs b/codex-rs/exec-server/src/server/handler/tests.rs index 0bd78fffb..7b16ae535 100644 --- a/codex-rs/exec-server/src/server/handler/tests.rs +++ b/codex-rs/exec-server/src/server/handler/tests.rs @@ -4,43 +4,81 @@ use std::time::Duration; use pretty_assertions::assert_eq; use tokio::sync::mpsc; +use uuid::Uuid; use super::ExecServerHandler; use crate::ProcessId; use crate::protocol::ExecParams; -use crate::protocol::InitializeResponse; +use crate::protocol::InitializeParams; +use crate::protocol::ReadParams; +use crate::protocol::ReadResponse; use crate::protocol::TerminateParams; use crate::protocol::TerminateResponse; use crate::rpc::RpcNotificationSender; +use crate::server::session_registry::SessionRegistry; fn exec_params(process_id: &str) -> ExecParams { - let mut env = HashMap::new(); - if let Some(path) = std::env::var_os("PATH") { - env.insert("PATH".to_string(), path.to_string_lossy().into_owned()); - } + exec_params_with_argv(process_id, sleep_argv()) +} + +fn exec_params_with_argv(process_id: &str, argv: Vec) -> ExecParams { ExecParams { process_id: ProcessId::from(process_id), - argv: vec![ - "bash".to_string(), - "-lc".to_string(), - "sleep 0.1".to_string(), - ], + argv, cwd: std::env::current_dir().expect("cwd"), - env, + env: inherited_path_env(), tty: false, arg0: None, } } +fn inherited_path_env() -> HashMap { + let mut env = HashMap::new(); + if let Some(path) = std::env::var_os("PATH") { + env.insert("PATH".to_string(), path.to_string_lossy().into_owned()); + } + env +} + +fn sleep_argv() -> Vec { + shell_argv("sleep 0.1", "ping -n 2 127.0.0.1 >NUL") +} + +fn shell_argv(unix_script: &str, windows_script: &str) -> Vec { + if cfg!(windows) { + vec![ + windows_command_processor(), + "/C".to_string(), + windows_script.to_string(), + ] + } else { + vec![ + "/bin/sh".to_string(), + "-c".to_string(), + unix_script.to_string(), + ] + } +} + +fn windows_command_processor() -> String { + std::env::var("COMSPEC").unwrap_or_else(|_| "cmd.exe".to_string()) +} + async fn initialized_handler() -> Arc { let (outgoing_tx, _outgoing_rx) = mpsc::channel(16); - let handler = Arc::new(ExecServerHandler::new(RpcNotificationSender::new( - outgoing_tx, - ))); - assert_eq!( - handler.initialize().expect("initialize"), - InitializeResponse {} - ); + let registry = SessionRegistry::new(); + let handler = Arc::new(ExecServerHandler::new( + registry, + RpcNotificationSender::new(outgoing_tx), + )); + let initialize_response = handler + .initialize(InitializeParams { + client_name: "exec-server-test".to_string(), + resume_session_id: None, + }) + .await + .expect("initialize"); + Uuid::parse_str(&initialize_response.session_id).expect("session id should be a UUID"); handler.initialized().expect("initialized"); handler } @@ -101,3 +139,197 @@ async fn terminate_reports_false_after_process_exit() { handler.shutdown().await; } + +#[tokio::test] +async fn long_poll_read_fails_after_session_resume() { + let (first_tx, _first_rx) = mpsc::channel(16); + let registry = SessionRegistry::new(); + let first_handler = Arc::new(ExecServerHandler::new( + Arc::clone(®istry), + RpcNotificationSender::new(first_tx), + )); + let initialize_response = first_handler + .initialize(InitializeParams { + client_name: "exec-server-test".to_string(), + resume_session_id: None, + }) + .await + .expect("initialize"); + first_handler.initialized().expect("initialized"); + + first_handler + .exec(exec_params_with_argv( + "proc-long-poll", + shell_argv( + "sleep 0.1; printf resumed", + "ping -n 2 127.0.0.1 >NUL && echo resumed", + ), + )) + .await + .expect("start process"); + + let first_read_handler = Arc::clone(&first_handler); + let read_task = tokio::spawn(async move { + first_read_handler + .exec_read(ReadParams { + process_id: ProcessId::from("proc-long-poll"), + after_seq: None, + max_bytes: None, + wait_ms: Some(500), + }) + .await + }); + + tokio::time::sleep(Duration::from_millis(50)).await; + first_handler.shutdown().await; + + let (second_tx, _second_rx) = mpsc::channel(16); + let second_handler = Arc::new(ExecServerHandler::new( + registry, + RpcNotificationSender::new(second_tx), + )); + second_handler + .initialize(InitializeParams { + client_name: "exec-server-test".to_string(), + resume_session_id: Some(initialize_response.session_id), + }) + .await + .expect("initialize second connection"); + second_handler + .initialized() + .expect("initialized second connection"); + + let err = read_task + .await + .expect("read task should join") + .expect_err("evicted long-poll read should fail"); + assert_eq!(err.code, -32600); + assert_eq!( + err.message, + "session has been resumed by another connection" + ); + + second_handler.shutdown().await; +} + +#[tokio::test] +async fn active_session_resume_is_rejected() { + let (first_tx, _first_rx) = mpsc::channel(16); + let registry = SessionRegistry::new(); + let first_handler = Arc::new(ExecServerHandler::new( + Arc::clone(®istry), + RpcNotificationSender::new(first_tx), + )); + let initialize_response = first_handler + .initialize(InitializeParams { + client_name: "exec-server-test".to_string(), + resume_session_id: None, + }) + .await + .expect("initialize"); + + let (second_tx, _second_rx) = mpsc::channel(16); + let second_handler = Arc::new(ExecServerHandler::new( + registry, + RpcNotificationSender::new(second_tx), + )); + let err = second_handler + .initialize(InitializeParams { + client_name: "exec-server-test".to_string(), + resume_session_id: Some(initialize_response.session_id.clone()), + }) + .await + .expect_err("active session resume should fail"); + + assert_eq!(err.code, -32600); + assert_eq!( + err.message, + format!( + "session {} is already attached to another connection", + initialize_response.session_id + ) + ); + + first_handler.shutdown().await; +} + +#[tokio::test] +async fn output_and_exit_are_retained_after_notification_receiver_closes() { + let (outgoing_tx, outgoing_rx) = mpsc::channel(16); + let handler = Arc::new(ExecServerHandler::new( + SessionRegistry::new(), + RpcNotificationSender::new(outgoing_tx), + )); + handler + .initialize(InitializeParams { + client_name: "exec-server-test".to_string(), + resume_session_id: None, + }) + .await + .expect("initialize"); + handler.initialized().expect("initialized"); + + let process_id = ProcessId::from("proc-notification-fail"); + handler + .exec(exec_params_with_argv( + process_id.as_str(), + shell_argv( + "sleep 0.05; printf 'first\\n'; sleep 0.05; printf 'second\\n'", + "echo first && ping -n 2 127.0.0.1 >NUL && echo second", + ), + )) + .await + .expect("start process"); + + drop(outgoing_rx); + + let (output, exit_code) = read_process_until_closed(&handler, process_id.clone()).await; + assert_eq!(output.replace("\r\n", "\n"), "first\nsecond\n"); + assert_eq!(exit_code, Some(0)); + + tokio::time::sleep(Duration::from_millis(100)).await; + handler + .exec(exec_params(process_id.as_str())) + .await + .expect("process id should be reusable after exit retention"); + + handler.shutdown().await; +} + +async fn read_process_until_closed( + handler: &ExecServerHandler, + process_id: ProcessId, +) -> (String, Option) { + let deadline = tokio::time::Instant::now() + Duration::from_secs(2); + let mut output = String::new(); + let mut exit_code = None; + let mut after_seq = None; + + loop { + let response: ReadResponse = handler + .exec_read(ReadParams { + process_id: process_id.clone(), + after_seq, + max_bytes: None, + wait_ms: Some(500), + }) + .await + .expect("read process"); + + for chunk in response.chunks { + output.push_str(&String::from_utf8_lossy(&chunk.chunk.into_inner())); + after_seq = Some(chunk.seq); + } + if response.exited { + exit_code = response.exit_code; + } + if response.closed { + return (output, exit_code); + } + after_seq = response.next_seq.checked_sub(1).or(after_seq); + assert!( + tokio::time::Instant::now() < deadline, + "process should close within 2s" + ); + } +} diff --git a/codex-rs/exec-server/src/server/process_handler.rs b/codex-rs/exec-server/src/server/process_handler.rs index 6f22890d3..38fbace1c 100644 --- a/codex-rs/exec-server/src/server/process_handler.rs +++ b/codex-rs/exec-server/src/server/process_handler.rs @@ -3,7 +3,6 @@ use codex_app_server_protocol::JSONRPCErrorError; use crate::local_process::LocalProcess; use crate::protocol::ExecParams; use crate::protocol::ExecResponse; -use crate::protocol::InitializeResponse; use crate::protocol::ReadParams; use crate::protocol::ReadResponse; use crate::protocol::TerminateParams; @@ -28,19 +27,8 @@ impl ProcessHandler { self.process.shutdown().await; } - pub(crate) fn initialize(&self) -> Result { - self.process.initialize() - } - - pub(crate) fn initialized(&self) -> Result<(), String> { - self.process.initialized() - } - - pub(crate) fn require_initialized_for( - &self, - method_family: &str, - ) -> Result<(), JSONRPCErrorError> { - self.process.require_initialized_for(method_family) + pub(crate) fn set_notification_sender(&self, notifications: Option) { + self.process.set_notification_sender(notifications); } pub(crate) async fn exec(&self, params: ExecParams) -> Result { diff --git a/codex-rs/exec-server/src/server/processor.rs b/codex-rs/exec-server/src/server/processor.rs index 518a1a78e..bd30a9819 100644 --- a/codex-rs/exec-server/src/server/processor.rs +++ b/codex-rs/exec-server/src/server/processor.rs @@ -14,14 +14,33 @@ use crate::rpc::invalid_request; use crate::rpc::method_not_found; use crate::server::ExecServerHandler; use crate::server::registry::build_router; +use crate::server::session_registry::SessionRegistry; -pub(crate) async fn run_connection(connection: JsonRpcConnection) { +#[derive(Clone)] +pub(crate) struct ConnectionProcessor { + session_registry: Arc, +} + +impl ConnectionProcessor { + pub(crate) fn new() -> Self { + Self { + session_registry: SessionRegistry::new(), + } + } + + pub(crate) async fn run_connection(&self, connection: JsonRpcConnection) { + run_connection(connection, Arc::clone(&self.session_registry)).await; + } +} + +async fn run_connection(connection: JsonRpcConnection, session_registry: Arc) { let router = Arc::new(build_router()); - let (json_outgoing_tx, mut incoming_rx, connection_tasks) = connection.into_parts(); + let (json_outgoing_tx, mut incoming_rx, mut disconnected_rx, connection_tasks) = + connection.into_parts(); let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(CHANNEL_CAPACITY); let notifications = RpcNotificationSender::new(outgoing_tx.clone()); - let handler = Arc::new(ExecServerHandler::new(notifications)); + let handler = Arc::new(ExecServerHandler::new(session_registry, notifications)); let outbound_task = tokio::spawn(async move { while let Some(message) = outgoing_rx.recv().await { @@ -40,6 +59,10 @@ pub(crate) async fn run_connection(connection: JsonRpcConnection) { // Process inbound events sequentially to preserve initialize/initialized ordering. while let Some(event) = incoming_rx.recv().await { + if !handler.is_session_attached() { + debug!("exec-server connection evicted after session resume"); + break; + } match event { JsonRpcConnectionEvent::MalformedMessage { reason } => { warn!("ignoring malformed exec-server message: {reason}"); @@ -57,7 +80,13 @@ pub(crate) async fn run_connection(connection: JsonRpcConnection) { JsonRpcConnectionEvent::Message(message) => match message { codex_app_server_protocol::JSONRPCMessage::Request(request) => { if let Some(route) = router.request_route(request.method.as_str()) { - let message = route(handler.clone(), request).await; + let message = tokio::select! { + message = route(Arc::clone(&handler), request) => message, + _ = disconnected_rx.changed() => { + debug!("exec-server transport disconnected while handling request"); + break; + } + }; if outgoing_tx.send(message).await.is_err() { break; } @@ -84,7 +113,16 @@ pub(crate) async fn run_connection(connection: JsonRpcConnection) { ); break; }; - if let Err(err) = route(handler.clone(), notification).await { + let result = tokio::select! { + result = route(Arc::clone(&handler), notification) => result, + _ = disconnected_rx.changed() => { + debug!( + "exec-server transport disconnected while handling notification" + ); + break; + } + }; + if let Err(err) = result { warn!("closing exec-server connection after protocol error: {err}"); break; } @@ -114,6 +152,7 @@ pub(crate) async fn run_connection(connection: JsonRpcConnection) { } handler.shutdown().await; + drop(handler); drop(outgoing_tx); for task in connection_tasks { task.abort(); @@ -121,3 +160,230 @@ pub(crate) async fn run_connection(connection: JsonRpcConnection) { } let _ = outbound_task.await; } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::sync::Arc; + use std::time::Duration; + + use codex_app_server_protocol::JSONRPCMessage; + use codex_app_server_protocol::JSONRPCNotification; + use codex_app_server_protocol::JSONRPCRequest; + use codex_app_server_protocol::JSONRPCResponse; + use codex_app_server_protocol::RequestId; + use serde::Serialize; + use serde::de::DeserializeOwned; + use tokio::io::AsyncBufReadExt; + use tokio::io::AsyncWriteExt; + use tokio::io::BufReader; + use tokio::io::DuplexStream; + use tokio::io::Lines; + use tokio::io::duplex; + use tokio::task::JoinHandle; + use tokio::time::timeout; + + use super::run_connection; + use crate::ProcessId; + use crate::connection::JsonRpcConnection; + use crate::protocol::EXEC_METHOD; + use crate::protocol::EXEC_READ_METHOD; + use crate::protocol::EXEC_TERMINATE_METHOD; + use crate::protocol::ExecParams; + use crate::protocol::ExecResponse; + use crate::protocol::INITIALIZE_METHOD; + use crate::protocol::INITIALIZED_METHOD; + use crate::protocol::InitializeParams; + use crate::protocol::InitializeResponse; + use crate::protocol::ReadParams; + use crate::protocol::TerminateParams; + use crate::protocol::TerminateResponse; + use crate::server::session_registry::SessionRegistry; + + #[tokio::test] + async fn transport_disconnect_detaches_session_during_in_flight_read() { + let registry = SessionRegistry::new(); + let (mut first_writer, mut first_lines, first_task) = + spawn_test_connection(Arc::clone(®istry), "first"); + + send_request( + &mut first_writer, + /*id*/ 1, + INITIALIZE_METHOD, + &InitializeParams { + client_name: "exec-server-test".to_string(), + resume_session_id: None, + }, + ) + .await; + let initialize_response: InitializeResponse = + read_response(&mut first_lines, /*expected_id*/ 1).await; + send_notification(&mut first_writer, INITIALIZED_METHOD, &()).await; + + let process_id = ProcessId::from("proc-long-poll"); + send_request( + &mut first_writer, + /*id*/ 2, + EXEC_METHOD, + &exec_params(process_id.clone()), + ) + .await; + let _: ExecResponse = read_response(&mut first_lines, /*expected_id*/ 2).await; + + send_request( + &mut first_writer, + /*id*/ 3, + EXEC_READ_METHOD, + &ReadParams { + process_id: process_id.clone(), + after_seq: None, + max_bytes: None, + wait_ms: Some(5_000), + }, + ) + .await; + drop(first_writer); + tokio::time::sleep(Duration::from_millis(25)).await; + + let (mut second_writer, mut second_lines, second_task) = + spawn_test_connection(Arc::clone(®istry), "second"); + send_request( + &mut second_writer, + /*id*/ 1, + INITIALIZE_METHOD, + &InitializeParams { + client_name: "exec-server-test".to_string(), + resume_session_id: Some(initialize_response.session_id.clone()), + }, + ) + .await; + let second_initialize_response = timeout( + Duration::from_secs(1), + read_response::(&mut second_lines, /*expected_id*/ 1), + ) + .await + .expect("resume initialize should not wait for the old read to finish"); + assert_eq!( + second_initialize_response.session_id, + initialize_response.session_id + ); + timeout(Duration::from_secs(1), first_task) + .await + .expect("first processor should exit") + .expect("first processor should join"); + send_notification(&mut second_writer, INITIALIZED_METHOD, &()).await; + + send_request( + &mut second_writer, + /*id*/ 2, + EXEC_TERMINATE_METHOD, + &TerminateParams { process_id }, + ) + .await; + let _: TerminateResponse = read_response(&mut second_lines, /*expected_id*/ 2).await; + + drop(second_writer); + drop(second_lines); + timeout(Duration::from_secs(1), second_task) + .await + .expect("second processor should exit") + .expect("second processor should join"); + } + + fn spawn_test_connection( + registry: Arc, + label: &str, + ) -> (DuplexStream, Lines>, JoinHandle<()>) { + let (client_writer, server_reader) = duplex(1 << 20); + let (server_writer, client_reader) = duplex(1 << 20); + let connection = + JsonRpcConnection::from_stdio(server_reader, server_writer, label.to_string()); + let task = tokio::spawn(run_connection(connection, registry)); + (client_writer, BufReader::new(client_reader).lines(), task) + } + + async fn send_request( + writer: &mut DuplexStream, + id: i64, + method: &str, + params: &P, + ) { + write_message( + writer, + &JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(id), + method: method.to_string(), + params: Some(serde_json::to_value(params).expect("serialize params")), + trace: None, + }), + ) + .await; + } + + async fn send_notification(writer: &mut DuplexStream, method: &str, params: &P) { + write_message( + writer, + &JSONRPCMessage::Notification(JSONRPCNotification { + method: method.to_string(), + params: Some(serde_json::to_value(params).expect("serialize params")), + }), + ) + .await; + } + + async fn write_message(writer: &mut DuplexStream, message: &JSONRPCMessage) { + let encoded = serde_json::to_vec(message).expect("serialize JSON-RPC message"); + writer.write_all(&encoded).await.expect("write request"); + writer.write_all(b"\n").await.expect("write newline"); + } + + async fn read_response( + lines: &mut Lines>, + expected_id: i64, + ) -> T { + let line = lines + .next_line() + .await + .expect("read response") + .expect("response line"); + match serde_json::from_str::(&line).expect("decode JSON-RPC response") { + JSONRPCMessage::Response(JSONRPCResponse { id, result }) => { + assert_eq!(id, RequestId::Integer(expected_id)); + serde_json::from_value(result).expect("decode response result") + } + JSONRPCMessage::Error(error) => panic!("unexpected JSON-RPC error: {error:?}"), + other => panic!("expected JSON-RPC response, got {other:?}"), + } + } + + fn exec_params(process_id: ProcessId) -> ExecParams { + let mut env = HashMap::new(); + if let Some(path) = std::env::var_os("PATH") { + env.insert("PATH".to_string(), path.to_string_lossy().into_owned()); + } + ExecParams { + process_id, + argv: sleep_then_print_argv(), + cwd: std::env::current_dir().expect("cwd"), + env, + tty: false, + arg0: None, + } + } + + fn sleep_then_print_argv() -> Vec { + if cfg!(windows) { + vec![ + std::env::var("COMSPEC").unwrap_or_else(|_| "cmd.exe".to_string()), + "/C".to_string(), + "ping -n 3 127.0.0.1 >NUL && echo late".to_string(), + ] + } else { + vec![ + "/bin/sh".to_string(), + "-c".to_string(), + "sleep 1; printf late".to_string(), + ] + } + } +} diff --git a/codex-rs/exec-server/src/server/registry.rs b/codex-rs/exec-server/src/server/registry.rs index 19dba4b8b..a57704c50 100644 --- a/codex-rs/exec-server/src/server/registry.rs +++ b/codex-rs/exec-server/src/server/registry.rs @@ -30,18 +30,18 @@ use crate::server::ExecServerHandler; pub(crate) fn build_router() -> RpcRouter { let mut router = RpcRouter::new(); - router.request( - INITIALIZE_METHOD, - |handler: Arc, _params: InitializeParams| async move { - handler.initialize() - }, - ); router.notification( INITIALIZED_METHOD, |handler: Arc, _params: serde_json::Value| async move { handler.initialized() }, ); + router.request( + INITIALIZE_METHOD, + |handler: Arc, params: InitializeParams| async move { + handler.initialize(params).await + }, + ); router.request( EXEC_METHOD, |handler: Arc, params: ExecParams| async move { handler.exec(params).await }, diff --git a/codex-rs/exec-server/src/server/session_registry.rs b/codex-rs/exec-server/src/server/session_registry.rs new file mode 100644 index 000000000..82c779c6b --- /dev/null +++ b/codex-rs/exec-server/src/server/session_registry.rs @@ -0,0 +1,259 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::Mutex as StdMutex; +use std::time::Duration; + +use codex_app_server_protocol::JSONRPCErrorError; +use tokio::sync::Mutex; +use uuid::Uuid; + +use crate::rpc::RpcNotificationSender; +use crate::rpc::invalid_request; +use crate::server::process_handler::ProcessHandler; + +#[cfg(test)] +const DETACHED_SESSION_TTL: Duration = Duration::from_millis(200); +#[cfg(not(test))] +const DETACHED_SESSION_TTL: Duration = Duration::from_secs(10); + +pub(crate) struct SessionRegistry { + sessions: Mutex>>, +} + +struct SessionEntry { + session_id: String, + process: ProcessHandler, + attachment: StdMutex, +} + +struct AttachmentState { + current_connection_id: Option, + detached_connection_id: Option, + detached_expires_at: Option, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +struct ConnectionId(Uuid); + +impl std::fmt::Display for ConnectionId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[derive(Clone)] +pub(crate) struct SessionHandle { + registry: Arc, + entry: Arc, + connection_id: ConnectionId, +} + +impl SessionRegistry { + pub(crate) fn new() -> Arc { + Arc::new(Self { + sessions: Mutex::new(HashMap::new()), + }) + } + + pub(crate) async fn attach( + self: &Arc, + resume_session_id: Option, + notifications: RpcNotificationSender, + ) -> Result { + enum AttachOutcome { + Attached(Arc), + Expired { + session_id: String, + entry: Arc, + }, + } + + let connection_id = ConnectionId(Uuid::new_v4()); + let outcome = { + let mut sessions = self.sessions.lock().await; + if let Some(session_id) = resume_session_id { + let entry = sessions + .get(&session_id) + .cloned() + .ok_or_else(|| invalid_request(format!("unknown session id {session_id}")))?; + if entry.is_expired(tokio::time::Instant::now()) { + let entry = sessions.remove(&session_id).ok_or_else(|| { + invalid_request(format!("unknown session id {session_id}")) + })?; + Ok(AttachOutcome::Expired { session_id, entry }) + } else if entry.has_active_connection() { + Err(invalid_request(format!( + "session {session_id} is already attached to another connection" + ))) + } else { + entry.process.set_notification_sender(Some(notifications)); + entry.attach(connection_id); + Ok(AttachOutcome::Attached(entry)) + } + } else { + let session_id = Uuid::new_v4().to_string(); + let entry = Arc::new(SessionEntry::new( + session_id.clone(), + ProcessHandler::new(notifications), + connection_id, + )); + sessions.insert(session_id, Arc::clone(&entry)); + Ok(AttachOutcome::Attached(entry)) + } + }; + let entry = match outcome? { + AttachOutcome::Attached(entry) => entry, + AttachOutcome::Expired { session_id, entry } => { + entry.process.shutdown().await; + return Err(invalid_request(format!("unknown session id {session_id}"))); + } + }; + + Ok(SessionHandle { + registry: Arc::clone(self), + entry, + connection_id, + }) + } + + async fn expire_if_detached(&self, session_id: String, connection_id: ConnectionId) { + tokio::time::sleep(DETACHED_SESSION_TTL).await; + + let removed = { + let mut sessions = self.sessions.lock().await; + let Some(entry) = sessions.get(&session_id) else { + return; + }; + if !entry.is_detached_connection_expired(connection_id, tokio::time::Instant::now()) { + return; + } + sessions.remove(&session_id) + }; + + if let Some(entry) = removed { + entry.process.shutdown().await; + } + } +} + +impl Default for SessionRegistry { + fn default() -> Self { + Self { + sessions: Mutex::new(HashMap::new()), + } + } +} + +impl SessionEntry { + fn new(session_id: String, process: ProcessHandler, connection_id: ConnectionId) -> Self { + Self { + session_id, + process, + attachment: StdMutex::new(AttachmentState { + current_connection_id: Some(connection_id), + detached_connection_id: None, + detached_expires_at: None, + }), + } + } + + fn attach(&self, connection_id: ConnectionId) { + let mut attachment = self + .attachment + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + attachment.current_connection_id = Some(connection_id); + attachment.detached_connection_id = None; + attachment.detached_expires_at = None; + } + + fn detach(&self, connection_id: ConnectionId) -> bool { + let mut attachment = self + .attachment + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + if attachment.current_connection_id != Some(connection_id) { + return false; + } + + attachment.current_connection_id = None; + attachment.detached_connection_id = Some(connection_id); + attachment.detached_expires_at = Some(tokio::time::Instant::now() + DETACHED_SESSION_TTL); + true + } + + fn has_active_connection(&self) -> bool { + self.attachment + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .current_connection_id + .is_some() + } + + fn is_attached_to(&self, connection_id: ConnectionId) -> bool { + self.attachment + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .current_connection_id + == Some(connection_id) + } + + fn is_expired(&self, now: tokio::time::Instant) -> bool { + self.attachment + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .detached_expires_at + .is_some_and(|deadline| now >= deadline) + } + + fn is_detached_connection_expired( + &self, + connection_id: ConnectionId, + now: tokio::time::Instant, + ) -> bool { + let attachment = self + .attachment + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + attachment.current_connection_id.is_none() + && attachment.detached_connection_id == Some(connection_id) + && attachment + .detached_expires_at + .is_some_and(|deadline| now >= deadline) + } +} + +impl SessionHandle { + pub(crate) fn session_id(&self) -> &str { + &self.entry.session_id + } + + pub(crate) fn connection_id(&self) -> String { + self.connection_id.to_string() + } + + pub(crate) fn is_session_attached(&self) -> bool { + self.entry.is_attached_to(self.connection_id) + } + + pub(crate) fn process(&self) -> &ProcessHandler { + &self.entry.process + } + + pub(crate) async fn detach(&self) { + if !self.entry.detach(self.connection_id) { + return; + } + + self.entry + .process + .set_notification_sender(/*notifications*/ None); + + let registry = Arc::clone(&self.registry); + let session_id = self.entry.session_id.clone(); + let connection_id = self.connection_id; + tokio::spawn(async move { + registry.expire_if_detached(session_id, connection_id).await; + }); + } +} diff --git a/codex-rs/exec-server/src/server/transport.rs b/codex-rs/exec-server/src/server/transport.rs index 4726465cc..de94b1af6 100644 --- a/codex-rs/exec-server/src/server/transport.rs +++ b/codex-rs/exec-server/src/server/transport.rs @@ -5,7 +5,7 @@ use tokio_tungstenite::accept_async; use tracing::warn; use crate::connection::JsonRpcConnection; -use crate::server::processor::run_connection; +use crate::server::processor::ConnectionProcessor; pub const DEFAULT_LISTEN_URL: &str = "ws://127.0.0.1:0"; @@ -58,19 +58,22 @@ async fn run_websocket_listener( ) -> Result<(), Box> { let listener = TcpListener::bind(bind_address).await?; let local_addr = listener.local_addr()?; + let processor = ConnectionProcessor::new(); tracing::info!("codex-exec-server listening on ws://{local_addr}"); println!("ws://{local_addr}"); loop { let (stream, peer_addr) = listener.accept().await?; + let processor = processor.clone(); tokio::spawn(async move { match accept_async(stream).await { Ok(websocket) => { - run_connection(JsonRpcConnection::from_websocket( - websocket, - format!("exec-server websocket {peer_addr}"), - )) - .await; + processor + .run_connection(JsonRpcConnection::from_websocket( + websocket, + format!("exec-server websocket {peer_addr}"), + )) + .await; } Err(err) => { warn!( diff --git a/codex-rs/exec-server/tests/common/exec_server.rs b/codex-rs/exec-server/tests/common/exec_server.rs index c7c120ee1..65d6f1ec6 100644 --- a/codex-rs/exec-server/tests/common/exec_server.rs +++ b/codex-rs/exec-server/tests/common/exec_server.rs @@ -64,6 +64,17 @@ impl ExecServerHarness { &self.websocket_url } + pub(crate) async fn disconnect_websocket(&mut self) -> anyhow::Result<()> { + self.websocket.close(None).await?; + Ok(()) + } + + pub(crate) async fn reconnect_websocket(&mut self) -> anyhow::Result<()> { + let (websocket, _) = connect_websocket_when_ready(&self.websocket_url).await?; + self.websocket = websocket; + Ok(()) + } + pub(crate) async fn send_request( &mut self, method: &str, diff --git a/codex-rs/exec-server/tests/initialize.rs b/codex-rs/exec-server/tests/initialize.rs index 0e95c9f9a..9c6739566 100644 --- a/codex-rs/exec-server/tests/initialize.rs +++ b/codex-rs/exec-server/tests/initialize.rs @@ -8,6 +8,7 @@ use codex_exec_server::InitializeParams; use codex_exec_server::InitializeResponse; use common::exec_server::exec_server; use pretty_assertions::assert_eq; +use uuid::Uuid; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn exec_server_accepts_initialize() -> anyhow::Result<()> { @@ -17,6 +18,7 @@ async fn exec_server_accepts_initialize() -> anyhow::Result<()> { "initialize", serde_json::to_value(InitializeParams { client_name: "exec-server-test".to_string(), + resume_session_id: None, })?, ) .await?; @@ -27,7 +29,7 @@ async fn exec_server_accepts_initialize() -> anyhow::Result<()> { }; assert_eq!(id, initialize_id); let initialize_response: InitializeResponse = serde_json::from_value(result)?; - assert_eq!(initialize_response, InitializeResponse {}); + Uuid::parse_str(&initialize_response.session_id)?; server.shutdown().await?; Ok(()) diff --git a/codex-rs/exec-server/tests/process.rs b/codex-rs/exec-server/tests/process.rs index c210b5a9c..cc96e6ef8 100644 --- a/codex-rs/exec-server/tests/process.rs +++ b/codex-rs/exec-server/tests/process.rs @@ -6,7 +6,10 @@ use codex_app_server_protocol::JSONRPCMessage; use codex_app_server_protocol::JSONRPCResponse; use codex_exec_server::ExecResponse; use codex_exec_server::InitializeParams; +use codex_exec_server::InitializeResponse; use codex_exec_server::ProcessId; +use codex_exec_server::ReadResponse; +use codex_exec_server::TerminateResponse; use common::exec_server::exec_server; use pretty_assertions::assert_eq; @@ -18,6 +21,7 @@ async fn exec_server_starts_process_over_websocket() -> anyhow::Result<()> { "initialize", serde_json::to_value(InitializeParams { client_name: "exec-server-test".to_string(), + resume_session_id: None, })?, ) .await?; @@ -70,3 +74,137 @@ async fn exec_server_starts_process_over_websocket() -> anyhow::Result<()> { server.shutdown().await?; Ok(()) } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn exec_server_resumes_detached_session_without_killing_processes() -> anyhow::Result<()> { + let mut server = exec_server().await?; + let initialize_id = server + .send_request( + "initialize", + serde_json::to_value(InitializeParams { + client_name: "exec-server-test".to_string(), + resume_session_id: None, + })?, + ) + .await?; + let response = server + .wait_for_event(|event| { + matches!( + event, + JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &initialize_id + ) + }) + .await?; + let JSONRPCMessage::Response(JSONRPCResponse { result, .. }) = response else { + panic!("expected initialize response"); + }; + let initialize_response: InitializeResponse = serde_json::from_value(result)?; + + server + .send_notification("initialized", serde_json::json!({})) + .await?; + + let process_start_id = server + .send_request( + "process/start", + serde_json::json!({ + "processId": "proc-resume", + "argv": ["/bin/sh", "-c", "sleep 5"], + "cwd": std::env::current_dir()?, + "env": {}, + "tty": false, + "arg0": null + }), + ) + .await?; + let _ = server + .wait_for_event(|event| { + matches!( + event, + JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &process_start_id + ) + }) + .await?; + + server.disconnect_websocket().await?; + server.reconnect_websocket().await?; + + let resume_initialize_id = server + .send_request( + "initialize", + serde_json::to_value(InitializeParams { + client_name: "exec-server-test".to_string(), + resume_session_id: Some(initialize_response.session_id.clone()), + })?, + ) + .await?; + let response = server + .wait_for_event(|event| { + matches!( + event, + JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &resume_initialize_id + ) + }) + .await?; + let JSONRPCMessage::Response(JSONRPCResponse { result, .. }) = response else { + panic!("expected resume initialize response"); + }; + let resumed_response: InitializeResponse = serde_json::from_value(result)?; + assert_eq!(resumed_response, initialize_response); + + server + .send_notification("initialized", serde_json::json!({})) + .await?; + + let process_read_id = server + .send_request( + "process/read", + serde_json::json!({ + "processId": "proc-resume", + "afterSeq": null, + "maxBytes": null, + "waitMs": 0 + }), + ) + .await?; + let response = server + .wait_for_event(|event| { + matches!( + event, + JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &process_read_id + ) + }) + .await?; + let JSONRPCMessage::Response(JSONRPCResponse { result, .. }) = response else { + panic!("expected process/read response"); + }; + let process_read_response: ReadResponse = serde_json::from_value(result)?; + assert!(process_read_response.failure.is_none()); + assert!(!process_read_response.exited); + assert!(!process_read_response.closed); + + let terminate_id = server + .send_request( + "process/terminate", + serde_json::json!({ + "processId": "proc-resume" + }), + ) + .await?; + let response = server + .wait_for_event(|event| { + matches!( + event, + JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &terminate_id + ) + }) + .await?; + let JSONRPCMessage::Response(JSONRPCResponse { result, .. }) = response else { + panic!("expected process/terminate response"); + }; + let terminate_response: TerminateResponse = serde_json::from_value(result)?; + assert_eq!(terminate_response, TerminateResponse { running: true }); + + server.shutdown().await?; + Ok(()) +} diff --git a/codex-rs/exec-server/tests/websocket.rs b/codex-rs/exec-server/tests/websocket.rs index f26efa520..64c9438b8 100644 --- a/codex-rs/exec-server/tests/websocket.rs +++ b/codex-rs/exec-server/tests/websocket.rs @@ -9,6 +9,7 @@ use codex_exec_server::InitializeParams; use codex_exec_server::InitializeResponse; use common::exec_server::exec_server; use pretty_assertions::assert_eq; +use uuid::Uuid; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn exec_server_reports_malformed_websocket_json_and_keeps_running() -> anyhow::Result<()> { @@ -36,6 +37,7 @@ async fn exec_server_reports_malformed_websocket_json_and_keeps_running() -> any "initialize", serde_json::to_value(InitializeParams { client_name: "exec-server-test".to_string(), + resume_session_id: None, })?, ) .await?; @@ -53,7 +55,7 @@ async fn exec_server_reports_malformed_websocket_json_and_keeps_running() -> any }; assert_eq!(id, initialize_id); let initialize_response: InitializeResponse = serde_json::from_value(result)?; - assert_eq!(initialize_response, InitializeResponse {}); + Uuid::parse_str(&initialize_response.session_id)?; server.shutdown().await?; Ok(())