diff --git a/codex-rs/code-mode/src/cell_actor/mod.rs b/codex-rs/code-mode/src/cell_actor/mod.rs index 6e0ffb9e0..d288bd2db 100644 --- a/codex-rs/code-mode/src/cell_actor/mod.rs +++ b/codex-rs/code-mode/src/cell_actor/mod.rs @@ -25,7 +25,11 @@ pub(crate) use self::types::CellError; pub(crate) use self::types::CellEventFuture; pub(crate) use self::types::CellHandle; pub(crate) use self::types::CellHost; +pub(crate) use self::types::CellState; pub(crate) use self::types::CellToolCall; +pub(crate) use self::types::CompletionCommit; +use self::types::CompletionDelivery; +use self::types::ObservationDelivery; use crate::runtime::PendingRuntimeMode; use crate::runtime::RuntimeCommand; use crate::runtime::RuntimeControlCommand; @@ -34,6 +38,7 @@ use crate::runtime::spawn_runtime; use crate::session_runtime::CellEvent; use crate::session_runtime::CreateCellRequest as CellRequest; use crate::session_runtime::ObserveMode; +use crate::session_runtime::OutputItem; use crate::session_runtime::ToolName as CellToolName; pub(crate) struct CellActor; @@ -44,6 +49,7 @@ impl CellActor { stored_values: HashMap, host: Arc, initial_observe_mode: ObserveMode, + cell_state: Arc, ) -> Result< ( CellHandle, @@ -61,15 +67,14 @@ impl CellActor { event_tx, PendingRuntimeMode::PauseUntilResumed, )?; - let cancellation_token = CancellationToken::new(); - let handle = CellHandle::new(command_tx, cancellation_token.clone()); + let handle = CellHandle::new(command_tx, Arc::clone(&cell_state)); let task = run_cell( host, CellContext { runtime_tx, runtime_control_tx, runtime_terminate_handle, - cancellation_token, + cell_state, }, event_rx, command_rx, @@ -88,7 +93,7 @@ struct CellContext { runtime_tx: std::sync::mpsc::Sender, runtime_control_tx: std::sync::mpsc::Sender, runtime_terminate_handle: v8::IsolateHandle, - cancellation_token: CancellationToken, + cell_state: Arc, } struct Observer { @@ -96,115 +101,98 @@ struct Observer { response_tx: oneshot::Sender>, } -struct Termination { - response_tx: Option>>, -} - async fn run_cell( host: Arc, context: CellContext, mut event_rx: mpsc::UnboundedReceiver, - mut command_rx: mpsc::UnboundedReceiver, + command_rx: mpsc::UnboundedReceiver, initial_observer: Observer, ) { let CellContext { runtime_tx, runtime_control_tx, runtime_terminate_handle, - cancellation_token, + cell_state, } = context; + let cancellation_token = cell_state.cancellation_token(); + let callback_cancellation_token = cancellation_token.child_token(); let mut content_items = Vec::new(); let mut pending_tool_call_ids = Vec::new(); - let mut completed_event = None; let mut observer = Some(initial_observer); - let mut termination: Option = None; + let mut termination = false; let mut runtime_closed = false; let mut runtime_paused = false; let mut yield_timer: Option>> = None; let mut notification_tasks = JoinSet::new(); let mut tool_tasks = JoinSet::new(); - + let mut command_rx = Some(command_rx); loop { let yield_deadline_elapsed = yield_timer .as_ref() .is_some_and(|yield_timer| yield_timer.deadline() <= tokio::time::Instant::now()); tokio::select! { biased; - maybe_command = command_rx.recv() => { - let Some(command) = maybe_command else { - if completed_event.is_some() { - break; - } - termination = Some(Termination { response_tx: None }); - begin_termination( - &runtime_tx, - &runtime_control_tx, - &runtime_terminate_handle, - &cancellation_token, + _ = cancellation_token.cancelled(), if !termination => { + termination = true; + yield_timer = None; + drop(command_rx.take()); + begin_termination( + &runtime_tx, + &runtime_control_tx, + &runtime_terminate_handle, + &cancellation_token, + ); + if runtime_closed { + finish_callbacks( + &callback_cancellation_token, + &mut notification_tasks, + &mut tool_tasks, + CallbackCompletion::Cancel, + ).await; + finish_termination( + &cell_state, + observer.take().map(|observer| observer.response_tx), + CellEvent::Terminated { + content_items: std::mem::take(&mut content_items), + }, ); - if runtime_closed { - break; - } + break; + } + } + maybe_command = async { + match command_rx.as_mut() { + Some(command_rx) => command_rx.recv().await, + None => std::future::pending::>().await, + } + } => { + let Some(CellCommand::Observe { mode, response_tx }) = maybe_command else { + cancellation_token.cancel(); continue; }; - match command { - CellCommand::Observe { mode, response_tx } => { - if let Some(event) = completed_event.take() { - let _ = response_tx.send(Ok(event)); - break; - } - if observer.is_some() || termination.is_some() { - let _ = response_tx.send(Err(CellError::Busy)); - continue; - } - observer = Some(Observer { mode, response_tx }); - yield_timer = observer.as_ref().and_then(observer_timer); - resume_for_observation( - mode, - &mut runtime_paused, - &runtime_tx, - &runtime_control_tx, - ); - } - CellCommand::Terminate { response_tx } => { - if let Some(event) = completed_event.take() { - if let Some(response_tx) = response_tx { - let _ = response_tx.send(Ok(event)); - } - break; - } - if termination.is_some() { - if let Some(response_tx) = response_tx { - let _ = response_tx.send(Err(CellError::AlreadyTerminating)); - } - continue; - } - termination = Some(Termination { response_tx }); - yield_timer = None; - begin_termination( - &runtime_tx, - &runtime_control_tx, - &runtime_terminate_handle, - &cancellation_token, - ); - if runtime_closed { - finish_callbacks( - &cancellation_token, - &mut notification_tasks, - &mut tool_tasks, - CallbackCompletion::Cancel, - ).await; - send_termination_events( - observer.take(), - termination.take(), - CellEvent::Terminated { - content_items: std::mem::take(&mut content_items), - }, - ); - break; - } - } + let response_tx = match cell_state.route_observation(response_tx) { + ObservationDelivery::Running(response_tx) => response_tx, + ObservationDelivery::Delivered => break, + ObservationDelivery::Buffered | ObservationDelivery::Closed => continue, + }; + if observer + .as_ref() + .is_some_and(|observer| observer.response_tx.is_closed()) + { + observer = None; + yield_timer = None; } + if observer.is_some() || termination { + let _ = response_tx.send(Err(CellError::Busy)); + continue; + } + observer = Some(Observer { mode, response_tx }); + yield_timer = observer.as_ref().and_then(observer_timer); + resume_for_observation( + mode, + &mut runtime_paused, + &runtime_tx, + &runtime_control_tx, + ); } _ = async { if let Some(yield_timer) = yield_timer.as_mut() { @@ -230,38 +218,57 @@ async fn run_cell( }, if !yield_deadline_elapsed => { let Some(event) = maybe_event else { runtime_closed = true; - if termination.is_some() { + if termination || cancellation_token.is_cancelled() { finish_callbacks( - &cancellation_token, + &callback_cancellation_token, &mut notification_tasks, &mut tool_tasks, CallbackCompletion::Cancel, ).await; - send_termination_events( - observer.take(), - termination.take(), + finish_termination( + &cell_state, + observer.take().map(|observer| observer.response_tx), CellEvent::Terminated { content_items: std::mem::take(&mut content_items), }, ); break; } - if completed_event.is_none() { - finish_callbacks( - &cancellation_token, - &mut notification_tasks, - &mut tool_tasks, - CallbackCompletion::DrainNotifications, - ).await; - let event = CellEvent::Completed { - content_items: std::mem::take(&mut content_items), - error_text: Some("exec runtime ended unexpectedly".to_string()), - }; - if send_or_buffer_completion( + finish_callbacks( + &callback_cancellation_token, + &mut notification_tasks, + &mut tool_tasks, + CallbackCompletion::DrainNotifications, + ) + .await; + let event = CellEvent::Completed { + content_items: std::mem::take(&mut content_items), + error_text: Some("exec runtime ended unexpectedly".to_string()), + }; + let rejected_event = match host + .commit_completion( + HashMap::new(), event, - &mut observer, - &mut completed_event, - ) { + Arc::clone(&cell_state), + ) + .await + { + CompletionCommit::Committed => None, + CompletionCommit::Rejected(event) => Some(event), + }; + match cell_state.deliver_completion( + observer.take().map(|observer| observer.response_tx), + ) { + CompletionDelivery::Delivered => break, + CompletionDelivery::Buffered => {} + CompletionDelivery::Rejected(response_tx) => { + finish_termination( + &cell_state, + response_tx, + CellEvent::Terminated { + content_items: rejected_completion_content(rejected_event), + }, + ); break; } } @@ -314,7 +321,7 @@ async fn run_cell( Arc::clone(&host), call_id, text, - cancellation_token.child_token(), + callback_cancellation_token.child_token(), ); } RuntimeEvent::ToolCall { id, name, kind, input } => { @@ -332,22 +339,22 @@ async fn run_cell( input, }, runtime_tx.clone(), - cancellation_token.child_token(), + callback_cancellation_token.child_token(), ); } RuntimeEvent::Result { stored_value_writes, error_text } => { runtime_closed = true; yield_timer = None; - if termination.is_some() { + if termination || cancellation_token.is_cancelled() { finish_callbacks( - &cancellation_token, + &callback_cancellation_token, &mut notification_tasks, &mut tool_tasks, CallbackCompletion::Cancel, ).await; - send_termination_events( - observer.take(), - termination.take(), + finish_termination( + &cell_state, + observer.take().map(|observer| observer.response_tx), CellEvent::Terminated { content_items: std::mem::take(&mut content_items), }, @@ -355,22 +362,42 @@ async fn run_cell( break; } finish_callbacks( - &cancellation_token, + &callback_cancellation_token, &mut notification_tasks, &mut tool_tasks, CallbackCompletion::DrainNotifications, - ).await; - host.commit_stored_values(stored_value_writes).await; + ) + .await; let event = CellEvent::Completed { content_items: std::mem::take(&mut content_items), error_text, }; - if send_or_buffer_completion( - event, - &mut observer, - &mut completed_event, + let rejected_event = match host + .commit_completion( + stored_value_writes, + event, + Arc::clone(&cell_state), + ) + .await + { + CompletionCommit::Committed => None, + CompletionCommit::Rejected(event) => Some(event), + }; + match cell_state.deliver_completion( + observer.take().map(|observer| observer.response_tx), ) { - break; + CompletionDelivery::Delivered => break, + CompletionDelivery::Buffered => {} + CompletionDelivery::Rejected(response_tx) => { + finish_termination( + &cell_state, + response_tx, + CellEvent::Terminated { + content_items: rejected_completion_content(rejected_event), + }, + ); + break; + } } } } @@ -383,6 +410,9 @@ async fn run_cell( } } } + // Reject requests that arrive while asynchronous terminal cleanup runs. + cell_state.tombstone(); + drop(command_rx.take()); begin_termination( &runtime_tx, &runtime_control_tx, @@ -390,7 +420,7 @@ async fn run_cell( &cancellation_token, ); finish_callbacks( - &cancellation_token, + &callback_cancellation_token, &mut notification_tasks, &mut tool_tasks, CallbackCompletion::Cancel, @@ -399,34 +429,29 @@ async fn run_cell( host.closed().await; } -fn send_or_buffer_completion( - event: CellEvent, - observer: &mut Option, - completed_event: &mut Option, -) -> bool { - if observer.is_some() { - send_observer_event(observer.take(), event); - true - } else { - *completed_event = Some(event); - false - } -} - fn send_observer_event(observer: Option, event: CellEvent) { if let Some(observer) = observer { let _ = observer.response_tx.send(Ok(event)); } } -fn send_termination_events( - observer: Option, - termination: Option, +fn rejected_completion_content(event: Option) -> Vec { + match event { + Some(CellEvent::Completed { content_items, .. }) => content_items, + None => Vec::new(), + Some(event) => panic!("completion commit rejected an unexpected event: {event:?}"), + } +} + +fn finish_termination( + cell_state: &CellState, + observer_tx: Option>>, event: CellEvent, ) { - send_observer_event(observer, event.clone()); - if let Some(response_tx) = termination.and_then(|termination| termination.response_tx) { - let _ = response_tx.send(Ok(event)); + if let Some(event) = cell_state.finish_termination(event) + && let Some(observer_tx) = observer_tx + { + let _ = observer_tx.send(Ok(event)); } } diff --git a/codex-rs/code-mode/src/cell_actor/tests.rs b/codex-rs/code-mode/src/cell_actor/tests.rs index d7b72a103..7404b9fe1 100644 --- a/codex-rs/code-mode/src/cell_actor/tests.rs +++ b/codex-rs/code-mode/src/cell_actor/tests.rs @@ -33,7 +33,14 @@ impl CellHost for TestHost { Ok(()) } - async fn commit_stored_values(&self, _stored_value_writes: HashMap) {} + async fn commit_completion( + &self, + _stored_value_writes: HashMap, + event: CellEvent, + cell_state: Arc, + ) -> CompletionCommit { + cell_state.commit_completion(event, || {}) + } async fn closed(&self) {} } @@ -64,14 +71,15 @@ fn spawn_cell_actor_harness(initial_observe_mode: ObserveMode) -> CellActorHarne PendingRuntimeMode::PauseUntilResumed, ) .unwrap(); - let handle = CellHandle::new(command_tx, CancellationToken::new()); + let cell_state = Arc::new(CellState::new(CancellationToken::new())); + let handle = CellHandle::new(command_tx, Arc::clone(&cell_state)); let task = tokio::spawn(run_cell( Arc::new(TestHost), CellContext { runtime_tx, runtime_control_tx, runtime_terminate_handle, - cancellation_token: CancellationToken::new(), + cell_state, }, event_rx, command_rx, @@ -142,3 +150,86 @@ async fn queued_termination_preempts_unobserved_runtime_completion() { assert_eq!(harness.initial_event_rx.await.unwrap(), terminated); harness.task.await.unwrap(); } + +#[tokio::test] +async fn only_the_first_termination_claims_a_buffered_completion() { + let cell_state = CellState::new(CancellationToken::new()); + let completion = CellEvent::Completed { + content_items: Vec::new(), + error_text: None, + }; + assert_eq!( + cell_state.commit_completion(completion.clone(), || {}), + CompletionCommit::Committed + ); + assert!(matches!( + cell_state.deliver_completion(/*response_tx*/ None), + CompletionDelivery::Buffered + )); + + let first_termination = cell_state.request_termination(); + assert_eq!( + cell_state.request_termination().await, + Err(CellError::AlreadyTerminating) + ); + assert_eq!(first_termination.await, Ok(completion.clone())); + assert_eq!( + cell_state.finish_termination(CellEvent::Terminated { + content_items: Vec::new(), + }), + Some(completion) + ); +} + +#[tokio::test] +async fn termination_claim_prevents_stored_value_commit() { + let cell_state = CellState::new(CancellationToken::new()); + let termination = cell_state.request_termination(); + let mut commit_ran = false; + let completion = CellEvent::Completed { + content_items: Vec::new(), + error_text: None, + }; + + assert_eq!( + cell_state.commit_completion(completion.clone(), || commit_ran = true), + CompletionCommit::Rejected(completion) + ); + assert!(!commit_ran); + + let terminated = CellEvent::Terminated { + content_items: Vec::new(), + }; + assert_eq!( + cell_state.finish_termination(terminated.clone()), + Some(terminated.clone()) + ); + assert_eq!(termination.await, Ok(terminated)); +} + +#[test] +fn failed_completion_delivery_rebuffers_the_event() { + let cell_state = CellState::new(CancellationToken::new()); + let event = CellEvent::Completed { + content_items: Vec::new(), + error_text: None, + }; + assert_eq!( + cell_state.commit_completion(event.clone(), || {}), + CompletionCommit::Committed + ); + let (response_tx, response_rx) = oneshot::channel(); + drop(response_rx); + assert!(matches!( + cell_state.deliver_completion(Some(response_tx)), + CompletionDelivery::Buffered + )); + assert!(cell_state.accepting_observations()); + + let (response_tx, mut response_rx) = oneshot::channel(); + assert!(matches!( + cell_state.route_observation(response_tx), + ObservationDelivery::Delivered + )); + assert_eq!(response_rx.try_recv(), Ok(Ok(event))); +} diff --git a/codex-rs/code-mode/src/cell_actor/types.rs b/codex-rs/code-mode/src/cell_actor/types.rs index d82f898e7..020ae3dcb 100644 --- a/codex-rs/code-mode/src/cell_actor/types.rs +++ b/codex-rs/code-mode/src/cell_actor/types.rs @@ -2,8 +2,7 @@ use std::collections::HashMap; use std::future::Future; use std::pin::Pin; use std::sync::Arc; -use std::sync::atomic::AtomicBool; -use std::sync::atomic::Ordering; +use std::sync::Mutex; use serde_json::Value as JsonValue; use tokio::sync::mpsc; @@ -32,10 +31,11 @@ pub(crate) struct CellToolCall { pub(crate) input: Option, } -/// Connects a cell actor to session-owned callbacks and lifecycle state. +/// Connects a cell actor to session-owned callbacks and stored values. /// -/// Implementations must honor callback cancellation and must not return from -/// `closed` until the session can no longer route requests to the cell. +/// Implementations should forward callback cancellation to downstream work. +/// Implementations must not return from `closed` until the session can no longer +/// route requests to the cell. pub(crate) trait CellHost: Send + Sync + 'static { fn invoke_tool( &self, @@ -50,10 +50,12 @@ pub(crate) trait CellHost: Send + Sync + 'static { cancellation_token: CancellationToken, ) -> impl Future> + Send; - fn commit_stored_values( + fn commit_completion( &self, stored_value_writes: HashMap, - ) -> impl Future + Send; + event: CellEvent, + cell_state: Arc, + ) -> impl Future + Send; fn closed(&self) -> impl Future + Send; } @@ -61,23 +63,21 @@ pub(crate) trait CellHost: Send + Sync + 'static { #[derive(Clone)] pub(crate) struct CellHandle { command_tx: mpsc::UnboundedSender, - cancellation_token: CancellationToken, - termination_requested: Arc, + state: Arc, } impl CellHandle { pub(super) fn new( command_tx: mpsc::UnboundedSender, - cancellation_token: CancellationToken, + state: Arc, ) -> Self { - Self { - command_tx, - cancellation_token, - termination_requested: Arc::new(AtomicBool::new(false)), - } + Self { command_tx, state } } pub(crate) fn observe(&self, mode: ObserveMode) -> CellEventFuture { + if !self.state.accepting_observations() { + return closed_event(); + } let (response_tx, response_rx) = oneshot::channel(); if self .command_tx @@ -90,33 +90,232 @@ impl CellHandle { } pub(crate) fn terminate(&self) -> CellEventFuture { - if self - .termination_requested - .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) - .is_err() - { - return Box::pin(async { Err(CellError::AlreadyTerminating) }); - } - let (response_tx, response_rx) = oneshot::channel(); - if self - .command_tx - .send(CellCommand::Terminate { - response_tx: Some(response_tx), - }) - .is_err() - { - self.termination_requested.store(false, Ordering::Relaxed); - return closed_event(); - } - response_event(response_rx) + self.state.request_termination() } pub(crate) fn shutdown(&self) { - self.termination_requested.store(true, Ordering::Relaxed); + self.state.cancellation_token().cancel(); + } +} + +/// The single linearization point for a cell's terminal outcome. +/// +/// Callback cancellation tokens are children of the cell token, so a terminal +/// decision cancels runtime work and its callbacks together. +/// +/// The mutex is held only for synchronous phase transitions and terminal +/// delivery. Runtime execution, observation waits, and callbacks never run +/// while it is held. +pub(crate) struct CellState { + phase: Mutex, + cancellation_token: CancellationToken, +} + +enum CellPhase { + Running, + Terminating { + response_tx: oneshot::Sender>, + }, + Completed(CellEvent), + CompletionClaimed(CellEvent), + Tombstone, +} + +pub(crate) enum CompletionDelivery { + Delivered, + Buffered, + Rejected(Option>>), +} + +/// Result of atomically publishing a completed cell and its session side effects. +#[derive(Debug, PartialEq)] +pub(crate) enum CompletionCommit { + Committed, + Rejected(CellEvent), +} + +pub(crate) enum ObservationDelivery { + Running(oneshot::Sender>), + Delivered, + Buffered, + Closed, +} + +impl CellState { + pub(crate) fn new(cancellation_token: CancellationToken) -> Self { + Self { + phase: Mutex::new(CellPhase::Running), + cancellation_token, + } + } + + pub(crate) fn accepting_observations(&self) -> bool { + let accepting_phase = matches!( + *self + .phase + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner), + CellPhase::Running | CellPhase::Completed(_) + ); + accepting_phase && !self.cancellation_token.is_cancelled() + } + + pub(crate) fn request_termination(&self) -> CellEventFuture { + let mut phase = self + .phase + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + match std::mem::replace(&mut *phase, CellPhase::Tombstone) { + CellPhase::Running => { + let (response_tx, response_rx) = oneshot::channel(); + *phase = CellPhase::Terminating { response_tx }; + self.cancellation_token.cancel(); + response_event(response_rx) + } + CellPhase::Terminating { response_tx } => { + *phase = CellPhase::Terminating { response_tx }; + Box::pin(async { Err(CellError::AlreadyTerminating) }) + } + CellPhase::Completed(event) => { + *phase = CellPhase::CompletionClaimed(event.clone()); + self.cancellation_token.cancel(); + ready_event(event) + } + CellPhase::CompletionClaimed(event) => { + *phase = CellPhase::CompletionClaimed(event); + Box::pin(async { Err(CellError::AlreadyTerminating) }) + } + CellPhase::Tombstone => closed_event(), + } + } + + pub(crate) fn commit_completion( + &self, + event: CellEvent, + commit: impl FnOnce(), + ) -> CompletionCommit { + let mut phase = self + .phase + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + if !matches!(*phase, CellPhase::Running) || self.cancellation_token.is_cancelled() { + return CompletionCommit::Rejected(event); + } + commit(); + *phase = CellPhase::Completed(event); + CompletionCommit::Committed + } + + pub(crate) fn deliver_completion( + &self, + response_tx: Option>>, + ) -> CompletionDelivery { + let mut phase = self + .phase + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let event = match std::mem::replace(&mut *phase, CellPhase::Tombstone) { + CellPhase::Completed(event) => event, + previous => { + *phase = previous; + return CompletionDelivery::Rejected(response_tx); + } + }; + let Some(response_tx) = response_tx else { + *phase = CellPhase::Completed(event); + return CompletionDelivery::Buffered; + }; + match response_tx.send(Ok(event)) { + Ok(()) => { + self.cancellation_token.cancel(); + CompletionDelivery::Delivered + } + Err(Ok(event)) => { + *phase = CellPhase::Completed(event); + CompletionDelivery::Buffered + } + Err(Err(error)) => { + panic!("completion delivery unexpectedly carried an actor error: {error:?}") + } + } + } + + pub(crate) fn route_observation( + &self, + response_tx: oneshot::Sender>, + ) -> ObservationDelivery { + let mut phase = self + .phase + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + match std::mem::replace(&mut *phase, CellPhase::Tombstone) { + CellPhase::Running => { + *phase = CellPhase::Running; + ObservationDelivery::Running(response_tx) + } + CellPhase::Completed(event) => match response_tx.send(Ok(event)) { + Ok(()) => { + self.cancellation_token.cancel(); + ObservationDelivery::Delivered + } + Err(Ok(event)) => { + *phase = CellPhase::Completed(event); + ObservationDelivery::Buffered + } + Err(Err(error)) => { + panic!("completion delivery unexpectedly carried an actor error: {error:?}") + } + }, + CellPhase::Terminating { + response_tx: termination_tx, + } => { + *phase = CellPhase::Terminating { + response_tx: termination_tx, + }; + let _ = response_tx.send(Err(CellError::Closed)); + ObservationDelivery::Closed + } + CellPhase::CompletionClaimed(event) => { + *phase = CellPhase::CompletionClaimed(event); + let _ = response_tx.send(Err(CellError::Closed)); + ObservationDelivery::Closed + } + CellPhase::Tombstone => { + let _ = response_tx.send(Err(CellError::Closed)); + ObservationDelivery::Closed + } + } + } + + pub(crate) fn finish_termination(&self, event: CellEvent) -> Option { + let mut phase = self + .phase + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let observer_event = match std::mem::replace(&mut *phase, CellPhase::Tombstone) { + CellPhase::Running => Some(event), + CellPhase::Terminating { response_tx } => { + let _ = response_tx.send(Ok(event.clone())); + Some(event) + } + CellPhase::Completed(completed_event) => Some(completed_event), + CellPhase::CompletionClaimed(completed_event) => Some(completed_event), + CellPhase::Tombstone => None, + }; self.cancellation_token.cancel(); - let _ = self - .command_tx - .send(CellCommand::Terminate { response_tx: None }); + observer_event + } + + pub(crate) fn tombstone(&self) { + *self + .phase + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = CellPhase::Tombstone; + self.cancellation_token.cancel(); + } + + pub(crate) fn cancellation_token(&self) -> CancellationToken { + self.cancellation_token.clone() } } @@ -125,15 +324,16 @@ pub(super) enum CellCommand { mode: ObserveMode, response_tx: oneshot::Sender>, }, - Terminate { - response_tx: Option>>, - }, } fn response_event(response_rx: oneshot::Receiver>) -> CellEventFuture { Box::pin(async move { response_rx.await.unwrap_or(Err(CellError::Closed)) }) } +fn ready_event(event: CellEvent) -> CellEventFuture { + Box::pin(async move { Ok(event) }) +} + fn closed_event() -> CellEventFuture { Box::pin(async { Err(CellError::Closed) }) } diff --git a/codex-rs/code-mode/src/session_runtime/mod.rs b/codex-rs/code-mode/src/session_runtime/mod.rs index f72aadb10..362dccda7 100644 --- a/codex-rs/code-mode/src/session_runtime/mod.rs +++ b/codex-rs/code-mode/src/session_runtime/mod.rs @@ -29,7 +29,9 @@ use crate::cell_actor::CellError; use crate::cell_actor::CellEventFuture; use crate::cell_actor::CellHandle; use crate::cell_actor::CellHost; +use crate::cell_actor::CellState; use crate::cell_actor::CellToolCall; +use crate::cell_actor::CompletionCommit; type RuntimeEventFuture = Pin> + Send + 'static>>; @@ -165,9 +167,15 @@ impl SessionRuntime { if cells.contains_key(&cell_id) { return Err(Error::DuplicateCell(cell_id)); } - let (handle, initial_event, task) = - CellActor::prepare(request, stored_values, host, initial_observe_mode) - .map_err(Error::Runtime)?; + let cell_state = Arc::new(CellState::new(CancellationToken::new())); + let (handle, initial_event, task) = CellActor::prepare( + request, + stored_values, + host, + initial_observe_mode, + cell_state, + ) + .map_err(Error::Runtime)?; cells.insert(cell_id.clone(), handle); drop(cells); tokio::spawn(task); @@ -251,12 +259,21 @@ impl CellHost for RuntimeCellHost { .await } - async fn commit_stored_values(&self, stored_value_writes: HashMap) { - self.inner - .stored_values - .lock() - .await - .extend(stored_value_writes); + async fn commit_completion( + &self, + stored_value_writes: HashMap, + event: CellEvent, + cell_state: Arc, + ) -> CompletionCommit { + let cancellation_token = cell_state.cancellation_token(); + let mut stored_values = tokio::select! { + biased; + _ = cancellation_token.cancelled() => { + return CompletionCommit::Rejected(event); + } + stored_values = self.inner.stored_values.lock() => stored_values, + }; + cell_state.commit_completion(event, || stored_values.extend(stored_value_writes)) } async fn closed(&self) { @@ -276,3 +293,7 @@ fn actor_error(cell_id: &CellId, error: CellError) -> Error { CellError::Closed => Error::ClosedCell(cell_id.clone()), } } + +#[cfg(test)] +#[path = "tests.rs"] +mod tests; diff --git a/codex-rs/code-mode/src/session_runtime/tests.rs b/codex-rs/code-mode/src/session_runtime/tests.rs new file mode 100644 index 000000000..b4457b3f7 --- /dev/null +++ b/codex-rs/code-mode/src/session_runtime/tests.rs @@ -0,0 +1,110 @@ +use std::collections::HashMap; +use std::future::Future; +use std::sync::Arc; +use std::task::Context; +use std::task::Poll; +use std::task::Waker; +use std::time::Duration; + +use pretty_assertions::assert_eq; +use serde_json::Value as JsonValue; +use tokio_util::sync::CancellationToken; + +use super::*; +use crate::cell_actor::CompletionCommit; + +struct RecordingDelegate; + +impl SessionRuntimeDelegate for RecordingDelegate { + async fn invoke_tool( + &self, + _invocation: NestedToolCall, + _cancellation_token: CancellationToken, + ) -> Result { + Ok(JsonValue::Null) + } + + async fn notify( + &self, + _call_id: String, + _cell_id: CellId, + _text: String, + _cancellation_token: CancellationToken, + ) -> Result<(), String> { + Ok(()) + } + + fn cell_closed(&self, _cell_id: &CellId) {} +} + +#[tokio::test] +async fn termination_rejects_a_waiting_store_commit_before_the_next_cell_can_load_it() { + let runtime = SessionRuntime::new(Arc::new(RecordingDelegate)); + let cell_state = Arc::new(CellState::new(CancellationToken::new())); + let host = RuntimeCellHost { + cell_id: CellId::new("terminating-writer"), + inner: Arc::clone(&runtime.inner), + }; + let completion = CellEvent::Completed { + content_items: vec![OutputItem::Text { + text: "uncommitted output".to_string(), + }], + error_text: None, + }; + + let stored_values = runtime.inner.stored_values.lock().await; + let commit = host.commit_completion( + HashMap::from([( + "candidate".to_string(), + JsonValue::String("lost".to_string()), + )]), + completion.clone(), + Arc::clone(&cell_state), + ); + tokio::pin!(commit); + let waker = Waker::noop(); + let mut context = Context::from_waker(waker); + assert!(matches!(commit.as_mut().poll(&mut context), Poll::Pending)); + + let termination = cell_state.request_termination(); + drop(stored_values); + assert_eq!(commit.await, CompletionCommit::Rejected(completion)); + let terminated = CellEvent::Terminated { + content_items: Vec::new(), + }; + assert_eq!( + cell_state.finish_termination(terminated.clone()), + Some(terminated.clone()) + ); + assert_eq!(termination.await, Ok(terminated)); + assert!( + !runtime + .inner + .stored_values + .lock() + .await + .contains_key("candidate") + ); + + let reader = runtime + .execute( + CreateCellRequest { + tool_call_id: "reader".to_string(), + enabled_tools: Vec::new(), + source: r#"text(String(load("candidate")));"#.to_string(), + }, + ObserveMode::YieldAfter(Duration::from_secs(1)), + ) + .await + .unwrap(); + assert_eq!( + reader.initial_event().await, + Ok(CellEvent::Completed { + content_items: vec![OutputItem::Text { + text: "undefined".to_string(), + }], + error_text: None, + }) + ); + runtime.shutdown().await.unwrap(); +}