mirror of
https://github.com/pchuan98/codex.git
synced 2026-07-01 00:31:56 +08:00
code-mode: linearize cell terminal state (#29286)
## Summary - Introduce a single cell terminal-state machine for completion and termination. - Make stored-value commits atomic with the winning terminal outcome. - Buffer terminal results for later observation and cover termination-before-commit behavior. ## Why Completion, termination, observation, and stored-value updates must agree on one linearized outcome under cancellation races. ## Impact Terminal delivery becomes deterministic and terminated cells cannot commit state after termination wins. ## Validation - Focused terminal-state regression passed. - Stack-tip validation: `just test -p codex-code-mode -p codex-code-mode-protocol` (70 passed). - Parent branch: `cconger/code-mode-runtime-compact-03b-session-runtime`.
This commit is contained in:
committed by
GitHub
Unverified
parent
63f009e9da
commit
f774455c3a
@@ -25,7 +25,11 @@ pub(crate) use self::types::CellError;
|
||||
pub(crate) use self::types::CellEventFuture;
|
||||
pub(crate) use self::types::CellHandle;
|
||||
pub(crate) use self::types::CellHost;
|
||||
pub(crate) use self::types::CellState;
|
||||
pub(crate) use self::types::CellToolCall;
|
||||
pub(crate) use self::types::CompletionCommit;
|
||||
use self::types::CompletionDelivery;
|
||||
use self::types::ObservationDelivery;
|
||||
use crate::runtime::PendingRuntimeMode;
|
||||
use crate::runtime::RuntimeCommand;
|
||||
use crate::runtime::RuntimeControlCommand;
|
||||
@@ -34,6 +38,7 @@ use crate::runtime::spawn_runtime;
|
||||
use crate::session_runtime::CellEvent;
|
||||
use crate::session_runtime::CreateCellRequest as CellRequest;
|
||||
use crate::session_runtime::ObserveMode;
|
||||
use crate::session_runtime::OutputItem;
|
||||
use crate::session_runtime::ToolName as CellToolName;
|
||||
|
||||
pub(crate) struct CellActor;
|
||||
@@ -44,6 +49,7 @@ impl CellActor {
|
||||
stored_values: HashMap<String, JsonValue>,
|
||||
host: Arc<H>,
|
||||
initial_observe_mode: ObserveMode,
|
||||
cell_state: Arc<CellState>,
|
||||
) -> Result<
|
||||
(
|
||||
CellHandle,
|
||||
@@ -61,15 +67,14 @@ impl CellActor {
|
||||
event_tx,
|
||||
PendingRuntimeMode::PauseUntilResumed,
|
||||
)?;
|
||||
let cancellation_token = CancellationToken::new();
|
||||
let handle = CellHandle::new(command_tx, cancellation_token.clone());
|
||||
let handle = CellHandle::new(command_tx, Arc::clone(&cell_state));
|
||||
let task = run_cell(
|
||||
host,
|
||||
CellContext {
|
||||
runtime_tx,
|
||||
runtime_control_tx,
|
||||
runtime_terminate_handle,
|
||||
cancellation_token,
|
||||
cell_state,
|
||||
},
|
||||
event_rx,
|
||||
command_rx,
|
||||
@@ -88,7 +93,7 @@ struct CellContext {
|
||||
runtime_tx: std::sync::mpsc::Sender<RuntimeCommand>,
|
||||
runtime_control_tx: std::sync::mpsc::Sender<RuntimeControlCommand>,
|
||||
runtime_terminate_handle: v8::IsolateHandle,
|
||||
cancellation_token: CancellationToken,
|
||||
cell_state: Arc<CellState>,
|
||||
}
|
||||
|
||||
struct Observer {
|
||||
@@ -96,115 +101,98 @@ struct Observer {
|
||||
response_tx: oneshot::Sender<Result<CellEvent, CellError>>,
|
||||
}
|
||||
|
||||
struct Termination {
|
||||
response_tx: Option<oneshot::Sender<Result<CellEvent, CellError>>>,
|
||||
}
|
||||
|
||||
async fn run_cell<H: CellHost>(
|
||||
host: Arc<H>,
|
||||
context: CellContext,
|
||||
mut event_rx: mpsc::UnboundedReceiver<RuntimeEvent>,
|
||||
mut command_rx: mpsc::UnboundedReceiver<CellCommand>,
|
||||
command_rx: mpsc::UnboundedReceiver<CellCommand>,
|
||||
initial_observer: Observer,
|
||||
) {
|
||||
let CellContext {
|
||||
runtime_tx,
|
||||
runtime_control_tx,
|
||||
runtime_terminate_handle,
|
||||
cancellation_token,
|
||||
cell_state,
|
||||
} = context;
|
||||
let cancellation_token = cell_state.cancellation_token();
|
||||
let callback_cancellation_token = cancellation_token.child_token();
|
||||
let mut content_items = Vec::new();
|
||||
let mut pending_tool_call_ids = Vec::new();
|
||||
let mut completed_event = None;
|
||||
let mut observer = Some(initial_observer);
|
||||
let mut termination: Option<Termination> = None;
|
||||
let mut termination = false;
|
||||
let mut runtime_closed = false;
|
||||
let mut runtime_paused = false;
|
||||
let mut yield_timer: Option<std::pin::Pin<Box<tokio::time::Sleep>>> = None;
|
||||
let mut notification_tasks = JoinSet::new();
|
||||
let mut tool_tasks = JoinSet::new();
|
||||
|
||||
let mut command_rx = Some(command_rx);
|
||||
loop {
|
||||
let yield_deadline_elapsed = yield_timer
|
||||
.as_ref()
|
||||
.is_some_and(|yield_timer| yield_timer.deadline() <= tokio::time::Instant::now());
|
||||
tokio::select! {
|
||||
biased;
|
||||
maybe_command = command_rx.recv() => {
|
||||
let Some(command) = maybe_command else {
|
||||
if completed_event.is_some() {
|
||||
break;
|
||||
}
|
||||
termination = Some(Termination { response_tx: None });
|
||||
begin_termination(
|
||||
&runtime_tx,
|
||||
&runtime_control_tx,
|
||||
&runtime_terminate_handle,
|
||||
&cancellation_token,
|
||||
_ = cancellation_token.cancelled(), if !termination => {
|
||||
termination = true;
|
||||
yield_timer = None;
|
||||
drop(command_rx.take());
|
||||
begin_termination(
|
||||
&runtime_tx,
|
||||
&runtime_control_tx,
|
||||
&runtime_terminate_handle,
|
||||
&cancellation_token,
|
||||
);
|
||||
if runtime_closed {
|
||||
finish_callbacks(
|
||||
&callback_cancellation_token,
|
||||
&mut notification_tasks,
|
||||
&mut tool_tasks,
|
||||
CallbackCompletion::Cancel,
|
||||
).await;
|
||||
finish_termination(
|
||||
&cell_state,
|
||||
observer.take().map(|observer| observer.response_tx),
|
||||
CellEvent::Terminated {
|
||||
content_items: std::mem::take(&mut content_items),
|
||||
},
|
||||
);
|
||||
if runtime_closed {
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
maybe_command = async {
|
||||
match command_rx.as_mut() {
|
||||
Some(command_rx) => command_rx.recv().await,
|
||||
None => std::future::pending::<Option<CellCommand>>().await,
|
||||
}
|
||||
} => {
|
||||
let Some(CellCommand::Observe { mode, response_tx }) = maybe_command else {
|
||||
cancellation_token.cancel();
|
||||
continue;
|
||||
};
|
||||
match command {
|
||||
CellCommand::Observe { mode, response_tx } => {
|
||||
if let Some(event) = completed_event.take() {
|
||||
let _ = response_tx.send(Ok(event));
|
||||
break;
|
||||
}
|
||||
if observer.is_some() || termination.is_some() {
|
||||
let _ = response_tx.send(Err(CellError::Busy));
|
||||
continue;
|
||||
}
|
||||
observer = Some(Observer { mode, response_tx });
|
||||
yield_timer = observer.as_ref().and_then(observer_timer);
|
||||
resume_for_observation(
|
||||
mode,
|
||||
&mut runtime_paused,
|
||||
&runtime_tx,
|
||||
&runtime_control_tx,
|
||||
);
|
||||
}
|
||||
CellCommand::Terminate { response_tx } => {
|
||||
if let Some(event) = completed_event.take() {
|
||||
if let Some(response_tx) = response_tx {
|
||||
let _ = response_tx.send(Ok(event));
|
||||
}
|
||||
break;
|
||||
}
|
||||
if termination.is_some() {
|
||||
if let Some(response_tx) = response_tx {
|
||||
let _ = response_tx.send(Err(CellError::AlreadyTerminating));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
termination = Some(Termination { response_tx });
|
||||
yield_timer = None;
|
||||
begin_termination(
|
||||
&runtime_tx,
|
||||
&runtime_control_tx,
|
||||
&runtime_terminate_handle,
|
||||
&cancellation_token,
|
||||
);
|
||||
if runtime_closed {
|
||||
finish_callbacks(
|
||||
&cancellation_token,
|
||||
&mut notification_tasks,
|
||||
&mut tool_tasks,
|
||||
CallbackCompletion::Cancel,
|
||||
).await;
|
||||
send_termination_events(
|
||||
observer.take(),
|
||||
termination.take(),
|
||||
CellEvent::Terminated {
|
||||
content_items: std::mem::take(&mut content_items),
|
||||
},
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
let response_tx = match cell_state.route_observation(response_tx) {
|
||||
ObservationDelivery::Running(response_tx) => response_tx,
|
||||
ObservationDelivery::Delivered => break,
|
||||
ObservationDelivery::Buffered | ObservationDelivery::Closed => continue,
|
||||
};
|
||||
if observer
|
||||
.as_ref()
|
||||
.is_some_and(|observer| observer.response_tx.is_closed())
|
||||
{
|
||||
observer = None;
|
||||
yield_timer = None;
|
||||
}
|
||||
if observer.is_some() || termination {
|
||||
let _ = response_tx.send(Err(CellError::Busy));
|
||||
continue;
|
||||
}
|
||||
observer = Some(Observer { mode, response_tx });
|
||||
yield_timer = observer.as_ref().and_then(observer_timer);
|
||||
resume_for_observation(
|
||||
mode,
|
||||
&mut runtime_paused,
|
||||
&runtime_tx,
|
||||
&runtime_control_tx,
|
||||
);
|
||||
}
|
||||
_ = async {
|
||||
if let Some(yield_timer) = yield_timer.as_mut() {
|
||||
@@ -230,38 +218,57 @@ async fn run_cell<H: CellHost>(
|
||||
}, if !yield_deadline_elapsed => {
|
||||
let Some(event) = maybe_event else {
|
||||
runtime_closed = true;
|
||||
if termination.is_some() {
|
||||
if termination || cancellation_token.is_cancelled() {
|
||||
finish_callbacks(
|
||||
&cancellation_token,
|
||||
&callback_cancellation_token,
|
||||
&mut notification_tasks,
|
||||
&mut tool_tasks,
|
||||
CallbackCompletion::Cancel,
|
||||
).await;
|
||||
send_termination_events(
|
||||
observer.take(),
|
||||
termination.take(),
|
||||
finish_termination(
|
||||
&cell_state,
|
||||
observer.take().map(|observer| observer.response_tx),
|
||||
CellEvent::Terminated {
|
||||
content_items: std::mem::take(&mut content_items),
|
||||
},
|
||||
);
|
||||
break;
|
||||
}
|
||||
if completed_event.is_none() {
|
||||
finish_callbacks(
|
||||
&cancellation_token,
|
||||
&mut notification_tasks,
|
||||
&mut tool_tasks,
|
||||
CallbackCompletion::DrainNotifications,
|
||||
).await;
|
||||
let event = CellEvent::Completed {
|
||||
content_items: std::mem::take(&mut content_items),
|
||||
error_text: Some("exec runtime ended unexpectedly".to_string()),
|
||||
};
|
||||
if send_or_buffer_completion(
|
||||
finish_callbacks(
|
||||
&callback_cancellation_token,
|
||||
&mut notification_tasks,
|
||||
&mut tool_tasks,
|
||||
CallbackCompletion::DrainNotifications,
|
||||
)
|
||||
.await;
|
||||
let event = CellEvent::Completed {
|
||||
content_items: std::mem::take(&mut content_items),
|
||||
error_text: Some("exec runtime ended unexpectedly".to_string()),
|
||||
};
|
||||
let rejected_event = match host
|
||||
.commit_completion(
|
||||
HashMap::new(),
|
||||
event,
|
||||
&mut observer,
|
||||
&mut completed_event,
|
||||
) {
|
||||
Arc::clone(&cell_state),
|
||||
)
|
||||
.await
|
||||
{
|
||||
CompletionCommit::Committed => None,
|
||||
CompletionCommit::Rejected(event) => Some(event),
|
||||
};
|
||||
match cell_state.deliver_completion(
|
||||
observer.take().map(|observer| observer.response_tx),
|
||||
) {
|
||||
CompletionDelivery::Delivered => break,
|
||||
CompletionDelivery::Buffered => {}
|
||||
CompletionDelivery::Rejected(response_tx) => {
|
||||
finish_termination(
|
||||
&cell_state,
|
||||
response_tx,
|
||||
CellEvent::Terminated {
|
||||
content_items: rejected_completion_content(rejected_event),
|
||||
},
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -314,7 +321,7 @@ async fn run_cell<H: CellHost>(
|
||||
Arc::clone(&host),
|
||||
call_id,
|
||||
text,
|
||||
cancellation_token.child_token(),
|
||||
callback_cancellation_token.child_token(),
|
||||
);
|
||||
}
|
||||
RuntimeEvent::ToolCall { id, name, kind, input } => {
|
||||
@@ -332,22 +339,22 @@ async fn run_cell<H: CellHost>(
|
||||
input,
|
||||
},
|
||||
runtime_tx.clone(),
|
||||
cancellation_token.child_token(),
|
||||
callback_cancellation_token.child_token(),
|
||||
);
|
||||
}
|
||||
RuntimeEvent::Result { stored_value_writes, error_text } => {
|
||||
runtime_closed = true;
|
||||
yield_timer = None;
|
||||
if termination.is_some() {
|
||||
if termination || cancellation_token.is_cancelled() {
|
||||
finish_callbacks(
|
||||
&cancellation_token,
|
||||
&callback_cancellation_token,
|
||||
&mut notification_tasks,
|
||||
&mut tool_tasks,
|
||||
CallbackCompletion::Cancel,
|
||||
).await;
|
||||
send_termination_events(
|
||||
observer.take(),
|
||||
termination.take(),
|
||||
finish_termination(
|
||||
&cell_state,
|
||||
observer.take().map(|observer| observer.response_tx),
|
||||
CellEvent::Terminated {
|
||||
content_items: std::mem::take(&mut content_items),
|
||||
},
|
||||
@@ -355,22 +362,42 @@ async fn run_cell<H: CellHost>(
|
||||
break;
|
||||
}
|
||||
finish_callbacks(
|
||||
&cancellation_token,
|
||||
&callback_cancellation_token,
|
||||
&mut notification_tasks,
|
||||
&mut tool_tasks,
|
||||
CallbackCompletion::DrainNotifications,
|
||||
).await;
|
||||
host.commit_stored_values(stored_value_writes).await;
|
||||
)
|
||||
.await;
|
||||
let event = CellEvent::Completed {
|
||||
content_items: std::mem::take(&mut content_items),
|
||||
error_text,
|
||||
};
|
||||
if send_or_buffer_completion(
|
||||
event,
|
||||
&mut observer,
|
||||
&mut completed_event,
|
||||
let rejected_event = match host
|
||||
.commit_completion(
|
||||
stored_value_writes,
|
||||
event,
|
||||
Arc::clone(&cell_state),
|
||||
)
|
||||
.await
|
||||
{
|
||||
CompletionCommit::Committed => None,
|
||||
CompletionCommit::Rejected(event) => Some(event),
|
||||
};
|
||||
match cell_state.deliver_completion(
|
||||
observer.take().map(|observer| observer.response_tx),
|
||||
) {
|
||||
break;
|
||||
CompletionDelivery::Delivered => break,
|
||||
CompletionDelivery::Buffered => {}
|
||||
CompletionDelivery::Rejected(response_tx) => {
|
||||
finish_termination(
|
||||
&cell_state,
|
||||
response_tx,
|
||||
CellEvent::Terminated {
|
||||
content_items: rejected_completion_content(rejected_event),
|
||||
},
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -383,6 +410,9 @@ async fn run_cell<H: CellHost>(
|
||||
}
|
||||
}
|
||||
}
|
||||
// Reject requests that arrive while asynchronous terminal cleanup runs.
|
||||
cell_state.tombstone();
|
||||
drop(command_rx.take());
|
||||
begin_termination(
|
||||
&runtime_tx,
|
||||
&runtime_control_tx,
|
||||
@@ -390,7 +420,7 @@ async fn run_cell<H: CellHost>(
|
||||
&cancellation_token,
|
||||
);
|
||||
finish_callbacks(
|
||||
&cancellation_token,
|
||||
&callback_cancellation_token,
|
||||
&mut notification_tasks,
|
||||
&mut tool_tasks,
|
||||
CallbackCompletion::Cancel,
|
||||
@@ -399,34 +429,29 @@ async fn run_cell<H: CellHost>(
|
||||
host.closed().await;
|
||||
}
|
||||
|
||||
fn send_or_buffer_completion(
|
||||
event: CellEvent,
|
||||
observer: &mut Option<Observer>,
|
||||
completed_event: &mut Option<CellEvent>,
|
||||
) -> bool {
|
||||
if observer.is_some() {
|
||||
send_observer_event(observer.take(), event);
|
||||
true
|
||||
} else {
|
||||
*completed_event = Some(event);
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn send_observer_event(observer: Option<Observer>, event: CellEvent) {
|
||||
if let Some(observer) = observer {
|
||||
let _ = observer.response_tx.send(Ok(event));
|
||||
}
|
||||
}
|
||||
|
||||
fn send_termination_events(
|
||||
observer: Option<Observer>,
|
||||
termination: Option<Termination>,
|
||||
fn rejected_completion_content(event: Option<CellEvent>) -> Vec<OutputItem> {
|
||||
match event {
|
||||
Some(CellEvent::Completed { content_items, .. }) => content_items,
|
||||
None => Vec::new(),
|
||||
Some(event) => panic!("completion commit rejected an unexpected event: {event:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn finish_termination(
|
||||
cell_state: &CellState,
|
||||
observer_tx: Option<oneshot::Sender<Result<CellEvent, CellError>>>,
|
||||
event: CellEvent,
|
||||
) {
|
||||
send_observer_event(observer, event.clone());
|
||||
if let Some(response_tx) = termination.and_then(|termination| termination.response_tx) {
|
||||
let _ = response_tx.send(Ok(event));
|
||||
if let Some(event) = cell_state.finish_termination(event)
|
||||
&& let Some(observer_tx) = observer_tx
|
||||
{
|
||||
let _ = observer_tx.send(Ok(event));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +33,14 @@ impl CellHost for TestHost {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn commit_stored_values(&self, _stored_value_writes: HashMap<String, JsonValue>) {}
|
||||
async fn commit_completion(
|
||||
&self,
|
||||
_stored_value_writes: HashMap<String, JsonValue>,
|
||||
event: CellEvent,
|
||||
cell_state: Arc<CellState>,
|
||||
) -> CompletionCommit {
|
||||
cell_state.commit_completion(event, || {})
|
||||
}
|
||||
|
||||
async fn closed(&self) {}
|
||||
}
|
||||
@@ -64,14 +71,15 @@ fn spawn_cell_actor_harness(initial_observe_mode: ObserveMode) -> CellActorHarne
|
||||
PendingRuntimeMode::PauseUntilResumed,
|
||||
)
|
||||
.unwrap();
|
||||
let handle = CellHandle::new(command_tx, CancellationToken::new());
|
||||
let cell_state = Arc::new(CellState::new(CancellationToken::new()));
|
||||
let handle = CellHandle::new(command_tx, Arc::clone(&cell_state));
|
||||
let task = tokio::spawn(run_cell(
|
||||
Arc::new(TestHost),
|
||||
CellContext {
|
||||
runtime_tx,
|
||||
runtime_control_tx,
|
||||
runtime_terminate_handle,
|
||||
cancellation_token: CancellationToken::new(),
|
||||
cell_state,
|
||||
},
|
||||
event_rx,
|
||||
command_rx,
|
||||
@@ -142,3 +150,86 @@ async fn queued_termination_preempts_unobserved_runtime_completion() {
|
||||
assert_eq!(harness.initial_event_rx.await.unwrap(), terminated);
|
||||
harness.task.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn only_the_first_termination_claims_a_buffered_completion() {
|
||||
let cell_state = CellState::new(CancellationToken::new());
|
||||
let completion = CellEvent::Completed {
|
||||
content_items: Vec::new(),
|
||||
error_text: None,
|
||||
};
|
||||
assert_eq!(
|
||||
cell_state.commit_completion(completion.clone(), || {}),
|
||||
CompletionCommit::Committed
|
||||
);
|
||||
assert!(matches!(
|
||||
cell_state.deliver_completion(/*response_tx*/ None),
|
||||
CompletionDelivery::Buffered
|
||||
));
|
||||
|
||||
let first_termination = cell_state.request_termination();
|
||||
assert_eq!(
|
||||
cell_state.request_termination().await,
|
||||
Err(CellError::AlreadyTerminating)
|
||||
);
|
||||
assert_eq!(first_termination.await, Ok(completion.clone()));
|
||||
assert_eq!(
|
||||
cell_state.finish_termination(CellEvent::Terminated {
|
||||
content_items: Vec::new(),
|
||||
}),
|
||||
Some(completion)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn termination_claim_prevents_stored_value_commit() {
|
||||
let cell_state = CellState::new(CancellationToken::new());
|
||||
let termination = cell_state.request_termination();
|
||||
let mut commit_ran = false;
|
||||
let completion = CellEvent::Completed {
|
||||
content_items: Vec::new(),
|
||||
error_text: None,
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
cell_state.commit_completion(completion.clone(), || commit_ran = true),
|
||||
CompletionCommit::Rejected(completion)
|
||||
);
|
||||
assert!(!commit_ran);
|
||||
|
||||
let terminated = CellEvent::Terminated {
|
||||
content_items: Vec::new(),
|
||||
};
|
||||
assert_eq!(
|
||||
cell_state.finish_termination(terminated.clone()),
|
||||
Some(terminated.clone())
|
||||
);
|
||||
assert_eq!(termination.await, Ok(terminated));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn failed_completion_delivery_rebuffers_the_event() {
|
||||
let cell_state = CellState::new(CancellationToken::new());
|
||||
let event = CellEvent::Completed {
|
||||
content_items: Vec::new(),
|
||||
error_text: None,
|
||||
};
|
||||
assert_eq!(
|
||||
cell_state.commit_completion(event.clone(), || {}),
|
||||
CompletionCommit::Committed
|
||||
);
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
drop(response_rx);
|
||||
assert!(matches!(
|
||||
cell_state.deliver_completion(Some(response_tx)),
|
||||
CompletionDelivery::Buffered
|
||||
));
|
||||
assert!(cell_state.accepting_observations());
|
||||
|
||||
let (response_tx, mut response_rx) = oneshot::channel();
|
||||
assert!(matches!(
|
||||
cell_state.route_observation(response_tx),
|
||||
ObservationDelivery::Delivered
|
||||
));
|
||||
assert_eq!(response_rx.try_recv(), Ok(Ok(event)));
|
||||
}
|
||||
|
||||
@@ -2,8 +2,7 @@ use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use serde_json::Value as JsonValue;
|
||||
use tokio::sync::mpsc;
|
||||
@@ -32,10 +31,11 @@ pub(crate) struct CellToolCall {
|
||||
pub(crate) input: Option<JsonValue>,
|
||||
}
|
||||
|
||||
/// Connects a cell actor to session-owned callbacks and lifecycle state.
|
||||
/// Connects a cell actor to session-owned callbacks and stored values.
|
||||
///
|
||||
/// Implementations must honor callback cancellation and must not return from
|
||||
/// `closed` until the session can no longer route requests to the cell.
|
||||
/// Implementations should forward callback cancellation to downstream work.
|
||||
/// Implementations must not return from `closed` until the session can no longer
|
||||
/// route requests to the cell.
|
||||
pub(crate) trait CellHost: Send + Sync + 'static {
|
||||
fn invoke_tool(
|
||||
&self,
|
||||
@@ -50,10 +50,12 @@ pub(crate) trait CellHost: Send + Sync + 'static {
|
||||
cancellation_token: CancellationToken,
|
||||
) -> impl Future<Output = Result<(), String>> + Send;
|
||||
|
||||
fn commit_stored_values(
|
||||
fn commit_completion(
|
||||
&self,
|
||||
stored_value_writes: HashMap<String, JsonValue>,
|
||||
) -> impl Future<Output = ()> + Send;
|
||||
event: CellEvent,
|
||||
cell_state: Arc<CellState>,
|
||||
) -> impl Future<Output = CompletionCommit> + Send;
|
||||
|
||||
fn closed(&self) -> impl Future<Output = ()> + Send;
|
||||
}
|
||||
@@ -61,23 +63,21 @@ pub(crate) trait CellHost: Send + Sync + 'static {
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct CellHandle {
|
||||
command_tx: mpsc::UnboundedSender<CellCommand>,
|
||||
cancellation_token: CancellationToken,
|
||||
termination_requested: Arc<AtomicBool>,
|
||||
state: Arc<CellState>,
|
||||
}
|
||||
|
||||
impl CellHandle {
|
||||
pub(super) fn new(
|
||||
command_tx: mpsc::UnboundedSender<CellCommand>,
|
||||
cancellation_token: CancellationToken,
|
||||
state: Arc<CellState>,
|
||||
) -> Self {
|
||||
Self {
|
||||
command_tx,
|
||||
cancellation_token,
|
||||
termination_requested: Arc::new(AtomicBool::new(false)),
|
||||
}
|
||||
Self { command_tx, state }
|
||||
}
|
||||
|
||||
pub(crate) fn observe(&self, mode: ObserveMode) -> CellEventFuture {
|
||||
if !self.state.accepting_observations() {
|
||||
return closed_event();
|
||||
}
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
if self
|
||||
.command_tx
|
||||
@@ -90,33 +90,232 @@ impl CellHandle {
|
||||
}
|
||||
|
||||
pub(crate) fn terminate(&self) -> CellEventFuture {
|
||||
if self
|
||||
.termination_requested
|
||||
.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
|
||||
.is_err()
|
||||
{
|
||||
return Box::pin(async { Err(CellError::AlreadyTerminating) });
|
||||
}
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
if self
|
||||
.command_tx
|
||||
.send(CellCommand::Terminate {
|
||||
response_tx: Some(response_tx),
|
||||
})
|
||||
.is_err()
|
||||
{
|
||||
self.termination_requested.store(false, Ordering::Relaxed);
|
||||
return closed_event();
|
||||
}
|
||||
response_event(response_rx)
|
||||
self.state.request_termination()
|
||||
}
|
||||
|
||||
pub(crate) fn shutdown(&self) {
|
||||
self.termination_requested.store(true, Ordering::Relaxed);
|
||||
self.state.cancellation_token().cancel();
|
||||
}
|
||||
}
|
||||
|
||||
/// The single linearization point for a cell's terminal outcome.
|
||||
///
|
||||
/// Callback cancellation tokens are children of the cell token, so a terminal
|
||||
/// decision cancels runtime work and its callbacks together.
|
||||
///
|
||||
/// The mutex is held only for synchronous phase transitions and terminal
|
||||
/// delivery. Runtime execution, observation waits, and callbacks never run
|
||||
/// while it is held.
|
||||
pub(crate) struct CellState {
|
||||
phase: Mutex<CellPhase>,
|
||||
cancellation_token: CancellationToken,
|
||||
}
|
||||
|
||||
enum CellPhase {
|
||||
Running,
|
||||
Terminating {
|
||||
response_tx: oneshot::Sender<Result<CellEvent, CellError>>,
|
||||
},
|
||||
Completed(CellEvent),
|
||||
CompletionClaimed(CellEvent),
|
||||
Tombstone,
|
||||
}
|
||||
|
||||
pub(crate) enum CompletionDelivery {
|
||||
Delivered,
|
||||
Buffered,
|
||||
Rejected(Option<oneshot::Sender<Result<CellEvent, CellError>>>),
|
||||
}
|
||||
|
||||
/// Result of atomically publishing a completed cell and its session side effects.
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub(crate) enum CompletionCommit {
|
||||
Committed,
|
||||
Rejected(CellEvent),
|
||||
}
|
||||
|
||||
pub(crate) enum ObservationDelivery {
|
||||
Running(oneshot::Sender<Result<CellEvent, CellError>>),
|
||||
Delivered,
|
||||
Buffered,
|
||||
Closed,
|
||||
}
|
||||
|
||||
impl CellState {
|
||||
pub(crate) fn new(cancellation_token: CancellationToken) -> Self {
|
||||
Self {
|
||||
phase: Mutex::new(CellPhase::Running),
|
||||
cancellation_token,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn accepting_observations(&self) -> bool {
|
||||
let accepting_phase = matches!(
|
||||
*self
|
||||
.phase
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner),
|
||||
CellPhase::Running | CellPhase::Completed(_)
|
||||
);
|
||||
accepting_phase && !self.cancellation_token.is_cancelled()
|
||||
}
|
||||
|
||||
pub(crate) fn request_termination(&self) -> CellEventFuture {
|
||||
let mut phase = self
|
||||
.phase
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
match std::mem::replace(&mut *phase, CellPhase::Tombstone) {
|
||||
CellPhase::Running => {
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
*phase = CellPhase::Terminating { response_tx };
|
||||
self.cancellation_token.cancel();
|
||||
response_event(response_rx)
|
||||
}
|
||||
CellPhase::Terminating { response_tx } => {
|
||||
*phase = CellPhase::Terminating { response_tx };
|
||||
Box::pin(async { Err(CellError::AlreadyTerminating) })
|
||||
}
|
||||
CellPhase::Completed(event) => {
|
||||
*phase = CellPhase::CompletionClaimed(event.clone());
|
||||
self.cancellation_token.cancel();
|
||||
ready_event(event)
|
||||
}
|
||||
CellPhase::CompletionClaimed(event) => {
|
||||
*phase = CellPhase::CompletionClaimed(event);
|
||||
Box::pin(async { Err(CellError::AlreadyTerminating) })
|
||||
}
|
||||
CellPhase::Tombstone => closed_event(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn commit_completion(
|
||||
&self,
|
||||
event: CellEvent,
|
||||
commit: impl FnOnce(),
|
||||
) -> CompletionCommit {
|
||||
let mut phase = self
|
||||
.phase
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
if !matches!(*phase, CellPhase::Running) || self.cancellation_token.is_cancelled() {
|
||||
return CompletionCommit::Rejected(event);
|
||||
}
|
||||
commit();
|
||||
*phase = CellPhase::Completed(event);
|
||||
CompletionCommit::Committed
|
||||
}
|
||||
|
||||
pub(crate) fn deliver_completion(
|
||||
&self,
|
||||
response_tx: Option<oneshot::Sender<Result<CellEvent, CellError>>>,
|
||||
) -> CompletionDelivery {
|
||||
let mut phase = self
|
||||
.phase
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let event = match std::mem::replace(&mut *phase, CellPhase::Tombstone) {
|
||||
CellPhase::Completed(event) => event,
|
||||
previous => {
|
||||
*phase = previous;
|
||||
return CompletionDelivery::Rejected(response_tx);
|
||||
}
|
||||
};
|
||||
let Some(response_tx) = response_tx else {
|
||||
*phase = CellPhase::Completed(event);
|
||||
return CompletionDelivery::Buffered;
|
||||
};
|
||||
match response_tx.send(Ok(event)) {
|
||||
Ok(()) => {
|
||||
self.cancellation_token.cancel();
|
||||
CompletionDelivery::Delivered
|
||||
}
|
||||
Err(Ok(event)) => {
|
||||
*phase = CellPhase::Completed(event);
|
||||
CompletionDelivery::Buffered
|
||||
}
|
||||
Err(Err(error)) => {
|
||||
panic!("completion delivery unexpectedly carried an actor error: {error:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn route_observation(
|
||||
&self,
|
||||
response_tx: oneshot::Sender<Result<CellEvent, CellError>>,
|
||||
) -> ObservationDelivery {
|
||||
let mut phase = self
|
||||
.phase
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
match std::mem::replace(&mut *phase, CellPhase::Tombstone) {
|
||||
CellPhase::Running => {
|
||||
*phase = CellPhase::Running;
|
||||
ObservationDelivery::Running(response_tx)
|
||||
}
|
||||
CellPhase::Completed(event) => match response_tx.send(Ok(event)) {
|
||||
Ok(()) => {
|
||||
self.cancellation_token.cancel();
|
||||
ObservationDelivery::Delivered
|
||||
}
|
||||
Err(Ok(event)) => {
|
||||
*phase = CellPhase::Completed(event);
|
||||
ObservationDelivery::Buffered
|
||||
}
|
||||
Err(Err(error)) => {
|
||||
panic!("completion delivery unexpectedly carried an actor error: {error:?}")
|
||||
}
|
||||
},
|
||||
CellPhase::Terminating {
|
||||
response_tx: termination_tx,
|
||||
} => {
|
||||
*phase = CellPhase::Terminating {
|
||||
response_tx: termination_tx,
|
||||
};
|
||||
let _ = response_tx.send(Err(CellError::Closed));
|
||||
ObservationDelivery::Closed
|
||||
}
|
||||
CellPhase::CompletionClaimed(event) => {
|
||||
*phase = CellPhase::CompletionClaimed(event);
|
||||
let _ = response_tx.send(Err(CellError::Closed));
|
||||
ObservationDelivery::Closed
|
||||
}
|
||||
CellPhase::Tombstone => {
|
||||
let _ = response_tx.send(Err(CellError::Closed));
|
||||
ObservationDelivery::Closed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn finish_termination(&self, event: CellEvent) -> Option<CellEvent> {
|
||||
let mut phase = self
|
||||
.phase
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let observer_event = match std::mem::replace(&mut *phase, CellPhase::Tombstone) {
|
||||
CellPhase::Running => Some(event),
|
||||
CellPhase::Terminating { response_tx } => {
|
||||
let _ = response_tx.send(Ok(event.clone()));
|
||||
Some(event)
|
||||
}
|
||||
CellPhase::Completed(completed_event) => Some(completed_event),
|
||||
CellPhase::CompletionClaimed(completed_event) => Some(completed_event),
|
||||
CellPhase::Tombstone => None,
|
||||
};
|
||||
self.cancellation_token.cancel();
|
||||
let _ = self
|
||||
.command_tx
|
||||
.send(CellCommand::Terminate { response_tx: None });
|
||||
observer_event
|
||||
}
|
||||
|
||||
pub(crate) fn tombstone(&self) {
|
||||
*self
|
||||
.phase
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner) = CellPhase::Tombstone;
|
||||
self.cancellation_token.cancel();
|
||||
}
|
||||
|
||||
pub(crate) fn cancellation_token(&self) -> CancellationToken {
|
||||
self.cancellation_token.clone()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,15 +324,16 @@ pub(super) enum CellCommand {
|
||||
mode: ObserveMode,
|
||||
response_tx: oneshot::Sender<Result<CellEvent, CellError>>,
|
||||
},
|
||||
Terminate {
|
||||
response_tx: Option<oneshot::Sender<Result<CellEvent, CellError>>>,
|
||||
},
|
||||
}
|
||||
|
||||
fn response_event(response_rx: oneshot::Receiver<Result<CellEvent, CellError>>) -> CellEventFuture {
|
||||
Box::pin(async move { response_rx.await.unwrap_or(Err(CellError::Closed)) })
|
||||
}
|
||||
|
||||
fn ready_event(event: CellEvent) -> CellEventFuture {
|
||||
Box::pin(async move { Ok(event) })
|
||||
}
|
||||
|
||||
fn closed_event() -> CellEventFuture {
|
||||
Box::pin(async { Err(CellError::Closed) })
|
||||
}
|
||||
|
||||
@@ -29,7 +29,9 @@ use crate::cell_actor::CellError;
|
||||
use crate::cell_actor::CellEventFuture;
|
||||
use crate::cell_actor::CellHandle;
|
||||
use crate::cell_actor::CellHost;
|
||||
use crate::cell_actor::CellState;
|
||||
use crate::cell_actor::CellToolCall;
|
||||
use crate::cell_actor::CompletionCommit;
|
||||
|
||||
type RuntimeEventFuture = Pin<Box<dyn Future<Output = Result<CellEvent, Error>> + Send + 'static>>;
|
||||
|
||||
@@ -165,9 +167,15 @@ impl<D: SessionRuntimeDelegate> SessionRuntime<D> {
|
||||
if cells.contains_key(&cell_id) {
|
||||
return Err(Error::DuplicateCell(cell_id));
|
||||
}
|
||||
let (handle, initial_event, task) =
|
||||
CellActor::prepare(request, stored_values, host, initial_observe_mode)
|
||||
.map_err(Error::Runtime)?;
|
||||
let cell_state = Arc::new(CellState::new(CancellationToken::new()));
|
||||
let (handle, initial_event, task) = CellActor::prepare(
|
||||
request,
|
||||
stored_values,
|
||||
host,
|
||||
initial_observe_mode,
|
||||
cell_state,
|
||||
)
|
||||
.map_err(Error::Runtime)?;
|
||||
cells.insert(cell_id.clone(), handle);
|
||||
drop(cells);
|
||||
tokio::spawn(task);
|
||||
@@ -251,12 +259,21 @@ impl<D: SessionRuntimeDelegate> CellHost for RuntimeCellHost<D> {
|
||||
.await
|
||||
}
|
||||
|
||||
async fn commit_stored_values(&self, stored_value_writes: HashMap<String, JsonValue>) {
|
||||
self.inner
|
||||
.stored_values
|
||||
.lock()
|
||||
.await
|
||||
.extend(stored_value_writes);
|
||||
async fn commit_completion(
|
||||
&self,
|
||||
stored_value_writes: HashMap<String, JsonValue>,
|
||||
event: CellEvent,
|
||||
cell_state: Arc<CellState>,
|
||||
) -> CompletionCommit {
|
||||
let cancellation_token = cell_state.cancellation_token();
|
||||
let mut stored_values = tokio::select! {
|
||||
biased;
|
||||
_ = cancellation_token.cancelled() => {
|
||||
return CompletionCommit::Rejected(event);
|
||||
}
|
||||
stored_values = self.inner.stored_values.lock() => stored_values,
|
||||
};
|
||||
cell_state.commit_completion(event, || stored_values.extend(stored_value_writes))
|
||||
}
|
||||
|
||||
async fn closed(&self) {
|
||||
@@ -276,3 +293,7 @@ fn actor_error(cell_id: &CellId, error: CellError) -> Error {
|
||||
CellError::Closed => Error::ClosedCell(cell_id.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests.rs"]
|
||||
mod tests;
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use std::task::Waker;
|
||||
use std::time::Duration;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value as JsonValue;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use super::*;
|
||||
use crate::cell_actor::CompletionCommit;
|
||||
|
||||
struct RecordingDelegate;
|
||||
|
||||
impl SessionRuntimeDelegate for RecordingDelegate {
|
||||
async fn invoke_tool(
|
||||
&self,
|
||||
_invocation: NestedToolCall,
|
||||
_cancellation_token: CancellationToken,
|
||||
) -> Result<JsonValue, String> {
|
||||
Ok(JsonValue::Null)
|
||||
}
|
||||
|
||||
async fn notify(
|
||||
&self,
|
||||
_call_id: String,
|
||||
_cell_id: CellId,
|
||||
_text: String,
|
||||
_cancellation_token: CancellationToken,
|
||||
) -> Result<(), String> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cell_closed(&self, _cell_id: &CellId) {}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn termination_rejects_a_waiting_store_commit_before_the_next_cell_can_load_it() {
|
||||
let runtime = SessionRuntime::new(Arc::new(RecordingDelegate));
|
||||
let cell_state = Arc::new(CellState::new(CancellationToken::new()));
|
||||
let host = RuntimeCellHost {
|
||||
cell_id: CellId::new("terminating-writer"),
|
||||
inner: Arc::clone(&runtime.inner),
|
||||
};
|
||||
let completion = CellEvent::Completed {
|
||||
content_items: vec![OutputItem::Text {
|
||||
text: "uncommitted output".to_string(),
|
||||
}],
|
||||
error_text: None,
|
||||
};
|
||||
|
||||
let stored_values = runtime.inner.stored_values.lock().await;
|
||||
let commit = host.commit_completion(
|
||||
HashMap::from([(
|
||||
"candidate".to_string(),
|
||||
JsonValue::String("lost".to_string()),
|
||||
)]),
|
||||
completion.clone(),
|
||||
Arc::clone(&cell_state),
|
||||
);
|
||||
tokio::pin!(commit);
|
||||
let waker = Waker::noop();
|
||||
let mut context = Context::from_waker(waker);
|
||||
assert!(matches!(commit.as_mut().poll(&mut context), Poll::Pending));
|
||||
|
||||
let termination = cell_state.request_termination();
|
||||
drop(stored_values);
|
||||
assert_eq!(commit.await, CompletionCommit::Rejected(completion));
|
||||
let terminated = CellEvent::Terminated {
|
||||
content_items: Vec::new(),
|
||||
};
|
||||
assert_eq!(
|
||||
cell_state.finish_termination(terminated.clone()),
|
||||
Some(terminated.clone())
|
||||
);
|
||||
assert_eq!(termination.await, Ok(terminated));
|
||||
assert!(
|
||||
!runtime
|
||||
.inner
|
||||
.stored_values
|
||||
.lock()
|
||||
.await
|
||||
.contains_key("candidate")
|
||||
);
|
||||
|
||||
let reader = runtime
|
||||
.execute(
|
||||
CreateCellRequest {
|
||||
tool_call_id: "reader".to_string(),
|
||||
enabled_tools: Vec::new(),
|
||||
source: r#"text(String(load("candidate")));"#.to_string(),
|
||||
},
|
||||
ObserveMode::YieldAfter(Duration::from_secs(1)),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
reader.initial_event().await,
|
||||
Ok(CellEvent::Completed {
|
||||
content_items: vec![OutputItem::Text {
|
||||
text: "undefined".to_string(),
|
||||
}],
|
||||
error_text: None,
|
||||
})
|
||||
);
|
||||
runtime.shutdown().await.unwrap();
|
||||
}
|
||||
Reference in New Issue
Block a user