mirror of
https://github.com/pchuan98/codex.git
synced 2026-07-01 00:31:56 +08:00
code-mode: make session shutdown authoritative (#29287)
## Summary - Give each session and cell a hierarchical cancellation token. - Track cell tasks so shutdown waits for admitted actors without polling the registry. - Make shutdown authoritative across concurrent admission and non-cooperative callbacks. ## Why A best-effort registry scan can miss cells admitted concurrently or blocked behind the registry lock. ## Impact Session shutdown reliably stops every admitted cell and rejects new work once shutdown begins. ## Validation - Stack-tip validation: `just test -p codex-code-mode -p codex-code-mode-protocol` (70 passed). - Parent branch: `cconger/code-mode-runtime-compact-03c-terminal-state`.
This commit is contained in:
committed by
GitHub
Unverified
parent
f774455c3a
commit
9c79d87d06
@@ -92,16 +92,13 @@ impl CellHandle {
|
||||
pub(crate) fn terminate(&self) -> CellEventFuture {
|
||||
self.state.request_termination()
|
||||
}
|
||||
|
||||
pub(crate) fn shutdown(&self) {
|
||||
self.state.cancellation_token().cancel();
|
||||
}
|
||||
}
|
||||
|
||||
/// The single linearization point for a cell's terminal outcome.
|
||||
///
|
||||
/// Callback cancellation tokens are children of the cell token, so a terminal
|
||||
/// decision cancels runtime work and its callbacks together.
|
||||
/// The cancellation token is a child of the owning session token. Callback
|
||||
/// tokens are children of this token, so cancellation flows strictly from the
|
||||
/// session to the cell and then to its callbacks.
|
||||
///
|
||||
/// The mutex is held only for synchronous phase transitions and terminal
|
||||
/// delivery. Runtime execution, observation waits, and callbacks never run
|
||||
|
||||
@@ -420,7 +420,7 @@ await new Promise(() => {});
|
||||
.send(RuntimeCommand::TimeoutFired { id: 1 })
|
||||
.unwrap();
|
||||
assert!(
|
||||
tokio::time::timeout(Duration::from_millis(100), event_rx.recv())
|
||||
tokio::time::timeout(Duration::from_secs(1), event_rx.recv())
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
@@ -289,7 +289,10 @@ async fn observed_natural_completion_wins_over_termination() {
|
||||
tokio::time::timeout(Duration::from_secs(1), async {
|
||||
loop {
|
||||
let response = service
|
||||
.execute(execute_request(r#"text(String(load("finished")));"#))
|
||||
.execute(ExecuteRequest {
|
||||
yield_time_ms: Some(60_000),
|
||||
..execute_request(r#"text(String(load("finished")));"#)
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
.initial_response()
|
||||
|
||||
@@ -360,7 +360,7 @@ await Promise.all([
|
||||
}
|
||||
);
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(1100)).await;
|
||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||
|
||||
let resumed_response = tokio::time::timeout(
|
||||
Duration::from_secs(1),
|
||||
|
||||
@@ -4,13 +4,13 @@ use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use serde_json::Value as JsonValue;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
pub(crate) use self::types::CellEvent;
|
||||
pub(crate) use self::types::CellId;
|
||||
@@ -43,8 +43,9 @@ pub(crate) struct SessionRuntime<D: SessionRuntimeDelegate> {
|
||||
struct Inner<D: SessionRuntimeDelegate> {
|
||||
stored_values: Mutex<HashMap<String, JsonValue>>,
|
||||
cells: Mutex<HashMap<CellId, CellHandle>>,
|
||||
cell_tasks: TaskTracker,
|
||||
shutdown_token: CancellationToken,
|
||||
delegate: Arc<D>,
|
||||
shutting_down: AtomicBool,
|
||||
next_cell_id: AtomicU64,
|
||||
}
|
||||
|
||||
@@ -54,15 +55,16 @@ impl<D: SessionRuntimeDelegate> SessionRuntime<D> {
|
||||
inner: Arc::new(Inner {
|
||||
stored_values: Mutex::new(HashMap::new()),
|
||||
cells: Mutex::new(HashMap::new()),
|
||||
cell_tasks: TaskTracker::new(),
|
||||
shutdown_token: CancellationToken::new(),
|
||||
delegate,
|
||||
shutting_down: AtomicBool::new(false),
|
||||
next_cell_id: AtomicU64::new(1),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_alive(&self) -> bool {
|
||||
!self.inner.shutting_down.load(Ordering::Acquire)
|
||||
!self.inner.shutdown_token.is_cancelled()
|
||||
}
|
||||
|
||||
pub(crate) async fn execute(
|
||||
@@ -70,6 +72,9 @@ impl<D: SessionRuntimeDelegate> SessionRuntime<D> {
|
||||
request: CreateCellRequest,
|
||||
initial_observe_mode: ObserveMode,
|
||||
) -> Result<StartedCell, Error> {
|
||||
if self.inner.shutdown_token.is_cancelled() {
|
||||
return Err(Error::ShuttingDown);
|
||||
}
|
||||
let cell_id = self.allocate_cell_id();
|
||||
let initial_event = self
|
||||
.start_cell(cell_id.clone(), request, initial_observe_mode)
|
||||
@@ -122,21 +127,13 @@ impl<D: SessionRuntimeDelegate> SessionRuntime<D> {
|
||||
}
|
||||
|
||||
pub(crate) async fn shutdown(&self) -> Result<(), Error> {
|
||||
self.inner.shutting_down.store(true, Ordering::Release);
|
||||
let handles = self
|
||||
.inner
|
||||
.cells
|
||||
.lock()
|
||||
.await
|
||||
.values()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
for handle in handles {
|
||||
handle.shutdown();
|
||||
}
|
||||
while !self.inner.cells.lock().await.is_empty() {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
self.begin_shutdown();
|
||||
// Taking the registry lock ensures every cell that passed the shutdown
|
||||
// check has registered its actor with the tracker before we wait.
|
||||
let cells = self.inner.cells.lock().await;
|
||||
self.inner.cell_tasks.close();
|
||||
drop(cells);
|
||||
self.inner.cell_tasks.wait().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -161,13 +158,13 @@ impl<D: SessionRuntimeDelegate> SessionRuntime<D> {
|
||||
inner: Arc::clone(&self.inner),
|
||||
});
|
||||
let mut cells = self.inner.cells.lock().await;
|
||||
if self.inner.shutting_down.load(Ordering::Acquire) {
|
||||
if self.inner.shutdown_token.is_cancelled() {
|
||||
return Err(Error::ShuttingDown);
|
||||
}
|
||||
if cells.contains_key(&cell_id) {
|
||||
return Err(Error::DuplicateCell(cell_id));
|
||||
}
|
||||
let cell_state = Arc::new(CellState::new(CancellationToken::new()));
|
||||
let cell_state = Arc::new(CellState::new(self.inner.shutdown_token.child_token()));
|
||||
let (handle, initial_event, task) = CellActor::prepare(
|
||||
request,
|
||||
stored_values,
|
||||
@@ -177,18 +174,14 @@ impl<D: SessionRuntimeDelegate> SessionRuntime<D> {
|
||||
)
|
||||
.map_err(Error::Runtime)?;
|
||||
cells.insert(cell_id.clone(), handle);
|
||||
self.inner.cell_tasks.spawn(task);
|
||||
drop(cells);
|
||||
tokio::spawn(task);
|
||||
Ok(map_actor_event(cell_id, initial_event))
|
||||
}
|
||||
|
||||
fn begin_shutdown(&self) {
|
||||
self.inner.shutting_down.store(true, Ordering::Release);
|
||||
if let Ok(cells) = self.inner.cells.try_lock() {
|
||||
for handle in cells.values() {
|
||||
handle.shutdown();
|
||||
}
|
||||
}
|
||||
self.inner.shutdown_token.cancel();
|
||||
self.inner.cell_tasks.close();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -108,3 +108,80 @@ async fn termination_rejects_a_waiting_store_commit_before_the_next_cell_can_loa
|
||||
);
|
||||
runtime.shutdown().await.unwrap();
|
||||
}
|
||||
|
||||
fn execute_request(source: &str) -> CreateCellRequest {
|
||||
CreateCellRequest {
|
||||
tool_call_id: "call-1".to_string(),
|
||||
enabled_tools: Vec::new(),
|
||||
source: source.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[expect(
|
||||
clippy::await_holding_invalid_type,
|
||||
reason = "test holds the registry lock to force admission ahead of shutdown"
|
||||
)]
|
||||
async fn shutdown_rejects_cell_admission_queued_before_the_registry_lock() {
|
||||
let runtime = Arc::new(SessionRuntime::new(Arc::new(RecordingDelegate)));
|
||||
let cells = runtime.inner.cells.lock().await;
|
||||
|
||||
let execution = runtime.execute(
|
||||
execute_request("while (true) {}"),
|
||||
ObserveMode::YieldAfter(Duration::from_millis(/*millis*/ 1)),
|
||||
);
|
||||
tokio::pin!(execution);
|
||||
std::future::poll_fn(|context| match execution.as_mut().poll(context) {
|
||||
Poll::Pending => Poll::Ready(()),
|
||||
Poll::Ready(Ok(_)) => panic!("execution completed before the registry lock was released"),
|
||||
Poll::Ready(Err(error)) => {
|
||||
panic!("execution failed before the registry lock was released: {error}")
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
let shutdown = runtime.shutdown();
|
||||
tokio::pin!(shutdown);
|
||||
std::future::poll_fn(|context| match shutdown.as_mut().poll(context) {
|
||||
Poll::Pending => Poll::Ready(()),
|
||||
Poll::Ready(Ok(())) => panic!("shutdown completed before acquiring the registry lock"),
|
||||
Poll::Ready(Err(error)) => {
|
||||
panic!("shutdown failed before acquiring the registry lock: {error}")
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(!runtime.is_alive());
|
||||
drop(cells);
|
||||
assert!(matches!(execution.await, Err(Error::ShuttingDown)));
|
||||
assert_eq!(shutdown.await, Ok(()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn drop_terminates_cells_when_the_registry_is_locked() {
|
||||
let runtime = SessionRuntime::new(Arc::new(RecordingDelegate));
|
||||
let started = runtime
|
||||
.execute(
|
||||
execute_request("while (true) {}"),
|
||||
ObserveMode::YieldAfter(Duration::from_millis(/*millis*/ 1)),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(started.cell_id, CellId::new("1"));
|
||||
assert_eq!(
|
||||
started.initial_event().await,
|
||||
Ok(CellEvent::Yielded {
|
||||
content_items: Vec::new(),
|
||||
})
|
||||
);
|
||||
|
||||
let inner = Arc::clone(&runtime.inner);
|
||||
let cells = inner.cells.lock().await;
|
||||
drop(runtime);
|
||||
drop(cells);
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(/*secs*/ 1), inner.cell_tasks.wait())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(inner.cell_tasks.is_empty());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user