code-mode: linearize cell terminal state (#29286)

## Summary

- Introduce a single cell terminal-state machine for completion and
termination.
- Make stored-value commits atomic with the winning terminal outcome.
- Buffer terminal results for later observation and cover
termination-before-commit behavior.

## Why

Completion, termination, observation, and stored-value updates must
agree on one linearized outcome under cancellation races.

## Impact

Terminal delivery becomes deterministic and terminated cells cannot
commit state after termination wins.

## Validation

- Focused terminal-state regression passed.
- Stack-tip validation: `just test -p codex-code-mode -p
codex-code-mode-protocol` (70 passed).
- Parent branch:
`cconger/code-mode-runtime-compact-03b-session-runtime`.
This commit is contained in:
Channing Conger
2026-06-21 12:05:24 -07:00
committed by GitHub
Unverified
parent 63f009e9da
commit f774455c3a
5 changed files with 640 additions and 193 deletions
+165 -140
View File
@@ -25,7 +25,11 @@ pub(crate) use self::types::CellError;
pub(crate) use self::types::CellEventFuture;
pub(crate) use self::types::CellHandle;
pub(crate) use self::types::CellHost;
pub(crate) use self::types::CellState;
pub(crate) use self::types::CellToolCall;
pub(crate) use self::types::CompletionCommit;
use self::types::CompletionDelivery;
use self::types::ObservationDelivery;
use crate::runtime::PendingRuntimeMode;
use crate::runtime::RuntimeCommand;
use crate::runtime::RuntimeControlCommand;
@@ -34,6 +38,7 @@ use crate::runtime::spawn_runtime;
use crate::session_runtime::CellEvent;
use crate::session_runtime::CreateCellRequest as CellRequest;
use crate::session_runtime::ObserveMode;
use crate::session_runtime::OutputItem;
use crate::session_runtime::ToolName as CellToolName;
pub(crate) struct CellActor;
@@ -44,6 +49,7 @@ impl CellActor {
stored_values: HashMap<String, JsonValue>,
host: Arc<H>,
initial_observe_mode: ObserveMode,
cell_state: Arc<CellState>,
) -> Result<
(
CellHandle,
@@ -61,15 +67,14 @@ impl CellActor {
event_tx,
PendingRuntimeMode::PauseUntilResumed,
)?;
let cancellation_token = CancellationToken::new();
let handle = CellHandle::new(command_tx, cancellation_token.clone());
let handle = CellHandle::new(command_tx, Arc::clone(&cell_state));
let task = run_cell(
host,
CellContext {
runtime_tx,
runtime_control_tx,
runtime_terminate_handle,
cancellation_token,
cell_state,
},
event_rx,
command_rx,
@@ -88,7 +93,7 @@ struct CellContext {
runtime_tx: std::sync::mpsc::Sender<RuntimeCommand>,
runtime_control_tx: std::sync::mpsc::Sender<RuntimeControlCommand>,
runtime_terminate_handle: v8::IsolateHandle,
cancellation_token: CancellationToken,
cell_state: Arc<CellState>,
}
struct Observer {
@@ -96,115 +101,98 @@ struct Observer {
response_tx: oneshot::Sender<Result<CellEvent, CellError>>,
}
struct Termination {
response_tx: Option<oneshot::Sender<Result<CellEvent, CellError>>>,
}
async fn run_cell<H: CellHost>(
host: Arc<H>,
context: CellContext,
mut event_rx: mpsc::UnboundedReceiver<RuntimeEvent>,
mut command_rx: mpsc::UnboundedReceiver<CellCommand>,
command_rx: mpsc::UnboundedReceiver<CellCommand>,
initial_observer: Observer,
) {
let CellContext {
runtime_tx,
runtime_control_tx,
runtime_terminate_handle,
cancellation_token,
cell_state,
} = context;
let cancellation_token = cell_state.cancellation_token();
let callback_cancellation_token = cancellation_token.child_token();
let mut content_items = Vec::new();
let mut pending_tool_call_ids = Vec::new();
let mut completed_event = None;
let mut observer = Some(initial_observer);
let mut termination: Option<Termination> = None;
let mut termination = false;
let mut runtime_closed = false;
let mut runtime_paused = false;
let mut yield_timer: Option<std::pin::Pin<Box<tokio::time::Sleep>>> = None;
let mut notification_tasks = JoinSet::new();
let mut tool_tasks = JoinSet::new();
let mut command_rx = Some(command_rx);
loop {
let yield_deadline_elapsed = yield_timer
.as_ref()
.is_some_and(|yield_timer| yield_timer.deadline() <= tokio::time::Instant::now());
tokio::select! {
biased;
maybe_command = command_rx.recv() => {
let Some(command) = maybe_command else {
if completed_event.is_some() {
break;
}
termination = Some(Termination { response_tx: None });
begin_termination(
&runtime_tx,
&runtime_control_tx,
&runtime_terminate_handle,
&cancellation_token,
_ = cancellation_token.cancelled(), if !termination => {
termination = true;
yield_timer = None;
drop(command_rx.take());
begin_termination(
&runtime_tx,
&runtime_control_tx,
&runtime_terminate_handle,
&cancellation_token,
);
if runtime_closed {
finish_callbacks(
&callback_cancellation_token,
&mut notification_tasks,
&mut tool_tasks,
CallbackCompletion::Cancel,
).await;
finish_termination(
&cell_state,
observer.take().map(|observer| observer.response_tx),
CellEvent::Terminated {
content_items: std::mem::take(&mut content_items),
},
);
if runtime_closed {
break;
}
break;
}
}
maybe_command = async {
match command_rx.as_mut() {
Some(command_rx) => command_rx.recv().await,
None => std::future::pending::<Option<CellCommand>>().await,
}
} => {
let Some(CellCommand::Observe { mode, response_tx }) = maybe_command else {
cancellation_token.cancel();
continue;
};
match command {
CellCommand::Observe { mode, response_tx } => {
if let Some(event) = completed_event.take() {
let _ = response_tx.send(Ok(event));
break;
}
if observer.is_some() || termination.is_some() {
let _ = response_tx.send(Err(CellError::Busy));
continue;
}
observer = Some(Observer { mode, response_tx });
yield_timer = observer.as_ref().and_then(observer_timer);
resume_for_observation(
mode,
&mut runtime_paused,
&runtime_tx,
&runtime_control_tx,
);
}
CellCommand::Terminate { response_tx } => {
if let Some(event) = completed_event.take() {
if let Some(response_tx) = response_tx {
let _ = response_tx.send(Ok(event));
}
break;
}
if termination.is_some() {
if let Some(response_tx) = response_tx {
let _ = response_tx.send(Err(CellError::AlreadyTerminating));
}
continue;
}
termination = Some(Termination { response_tx });
yield_timer = None;
begin_termination(
&runtime_tx,
&runtime_control_tx,
&runtime_terminate_handle,
&cancellation_token,
);
if runtime_closed {
finish_callbacks(
&cancellation_token,
&mut notification_tasks,
&mut tool_tasks,
CallbackCompletion::Cancel,
).await;
send_termination_events(
observer.take(),
termination.take(),
CellEvent::Terminated {
content_items: std::mem::take(&mut content_items),
},
);
break;
}
}
let response_tx = match cell_state.route_observation(response_tx) {
ObservationDelivery::Running(response_tx) => response_tx,
ObservationDelivery::Delivered => break,
ObservationDelivery::Buffered | ObservationDelivery::Closed => continue,
};
if observer
.as_ref()
.is_some_and(|observer| observer.response_tx.is_closed())
{
observer = None;
yield_timer = None;
}
if observer.is_some() || termination {
let _ = response_tx.send(Err(CellError::Busy));
continue;
}
observer = Some(Observer { mode, response_tx });
yield_timer = observer.as_ref().and_then(observer_timer);
resume_for_observation(
mode,
&mut runtime_paused,
&runtime_tx,
&runtime_control_tx,
);
}
_ = async {
if let Some(yield_timer) = yield_timer.as_mut() {
@@ -230,38 +218,57 @@ async fn run_cell<H: CellHost>(
}, if !yield_deadline_elapsed => {
let Some(event) = maybe_event else {
runtime_closed = true;
if termination.is_some() {
if termination || cancellation_token.is_cancelled() {
finish_callbacks(
&cancellation_token,
&callback_cancellation_token,
&mut notification_tasks,
&mut tool_tasks,
CallbackCompletion::Cancel,
).await;
send_termination_events(
observer.take(),
termination.take(),
finish_termination(
&cell_state,
observer.take().map(|observer| observer.response_tx),
CellEvent::Terminated {
content_items: std::mem::take(&mut content_items),
},
);
break;
}
if completed_event.is_none() {
finish_callbacks(
&cancellation_token,
&mut notification_tasks,
&mut tool_tasks,
CallbackCompletion::DrainNotifications,
).await;
let event = CellEvent::Completed {
content_items: std::mem::take(&mut content_items),
error_text: Some("exec runtime ended unexpectedly".to_string()),
};
if send_or_buffer_completion(
finish_callbacks(
&callback_cancellation_token,
&mut notification_tasks,
&mut tool_tasks,
CallbackCompletion::DrainNotifications,
)
.await;
let event = CellEvent::Completed {
content_items: std::mem::take(&mut content_items),
error_text: Some("exec runtime ended unexpectedly".to_string()),
};
let rejected_event = match host
.commit_completion(
HashMap::new(),
event,
&mut observer,
&mut completed_event,
) {
Arc::clone(&cell_state),
)
.await
{
CompletionCommit::Committed => None,
CompletionCommit::Rejected(event) => Some(event),
};
match cell_state.deliver_completion(
observer.take().map(|observer| observer.response_tx),
) {
CompletionDelivery::Delivered => break,
CompletionDelivery::Buffered => {}
CompletionDelivery::Rejected(response_tx) => {
finish_termination(
&cell_state,
response_tx,
CellEvent::Terminated {
content_items: rejected_completion_content(rejected_event),
},
);
break;
}
}
@@ -314,7 +321,7 @@ async fn run_cell<H: CellHost>(
Arc::clone(&host),
call_id,
text,
cancellation_token.child_token(),
callback_cancellation_token.child_token(),
);
}
RuntimeEvent::ToolCall { id, name, kind, input } => {
@@ -332,22 +339,22 @@ async fn run_cell<H: CellHost>(
input,
},
runtime_tx.clone(),
cancellation_token.child_token(),
callback_cancellation_token.child_token(),
);
}
RuntimeEvent::Result { stored_value_writes, error_text } => {
runtime_closed = true;
yield_timer = None;
if termination.is_some() {
if termination || cancellation_token.is_cancelled() {
finish_callbacks(
&cancellation_token,
&callback_cancellation_token,
&mut notification_tasks,
&mut tool_tasks,
CallbackCompletion::Cancel,
).await;
send_termination_events(
observer.take(),
termination.take(),
finish_termination(
&cell_state,
observer.take().map(|observer| observer.response_tx),
CellEvent::Terminated {
content_items: std::mem::take(&mut content_items),
},
@@ -355,22 +362,42 @@ async fn run_cell<H: CellHost>(
break;
}
finish_callbacks(
&cancellation_token,
&callback_cancellation_token,
&mut notification_tasks,
&mut tool_tasks,
CallbackCompletion::DrainNotifications,
).await;
host.commit_stored_values(stored_value_writes).await;
)
.await;
let event = CellEvent::Completed {
content_items: std::mem::take(&mut content_items),
error_text,
};
if send_or_buffer_completion(
event,
&mut observer,
&mut completed_event,
let rejected_event = match host
.commit_completion(
stored_value_writes,
event,
Arc::clone(&cell_state),
)
.await
{
CompletionCommit::Committed => None,
CompletionCommit::Rejected(event) => Some(event),
};
match cell_state.deliver_completion(
observer.take().map(|observer| observer.response_tx),
) {
break;
CompletionDelivery::Delivered => break,
CompletionDelivery::Buffered => {}
CompletionDelivery::Rejected(response_tx) => {
finish_termination(
&cell_state,
response_tx,
CellEvent::Terminated {
content_items: rejected_completion_content(rejected_event),
},
);
break;
}
}
}
}
@@ -383,6 +410,9 @@ async fn run_cell<H: CellHost>(
}
}
}
// Reject requests that arrive while asynchronous terminal cleanup runs.
cell_state.tombstone();
drop(command_rx.take());
begin_termination(
&runtime_tx,
&runtime_control_tx,
@@ -390,7 +420,7 @@ async fn run_cell<H: CellHost>(
&cancellation_token,
);
finish_callbacks(
&cancellation_token,
&callback_cancellation_token,
&mut notification_tasks,
&mut tool_tasks,
CallbackCompletion::Cancel,
@@ -399,34 +429,29 @@ async fn run_cell<H: CellHost>(
host.closed().await;
}
fn send_or_buffer_completion(
event: CellEvent,
observer: &mut Option<Observer>,
completed_event: &mut Option<CellEvent>,
) -> bool {
if observer.is_some() {
send_observer_event(observer.take(), event);
true
} else {
*completed_event = Some(event);
false
}
}
fn send_observer_event(observer: Option<Observer>, event: CellEvent) {
if let Some(observer) = observer {
let _ = observer.response_tx.send(Ok(event));
}
}
fn send_termination_events(
observer: Option<Observer>,
termination: Option<Termination>,
fn rejected_completion_content(event: Option<CellEvent>) -> Vec<OutputItem> {
match event {
Some(CellEvent::Completed { content_items, .. }) => content_items,
None => Vec::new(),
Some(event) => panic!("completion commit rejected an unexpected event: {event:?}"),
}
}
fn finish_termination(
cell_state: &CellState,
observer_tx: Option<oneshot::Sender<Result<CellEvent, CellError>>>,
event: CellEvent,
) {
send_observer_event(observer, event.clone());
if let Some(response_tx) = termination.and_then(|termination| termination.response_tx) {
let _ = response_tx.send(Ok(event));
if let Some(event) = cell_state.finish_termination(event)
&& let Some(observer_tx) = observer_tx
{
let _ = observer_tx.send(Ok(event));
}
}
+94 -3
View File
@@ -33,7 +33,14 @@ impl CellHost for TestHost {
Ok(())
}
async fn commit_stored_values(&self, _stored_value_writes: HashMap<String, JsonValue>) {}
async fn commit_completion(
&self,
_stored_value_writes: HashMap<String, JsonValue>,
event: CellEvent,
cell_state: Arc<CellState>,
) -> CompletionCommit {
cell_state.commit_completion(event, || {})
}
async fn closed(&self) {}
}
@@ -64,14 +71,15 @@ fn spawn_cell_actor_harness(initial_observe_mode: ObserveMode) -> CellActorHarne
PendingRuntimeMode::PauseUntilResumed,
)
.unwrap();
let handle = CellHandle::new(command_tx, CancellationToken::new());
let cell_state = Arc::new(CellState::new(CancellationToken::new()));
let handle = CellHandle::new(command_tx, Arc::clone(&cell_state));
let task = tokio::spawn(run_cell(
Arc::new(TestHost),
CellContext {
runtime_tx,
runtime_control_tx,
runtime_terminate_handle,
cancellation_token: CancellationToken::new(),
cell_state,
},
event_rx,
command_rx,
@@ -142,3 +150,86 @@ async fn queued_termination_preempts_unobserved_runtime_completion() {
assert_eq!(harness.initial_event_rx.await.unwrap(), terminated);
harness.task.await.unwrap();
}
#[tokio::test]
async fn only_the_first_termination_claims_a_buffered_completion() {
let cell_state = CellState::new(CancellationToken::new());
let completion = CellEvent::Completed {
content_items: Vec::new(),
error_text: None,
};
assert_eq!(
cell_state.commit_completion(completion.clone(), || {}),
CompletionCommit::Committed
);
assert!(matches!(
cell_state.deliver_completion(/*response_tx*/ None),
CompletionDelivery::Buffered
));
let first_termination = cell_state.request_termination();
assert_eq!(
cell_state.request_termination().await,
Err(CellError::AlreadyTerminating)
);
assert_eq!(first_termination.await, Ok(completion.clone()));
assert_eq!(
cell_state.finish_termination(CellEvent::Terminated {
content_items: Vec::new(),
}),
Some(completion)
);
}
#[tokio::test]
async fn termination_claim_prevents_stored_value_commit() {
let cell_state = CellState::new(CancellationToken::new());
let termination = cell_state.request_termination();
let mut commit_ran = false;
let completion = CellEvent::Completed {
content_items: Vec::new(),
error_text: None,
};
assert_eq!(
cell_state.commit_completion(completion.clone(), || commit_ran = true),
CompletionCommit::Rejected(completion)
);
assert!(!commit_ran);
let terminated = CellEvent::Terminated {
content_items: Vec::new(),
};
assert_eq!(
cell_state.finish_termination(terminated.clone()),
Some(terminated.clone())
);
assert_eq!(termination.await, Ok(terminated));
}
#[test]
fn failed_completion_delivery_rebuffers_the_event() {
let cell_state = CellState::new(CancellationToken::new());
let event = CellEvent::Completed {
content_items: Vec::new(),
error_text: None,
};
assert_eq!(
cell_state.commit_completion(event.clone(), || {}),
CompletionCommit::Committed
);
let (response_tx, response_rx) = oneshot::channel();
drop(response_rx);
assert!(matches!(
cell_state.deliver_completion(Some(response_tx)),
CompletionDelivery::Buffered
));
assert!(cell_state.accepting_observations());
let (response_tx, mut response_rx) = oneshot::channel();
assert!(matches!(
cell_state.route_observation(response_tx),
ObservationDelivery::Delivered
));
assert_eq!(response_rx.try_recv(), Ok(Ok(event)));
}
+241 -41
View File
@@ -2,8 +2,7 @@ use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::sync::Mutex;
use serde_json::Value as JsonValue;
use tokio::sync::mpsc;
@@ -32,10 +31,11 @@ pub(crate) struct CellToolCall {
pub(crate) input: Option<JsonValue>,
}
/// Connects a cell actor to session-owned callbacks and lifecycle state.
/// Connects a cell actor to session-owned callbacks and stored values.
///
/// Implementations must honor callback cancellation and must not return from
/// `closed` until the session can no longer route requests to the cell.
/// Implementations should forward callback cancellation to downstream work.
/// Implementations must not return from `closed` until the session can no longer
/// route requests to the cell.
pub(crate) trait CellHost: Send + Sync + 'static {
fn invoke_tool(
&self,
@@ -50,10 +50,12 @@ pub(crate) trait CellHost: Send + Sync + 'static {
cancellation_token: CancellationToken,
) -> impl Future<Output = Result<(), String>> + Send;
fn commit_stored_values(
fn commit_completion(
&self,
stored_value_writes: HashMap<String, JsonValue>,
) -> impl Future<Output = ()> + Send;
event: CellEvent,
cell_state: Arc<CellState>,
) -> impl Future<Output = CompletionCommit> + Send;
fn closed(&self) -> impl Future<Output = ()> + Send;
}
@@ -61,23 +63,21 @@ pub(crate) trait CellHost: Send + Sync + 'static {
#[derive(Clone)]
pub(crate) struct CellHandle {
command_tx: mpsc::UnboundedSender<CellCommand>,
cancellation_token: CancellationToken,
termination_requested: Arc<AtomicBool>,
state: Arc<CellState>,
}
impl CellHandle {
pub(super) fn new(
command_tx: mpsc::UnboundedSender<CellCommand>,
cancellation_token: CancellationToken,
state: Arc<CellState>,
) -> Self {
Self {
command_tx,
cancellation_token,
termination_requested: Arc::new(AtomicBool::new(false)),
}
Self { command_tx, state }
}
pub(crate) fn observe(&self, mode: ObserveMode) -> CellEventFuture {
if !self.state.accepting_observations() {
return closed_event();
}
let (response_tx, response_rx) = oneshot::channel();
if self
.command_tx
@@ -90,33 +90,232 @@ impl CellHandle {
}
pub(crate) fn terminate(&self) -> CellEventFuture {
if self
.termination_requested
.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
.is_err()
{
return Box::pin(async { Err(CellError::AlreadyTerminating) });
}
let (response_tx, response_rx) = oneshot::channel();
if self
.command_tx
.send(CellCommand::Terminate {
response_tx: Some(response_tx),
})
.is_err()
{
self.termination_requested.store(false, Ordering::Relaxed);
return closed_event();
}
response_event(response_rx)
self.state.request_termination()
}
pub(crate) fn shutdown(&self) {
self.termination_requested.store(true, Ordering::Relaxed);
self.state.cancellation_token().cancel();
}
}
/// The single linearization point for a cell's terminal outcome.
///
/// Callback cancellation tokens are children of the cell token, so a terminal
/// decision cancels runtime work and its callbacks together.
///
/// The mutex is held only for synchronous phase transitions and terminal
/// delivery. Runtime execution, observation waits, and callbacks never run
/// while it is held.
pub(crate) struct CellState {
phase: Mutex<CellPhase>,
cancellation_token: CancellationToken,
}
enum CellPhase {
Running,
Terminating {
response_tx: oneshot::Sender<Result<CellEvent, CellError>>,
},
Completed(CellEvent),
CompletionClaimed(CellEvent),
Tombstone,
}
pub(crate) enum CompletionDelivery {
Delivered,
Buffered,
Rejected(Option<oneshot::Sender<Result<CellEvent, CellError>>>),
}
/// Result of atomically publishing a completed cell and its session side effects.
#[derive(Debug, PartialEq)]
pub(crate) enum CompletionCommit {
Committed,
Rejected(CellEvent),
}
pub(crate) enum ObservationDelivery {
Running(oneshot::Sender<Result<CellEvent, CellError>>),
Delivered,
Buffered,
Closed,
}
impl CellState {
pub(crate) fn new(cancellation_token: CancellationToken) -> Self {
Self {
phase: Mutex::new(CellPhase::Running),
cancellation_token,
}
}
pub(crate) fn accepting_observations(&self) -> bool {
let accepting_phase = matches!(
*self
.phase
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner),
CellPhase::Running | CellPhase::Completed(_)
);
accepting_phase && !self.cancellation_token.is_cancelled()
}
pub(crate) fn request_termination(&self) -> CellEventFuture {
let mut phase = self
.phase
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
match std::mem::replace(&mut *phase, CellPhase::Tombstone) {
CellPhase::Running => {
let (response_tx, response_rx) = oneshot::channel();
*phase = CellPhase::Terminating { response_tx };
self.cancellation_token.cancel();
response_event(response_rx)
}
CellPhase::Terminating { response_tx } => {
*phase = CellPhase::Terminating { response_tx };
Box::pin(async { Err(CellError::AlreadyTerminating) })
}
CellPhase::Completed(event) => {
*phase = CellPhase::CompletionClaimed(event.clone());
self.cancellation_token.cancel();
ready_event(event)
}
CellPhase::CompletionClaimed(event) => {
*phase = CellPhase::CompletionClaimed(event);
Box::pin(async { Err(CellError::AlreadyTerminating) })
}
CellPhase::Tombstone => closed_event(),
}
}
pub(crate) fn commit_completion(
&self,
event: CellEvent,
commit: impl FnOnce(),
) -> CompletionCommit {
let mut phase = self
.phase
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if !matches!(*phase, CellPhase::Running) || self.cancellation_token.is_cancelled() {
return CompletionCommit::Rejected(event);
}
commit();
*phase = CellPhase::Completed(event);
CompletionCommit::Committed
}
pub(crate) fn deliver_completion(
&self,
response_tx: Option<oneshot::Sender<Result<CellEvent, CellError>>>,
) -> CompletionDelivery {
let mut phase = self
.phase
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let event = match std::mem::replace(&mut *phase, CellPhase::Tombstone) {
CellPhase::Completed(event) => event,
previous => {
*phase = previous;
return CompletionDelivery::Rejected(response_tx);
}
};
let Some(response_tx) = response_tx else {
*phase = CellPhase::Completed(event);
return CompletionDelivery::Buffered;
};
match response_tx.send(Ok(event)) {
Ok(()) => {
self.cancellation_token.cancel();
CompletionDelivery::Delivered
}
Err(Ok(event)) => {
*phase = CellPhase::Completed(event);
CompletionDelivery::Buffered
}
Err(Err(error)) => {
panic!("completion delivery unexpectedly carried an actor error: {error:?}")
}
}
}
pub(crate) fn route_observation(
&self,
response_tx: oneshot::Sender<Result<CellEvent, CellError>>,
) -> ObservationDelivery {
let mut phase = self
.phase
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
match std::mem::replace(&mut *phase, CellPhase::Tombstone) {
CellPhase::Running => {
*phase = CellPhase::Running;
ObservationDelivery::Running(response_tx)
}
CellPhase::Completed(event) => match response_tx.send(Ok(event)) {
Ok(()) => {
self.cancellation_token.cancel();
ObservationDelivery::Delivered
}
Err(Ok(event)) => {
*phase = CellPhase::Completed(event);
ObservationDelivery::Buffered
}
Err(Err(error)) => {
panic!("completion delivery unexpectedly carried an actor error: {error:?}")
}
},
CellPhase::Terminating {
response_tx: termination_tx,
} => {
*phase = CellPhase::Terminating {
response_tx: termination_tx,
};
let _ = response_tx.send(Err(CellError::Closed));
ObservationDelivery::Closed
}
CellPhase::CompletionClaimed(event) => {
*phase = CellPhase::CompletionClaimed(event);
let _ = response_tx.send(Err(CellError::Closed));
ObservationDelivery::Closed
}
CellPhase::Tombstone => {
let _ = response_tx.send(Err(CellError::Closed));
ObservationDelivery::Closed
}
}
}
pub(crate) fn finish_termination(&self, event: CellEvent) -> Option<CellEvent> {
let mut phase = self
.phase
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let observer_event = match std::mem::replace(&mut *phase, CellPhase::Tombstone) {
CellPhase::Running => Some(event),
CellPhase::Terminating { response_tx } => {
let _ = response_tx.send(Ok(event.clone()));
Some(event)
}
CellPhase::Completed(completed_event) => Some(completed_event),
CellPhase::CompletionClaimed(completed_event) => Some(completed_event),
CellPhase::Tombstone => None,
};
self.cancellation_token.cancel();
let _ = self
.command_tx
.send(CellCommand::Terminate { response_tx: None });
observer_event
}
pub(crate) fn tombstone(&self) {
*self
.phase
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = CellPhase::Tombstone;
self.cancellation_token.cancel();
}
pub(crate) fn cancellation_token(&self) -> CancellationToken {
self.cancellation_token.clone()
}
}
@@ -125,15 +324,16 @@ pub(super) enum CellCommand {
mode: ObserveMode,
response_tx: oneshot::Sender<Result<CellEvent, CellError>>,
},
Terminate {
response_tx: Option<oneshot::Sender<Result<CellEvent, CellError>>>,
},
}
fn response_event(response_rx: oneshot::Receiver<Result<CellEvent, CellError>>) -> CellEventFuture {
Box::pin(async move { response_rx.await.unwrap_or(Err(CellError::Closed)) })
}
fn ready_event(event: CellEvent) -> CellEventFuture {
Box::pin(async move { Ok(event) })
}
fn closed_event() -> CellEventFuture {
Box::pin(async { Err(CellError::Closed) })
}
+30 -9
View File
@@ -29,7 +29,9 @@ use crate::cell_actor::CellError;
use crate::cell_actor::CellEventFuture;
use crate::cell_actor::CellHandle;
use crate::cell_actor::CellHost;
use crate::cell_actor::CellState;
use crate::cell_actor::CellToolCall;
use crate::cell_actor::CompletionCommit;
type RuntimeEventFuture = Pin<Box<dyn Future<Output = Result<CellEvent, Error>> + Send + 'static>>;
@@ -165,9 +167,15 @@ impl<D: SessionRuntimeDelegate> SessionRuntime<D> {
if cells.contains_key(&cell_id) {
return Err(Error::DuplicateCell(cell_id));
}
let (handle, initial_event, task) =
CellActor::prepare(request, stored_values, host, initial_observe_mode)
.map_err(Error::Runtime)?;
let cell_state = Arc::new(CellState::new(CancellationToken::new()));
let (handle, initial_event, task) = CellActor::prepare(
request,
stored_values,
host,
initial_observe_mode,
cell_state,
)
.map_err(Error::Runtime)?;
cells.insert(cell_id.clone(), handle);
drop(cells);
tokio::spawn(task);
@@ -251,12 +259,21 @@ impl<D: SessionRuntimeDelegate> CellHost for RuntimeCellHost<D> {
.await
}
async fn commit_stored_values(&self, stored_value_writes: HashMap<String, JsonValue>) {
self.inner
.stored_values
.lock()
.await
.extend(stored_value_writes);
async fn commit_completion(
&self,
stored_value_writes: HashMap<String, JsonValue>,
event: CellEvent,
cell_state: Arc<CellState>,
) -> CompletionCommit {
let cancellation_token = cell_state.cancellation_token();
let mut stored_values = tokio::select! {
biased;
_ = cancellation_token.cancelled() => {
return CompletionCommit::Rejected(event);
}
stored_values = self.inner.stored_values.lock() => stored_values,
};
cell_state.commit_completion(event, || stored_values.extend(stored_value_writes))
}
async fn closed(&self) {
@@ -276,3 +293,7 @@ fn actor_error(cell_id: &CellId, error: CellError) -> Error {
CellError::Closed => Error::ClosedCell(cell_id.clone()),
}
}
#[cfg(test)]
#[path = "tests.rs"]
mod tests;
@@ -0,0 +1,110 @@
use std::collections::HashMap;
use std::future::Future;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use std::task::Waker;
use std::time::Duration;
use pretty_assertions::assert_eq;
use serde_json::Value as JsonValue;
use tokio_util::sync::CancellationToken;
use super::*;
use crate::cell_actor::CompletionCommit;
struct RecordingDelegate;
impl SessionRuntimeDelegate for RecordingDelegate {
async fn invoke_tool(
&self,
_invocation: NestedToolCall,
_cancellation_token: CancellationToken,
) -> Result<JsonValue, String> {
Ok(JsonValue::Null)
}
async fn notify(
&self,
_call_id: String,
_cell_id: CellId,
_text: String,
_cancellation_token: CancellationToken,
) -> Result<(), String> {
Ok(())
}
fn cell_closed(&self, _cell_id: &CellId) {}
}
#[tokio::test]
async fn termination_rejects_a_waiting_store_commit_before_the_next_cell_can_load_it() {
let runtime = SessionRuntime::new(Arc::new(RecordingDelegate));
let cell_state = Arc::new(CellState::new(CancellationToken::new()));
let host = RuntimeCellHost {
cell_id: CellId::new("terminating-writer"),
inner: Arc::clone(&runtime.inner),
};
let completion = CellEvent::Completed {
content_items: vec![OutputItem::Text {
text: "uncommitted output".to_string(),
}],
error_text: None,
};
let stored_values = runtime.inner.stored_values.lock().await;
let commit = host.commit_completion(
HashMap::from([(
"candidate".to_string(),
JsonValue::String("lost".to_string()),
)]),
completion.clone(),
Arc::clone(&cell_state),
);
tokio::pin!(commit);
let waker = Waker::noop();
let mut context = Context::from_waker(waker);
assert!(matches!(commit.as_mut().poll(&mut context), Poll::Pending));
let termination = cell_state.request_termination();
drop(stored_values);
assert_eq!(commit.await, CompletionCommit::Rejected(completion));
let terminated = CellEvent::Terminated {
content_items: Vec::new(),
};
assert_eq!(
cell_state.finish_termination(terminated.clone()),
Some(terminated.clone())
);
assert_eq!(termination.await, Ok(terminated));
assert!(
!runtime
.inner
.stored_values
.lock()
.await
.contains_key("candidate")
);
let reader = runtime
.execute(
CreateCellRequest {
tool_call_id: "reader".to_string(),
enabled_tools: Vec::new(),
source: r#"text(String(load("candidate")));"#.to_string(),
},
ObserveMode::YieldAfter(Duration::from_secs(1)),
)
.await
.unwrap();
assert_eq!(
reader.initial_event().await,
Ok(CellEvent::Completed {
content_items: vec![OutputItem::Text {
text: "undefined".to_string(),
}],
error_text: None,
})
);
runtime.shutdown().await.unwrap();
}