diff --git a/codex-rs/code-mode/src/cell_actor/types.rs b/codex-rs/code-mode/src/cell_actor/types.rs index 020ae3dcb..ee345d6ab 100644 --- a/codex-rs/code-mode/src/cell_actor/types.rs +++ b/codex-rs/code-mode/src/cell_actor/types.rs @@ -92,16 +92,13 @@ impl CellHandle { pub(crate) fn terminate(&self) -> CellEventFuture { self.state.request_termination() } - - pub(crate) fn shutdown(&self) { - self.state.cancellation_token().cancel(); - } } /// The single linearization point for a cell's terminal outcome. /// -/// Callback cancellation tokens are children of the cell token, so a terminal -/// decision cancels runtime work and its callbacks together. +/// The cancellation token is a child of the owning session token. Callback +/// tokens are children of this token, so cancellation flows strictly from the +/// session to the cell and then to its callbacks. /// /// The mutex is held only for synchronous phase transitions and terminal /// delivery. Runtime execution, observation waits, and callbacks never run diff --git a/codex-rs/code-mode/src/runtime/mod.rs b/codex-rs/code-mode/src/runtime/mod.rs index bfc54f71b..42af5c0b9 100644 --- a/codex-rs/code-mode/src/runtime/mod.rs +++ b/codex-rs/code-mode/src/runtime/mod.rs @@ -420,7 +420,7 @@ await new Promise(() => {}); .send(RuntimeCommand::TimeoutFired { id: 1 }) .unwrap(); assert!( - tokio::time::timeout(Duration::from_millis(100), event_rx.recv()) + tokio::time::timeout(Duration::from_secs(1), event_rx.recv()) .await .is_err() ); diff --git a/codex-rs/code-mode/src/service_contract_tests.rs b/codex-rs/code-mode/src/service_contract_tests.rs index 948431032..8ead39640 100644 --- a/codex-rs/code-mode/src/service_contract_tests.rs +++ b/codex-rs/code-mode/src/service_contract_tests.rs @@ -289,7 +289,10 @@ async fn observed_natural_completion_wins_over_termination() { tokio::time::timeout(Duration::from_secs(1), async { loop { let response = service - .execute(execute_request(r#"text(String(load("finished")));"#)) + .execute(ExecuteRequest { + yield_time_ms: Some(60_000), + ..execute_request(r#"text(String(load("finished")));"#) + }) .await .unwrap() .initial_response() diff --git a/codex-rs/code-mode/src/service_tests.rs b/codex-rs/code-mode/src/service_tests.rs index d90d4b6af..6344bfdab 100644 --- a/codex-rs/code-mode/src/service_tests.rs +++ b/codex-rs/code-mode/src/service_tests.rs @@ -360,7 +360,7 @@ await Promise.all([ } ); - tokio::time::sleep(Duration::from_millis(1100)).await; + tokio::time::sleep(Duration::from_secs(2)).await; let resumed_response = tokio::time::timeout( Duration::from_secs(1), diff --git a/codex-rs/code-mode/src/session_runtime/mod.rs b/codex-rs/code-mode/src/session_runtime/mod.rs index 362dccda7..71cf3d613 100644 --- a/codex-rs/code-mode/src/session_runtime/mod.rs +++ b/codex-rs/code-mode/src/session_runtime/mod.rs @@ -4,13 +4,13 @@ use std::collections::HashMap; use std::future::Future; use std::pin::Pin; use std::sync::Arc; -use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; use serde_json::Value as JsonValue; use tokio::sync::Mutex; use tokio_util::sync::CancellationToken; +use tokio_util::task::TaskTracker; pub(crate) use self::types::CellEvent; pub(crate) use self::types::CellId; @@ -43,8 +43,9 @@ pub(crate) struct SessionRuntime { struct Inner { stored_values: Mutex>, cells: Mutex>, + cell_tasks: TaskTracker, + shutdown_token: CancellationToken, delegate: Arc, - shutting_down: AtomicBool, next_cell_id: AtomicU64, } @@ -54,15 +55,16 @@ impl SessionRuntime { inner: Arc::new(Inner { stored_values: Mutex::new(HashMap::new()), cells: Mutex::new(HashMap::new()), + cell_tasks: TaskTracker::new(), + shutdown_token: CancellationToken::new(), delegate, - shutting_down: AtomicBool::new(false), next_cell_id: AtomicU64::new(1), }), } } pub(crate) fn is_alive(&self) -> bool { - !self.inner.shutting_down.load(Ordering::Acquire) + !self.inner.shutdown_token.is_cancelled() } pub(crate) async fn execute( @@ -70,6 +72,9 @@ impl SessionRuntime { request: CreateCellRequest, initial_observe_mode: ObserveMode, ) -> Result { + if self.inner.shutdown_token.is_cancelled() { + return Err(Error::ShuttingDown); + } let cell_id = self.allocate_cell_id(); let initial_event = self .start_cell(cell_id.clone(), request, initial_observe_mode) @@ -122,21 +127,13 @@ impl SessionRuntime { } pub(crate) async fn shutdown(&self) -> Result<(), Error> { - self.inner.shutting_down.store(true, Ordering::Release); - let handles = self - .inner - .cells - .lock() - .await - .values() - .cloned() - .collect::>(); - for handle in handles { - handle.shutdown(); - } - while !self.inner.cells.lock().await.is_empty() { - tokio::task::yield_now().await; - } + self.begin_shutdown(); + // Taking the registry lock ensures every cell that passed the shutdown + // check has registered its actor with the tracker before we wait. + let cells = self.inner.cells.lock().await; + self.inner.cell_tasks.close(); + drop(cells); + self.inner.cell_tasks.wait().await; Ok(()) } @@ -161,13 +158,13 @@ impl SessionRuntime { inner: Arc::clone(&self.inner), }); let mut cells = self.inner.cells.lock().await; - if self.inner.shutting_down.load(Ordering::Acquire) { + if self.inner.shutdown_token.is_cancelled() { return Err(Error::ShuttingDown); } if cells.contains_key(&cell_id) { return Err(Error::DuplicateCell(cell_id)); } - let cell_state = Arc::new(CellState::new(CancellationToken::new())); + let cell_state = Arc::new(CellState::new(self.inner.shutdown_token.child_token())); let (handle, initial_event, task) = CellActor::prepare( request, stored_values, @@ -177,18 +174,14 @@ impl SessionRuntime { ) .map_err(Error::Runtime)?; cells.insert(cell_id.clone(), handle); + self.inner.cell_tasks.spawn(task); drop(cells); - tokio::spawn(task); Ok(map_actor_event(cell_id, initial_event)) } fn begin_shutdown(&self) { - self.inner.shutting_down.store(true, Ordering::Release); - if let Ok(cells) = self.inner.cells.try_lock() { - for handle in cells.values() { - handle.shutdown(); - } - } + self.inner.shutdown_token.cancel(); + self.inner.cell_tasks.close(); } } diff --git a/codex-rs/code-mode/src/session_runtime/tests.rs b/codex-rs/code-mode/src/session_runtime/tests.rs index b4457b3f7..a1dc18c0e 100644 --- a/codex-rs/code-mode/src/session_runtime/tests.rs +++ b/codex-rs/code-mode/src/session_runtime/tests.rs @@ -108,3 +108,80 @@ async fn termination_rejects_a_waiting_store_commit_before_the_next_cell_can_loa ); runtime.shutdown().await.unwrap(); } + +fn execute_request(source: &str) -> CreateCellRequest { + CreateCellRequest { + tool_call_id: "call-1".to_string(), + enabled_tools: Vec::new(), + source: source.to_string(), + } +} + +#[tokio::test] +#[expect( + clippy::await_holding_invalid_type, + reason = "test holds the registry lock to force admission ahead of shutdown" +)] +async fn shutdown_rejects_cell_admission_queued_before_the_registry_lock() { + let runtime = Arc::new(SessionRuntime::new(Arc::new(RecordingDelegate))); + let cells = runtime.inner.cells.lock().await; + + let execution = runtime.execute( + execute_request("while (true) {}"), + ObserveMode::YieldAfter(Duration::from_millis(/*millis*/ 1)), + ); + tokio::pin!(execution); + std::future::poll_fn(|context| match execution.as_mut().poll(context) { + Poll::Pending => Poll::Ready(()), + Poll::Ready(Ok(_)) => panic!("execution completed before the registry lock was released"), + Poll::Ready(Err(error)) => { + panic!("execution failed before the registry lock was released: {error}") + } + }) + .await; + + let shutdown = runtime.shutdown(); + tokio::pin!(shutdown); + std::future::poll_fn(|context| match shutdown.as_mut().poll(context) { + Poll::Pending => Poll::Ready(()), + Poll::Ready(Ok(())) => panic!("shutdown completed before acquiring the registry lock"), + Poll::Ready(Err(error)) => { + panic!("shutdown failed before acquiring the registry lock: {error}") + } + }) + .await; + + assert!(!runtime.is_alive()); + drop(cells); + assert!(matches!(execution.await, Err(Error::ShuttingDown))); + assert_eq!(shutdown.await, Ok(())); +} + +#[tokio::test] +async fn drop_terminates_cells_when_the_registry_is_locked() { + let runtime = SessionRuntime::new(Arc::new(RecordingDelegate)); + let started = runtime + .execute( + execute_request("while (true) {}"), + ObserveMode::YieldAfter(Duration::from_millis(/*millis*/ 1)), + ) + .await + .unwrap(); + assert_eq!(started.cell_id, CellId::new("1")); + assert_eq!( + started.initial_event().await, + Ok(CellEvent::Yielded { + content_items: Vec::new(), + }) + ); + + let inner = Arc::clone(&runtime.inner); + let cells = inner.cells.lock().await; + drop(runtime); + drop(cells); + + tokio::time::timeout(Duration::from_secs(/*secs*/ 1), inner.cell_tasks.wait()) + .await + .unwrap(); + assert!(inner.cell_tasks.is_empty()); +}